Skip to content

Commit f4b9839

Browse files
authored
[mlir][TensorToSPIRV] Add type check for tensor.extract in TensorToSPIRV (#107110)
This patch add a type check for `tensor.extract` in TensorToSPIRV. Only convert `tensor.extract` with supported element type. Fix #74466.
1 parent 812c96e commit f4b9839

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ class TensorExtractPattern final
4545
ConversionPatternRewriter &rewriter) const override {
4646
auto tensorType = cast<RankedTensorType>(extractOp.getTensor().getType());
4747

48+
if (!isa<spirv::ScalarType>(tensorType.getElementType()))
49+
return rewriter.notifyMatchFailure(extractOp, "unsupported type");
4850
if (!tensorType.hasStaticShape())
4951
return rewriter.notifyMatchFailure(extractOp, "non-static tensor");
5052

mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,24 @@ func.func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 {
2929

3030
// -----
3131

32+
// CHECK-LABEL: test_spirv_unsupported_type_index
33+
func.func @test_spirv_unsupported_type_index(%a : index) {
34+
%cst = arith.constant dense<[1, 2]> : tensor<2xindex>
35+
// CHECK: tensor.extract
36+
%extract = tensor.extract %cst[%a] : tensor<2xindex>
37+
return
38+
}
39+
40+
// CHECK-LABEL: test_spirv_unsupported_type_i128
41+
func.func @test_spirv_unsupported_type_i128(%a : index) {
42+
%cst = arith.constant dense<[1, 2]> : tensor<2xi128>
43+
// CHECK: tensor.extract
44+
%extract = tensor.extract %cst[%a] : tensor<2xi128>
45+
return
46+
}
47+
48+
// -----
49+
3250
//===----------------------------------------------------------------------===//
3351
// Type conversion
3452
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)