Skip to content

Commit 020856e

Browse files
authored
[mlir][vector][spirv] Fix a crash in VectorLoadOpConverter (#149964)
This PR adds null check for `spirvVectorType` in VectorLoadOpConverter to prevent a crash. Fixes #149956.
1 parent b3e016e commit 020856e

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,9 @@ struct VectorLoadOpConverter final
793793
// Use the converted vector type instead of original (single element vector
794794
// would get converted to scalar).
795795
auto spirvVectorType = typeConverter.convertType(vectorType);
796+
if (!spirvVectorType)
797+
return rewriter.notifyMatchFailure(loadOp, "unsupported vector type");
798+
796799
auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass);
797800

798801
// For single element vectors, we don't need to bitcast the access chain to

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,3 +1161,15 @@ func.func @vector_store_2d(%arg0 : memref<4x4xf32, #spirv.storage_class<StorageB
11611161
}
11621162

11631163
} // end module
1164+
1165+
// -----
1166+
1167+
// Ensure the case without module attributes not crash.
1168+
1169+
// CHECK-LABEL: @vector_load
1170+
// CHECK: vector.load
1171+
func.func @vector_load(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
1172+
%idx = arith.constant 0 : index
1173+
%0 = vector.load %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
1174+
return %0: vector<4xf32>
1175+
}

0 commit comments

Comments
 (0)