diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 251c48859d2de..5291f95d37144 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1153,6 +1153,9 @@ class RescaleConverter : public OpRewritePattern { return rewriter.notifyMatchFailure( op, "tosa.rescale requires scale32 for double_round to be true"); + if (!isa(inputTy.getElementType())) + return rewriter.notifyMatchFailure(op, "only support integer type"); + SmallVector dynDims; for (int i = 0; i < outputTy.getRank(); i++) { if (outputTy.isDynamicDim(i)) { diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir index b78577275a52a..ea1b79cbd9507 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir @@ -36,3 +36,12 @@ func.func @rfft2d_with_non_float_type(%arg0 : tensor<1x1x1xi32>) -> (tensor<1x1x %real, %imag = tosa.rfft2d %arg0 : (tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>) return %real, %imag : tensor<1x1x1xi32>, tensor<1x1x1xi32> } + +// ----- + +// CHECK-LABEL: @rescale_unsupported_type +func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> { + // expected-error@+1 {{failed to legalize operation 'tosa.rescale'}} + %0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array} : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> + return %0 : tensor<13x21x3x!quant.uniform> +}