diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 750ce85049409..06b19335f2aed 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -793,6 +793,9 @@ struct VectorLoadOpConverter final // Use the converted vector type instead of original (single element vector // would get converted to scalar). auto spirvVectorType = typeConverter.convertType(vectorType); + if (!spirvVectorType) + return rewriter.notifyMatchFailure(loadOp, "unsupported vector type"); + auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass); // For single element vectors, we don't need to bitcast the access chain to diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index 44941f0d0ac5d..f43a41a0af2f4 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -1161,3 +1161,15 @@ func.func @vector_store_2d(%arg0 : memref<4x4xf32, #spirv.storage_class>) -> vector<4xf32> { + %idx = arith.constant 0 : index + %0 = vector.load %arg0[%idx] : memref<4xf32, #spirv.storage_class>, vector<4xf32> + return %0: vector<4xf32> +}