diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 1ecb892a4ea92..bca77ba68fbd1 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -770,10 +770,20 @@ struct VectorLoadOpConverter final spirv::StorageClass storageClass = attr.getValue(); auto vectorType = loadOp.getVectorType(); - auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass); - Value castedAccessChain = - rewriter.create(loc, vectorPtrType, accessChain); - rewriter.replaceOpWithNewOp(loadOp, vectorType, + // Use the converted vector type instead of original (single element vector + // would get converted to scalar). + auto spirvVectorType = typeConverter.convertType(vectorType); + auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass); + + // For single element vectors, we don't need to bitcast the access chain to + // the original vector type. Both is going to be the same, a pointer + // to a scalar. + Value castedAccessChain = (vectorType.getNumElements() == 1) + ? accessChain + : rewriter.create( + loc, vectorPtrType, accessChain); + + rewriter.replaceOpWithNewOp(loadOp, spirvVectorType, castedAccessChain); return success(); @@ -806,8 +816,15 @@ struct VectorStoreOpConverter final spirv::StorageClass storageClass = attr.getValue(); auto vectorType = storeOp.getVectorType(); auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass); - Value castedAccessChain = - rewriter.create(loc, vectorPtrType, accessChain); + + // For single element vectors, we don't need to bitcast the access chain to + // the original vector type. Both is going to be the same, a pointer + // to a scalar. + Value castedAccessChain = (vectorType.getNumElements() == 1) + ? accessChain + : rewriter.create( + loc, vectorPtrType, accessChain); + rewriter.replaceOpWithNewOp(storeOp, castedAccessChain, adaptor.getValueToStore()); diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index 3f0bf1962e299..4701ac5d96009 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -1004,6 +1004,27 @@ func.func @vector_load(%arg0 : memref<4xf32, #spirv.storage_class return %0: vector<4xf32> } + +// CHECK-LABEL: @vector_load_single_elem +// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class>) +// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32 +// CHECK: %[[CST1:.+]] = spirv.Constant 0 : i32 +// CHECK: %[[CST2:.+]] = spirv.Constant 0 : i32 +// CHECK: %[[CST3:.+]] = spirv.Constant 1 : i32 +// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S1]]] : !spirv.ptr [0])>, StorageBuffer>, i32, i32 +// CHECK: %[[S5:.+]] = spirv.Load "StorageBuffer" %[[S4]] : f32 +// CHECK: %[[R0:.+]] = builtin.unrealized_conversion_cast %[[S5]] : f32 to vector<1xf32> +// CHECK: return %[[R0]] : vector<1xf32> +func.func @vector_load_single_elem(%arg0 : memref<4xf32, #spirv.storage_class>) -> vector<1xf32> { + %idx = arith.constant 0 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + %0 = vector.load %arg0[%idx] : memref<4xf32, #spirv.storage_class>, vector<1xf32> + return %0: vector<1xf32> +} + + // CHECK-LABEL: @vector_load_2d // CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class>) -> vector<4xf32> { // CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x4xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> @@ -1046,6 +1067,24 @@ func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class> +// CHECK-SAME: %[[ARG1:.*]]: vector<1xf32> +// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> +// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : vector<1xf32> to f32 +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32 +// CHECK: %[[CST1:.+]] = spirv.Constant 0 : i32 +// CHECK: %[[CST2:.+]] = spirv.Constant 0 : i32 +// CHECK: %[[CST3:.+]] = spirv.Constant 1 : i32 +// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S2]]] : !spirv.ptr [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr +// CHECK: spirv.Store "StorageBuffer" %[[S4]], %[[S1]] : f32 +func.func @vector_store_single_elem(%arg0 : memref<4xf32, #spirv.storage_class>, %arg1 : vector<1xf32>) { + %idx = arith.constant 0 : index + vector.store %arg1, %arg0[%idx] : memref<4xf32, #spirv.storage_class>, vector<1xf32> + return +} + // CHECK-LABEL: @vector_store_2d // CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class> // CHECK-SAME: %[[ARG1:.*]]: vector<4xf32>