Skip to content

Commit 262afc8

Browse files
authored
[mlir][TosaToLinalg] RescaleConverter only support integer type (#114239)
This PR fixes a bug in the `RescaleConverter` that allows non-integer types, which leads to a crash. Fixes #61383.
1 parent d3daa3c commit 262afc8

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,6 +1153,9 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
11531153
return rewriter.notifyMatchFailure(
11541154
op, "tosa.rescale requires scale32 for double_round to be true");
11551155

1156+
if (!isa<IntegerType>(inputTy.getElementType()))
1157+
return rewriter.notifyMatchFailure(op, "only support integer type");
1158+
11561159
SmallVector<Value> dynDims;
11571160
for (int i = 0; i < outputTy.getRank(); i++) {
11581161
if (outputTy.isDynamicDim(i)) {

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,12 @@ func.func @rfft2d_with_non_float_type(%arg0 : tensor<1x1x1xi32>) -> (tensor<1x1x
3636
%real, %imag = tosa.rfft2d %arg0 : (tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>)
3737
return %real, %imag : tensor<1x1x1xi32>, tensor<1x1x1xi32>
3838
}
39+
40+
// -----
41+
42+
// CHECK-LABEL: @rescale_unsupported_type
43+
func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
44+
// expected-error@+1 {{failed to legalize operation 'tosa.rescale'}}
45+
%0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
46+
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
47+
}

0 commit comments

Comments
 (0)