From f80660c69a562c6f9200718da9d960cd482fbffe Mon Sep 17 00:00:00 2001 From: swote Date: Wed, 30 Apr 2025 01:23:22 +0900 Subject: [PATCH] [MLIR][TOSA] Fix validation for unsigned integer types in RescaleOp This patch fixes a bug in the TOSA RescaleOp verifier that incorrectly rejects unsigned integer types (ui8, ui16), even though they are supported by the TOSA specification. The verifier now properly handles unsigned integer types when the corresponding input_unsigned or output_unsigned attribute is set to true. Added tests for ui8<->i8 and ui16<->i16 rescale operations. Fixes https://github.com/llvm/llvm-project/issues/135699 --- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 30 ++++++++++++++---------- mlir/test/Dialect/Tosa/availability.mlir | 22 +++++++++++++++++ 2 files changed, 40 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index b2e471f2bba93..980ef18b975f9 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -2111,24 +2111,30 @@ static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal, const int64_t &zp, const std::string &operand) { bool isInputZp = (operand == "Input"); - bool tensorUnsigned = - isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned(); + isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned(); StringRef tensorName = isInputZp ? "input" : "output"; - Type zpElemType = getElementTypeOrSelf(zpVal); if (zp != 0) { - if (!zpElemType.isInteger(8) && - !(zpElemType.isInteger(16) && tensorUnsigned)) { - return op.emitOpError() - << "expect " << tensorName << "_zp of 0, got " << zp; + bool validType = zpElemType.isInteger(8); + + if (tensorUnsigned && zpElemType.isInteger(8)) { + validType = true; } - if (zpElemType.isInteger(16) && tensorUnsigned && - zp != static_cast(32768)) { - return op.emitOpError() << "expect " << tensorName - << "_zp of 0 or 32768 for unsigned int16 " - << tensorName << ", got " << zp; + + if (zpElemType.isInteger(16) && tensorUnsigned) { + validType = true; + if (zp != 32768) { + return op.emitOpError() << "expect " << tensorName + << "_zp of 0 or 32768 for unsigned int16 " + << tensorName << ", got " << zp; + } + } + + if (!validType) { + return op.emitOpError() + << "expect " << tensorName << "_zp of 0, got " << zp; } } diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir index 75126a11ac504..08d2bd30cf971 100644 --- a/mlir/test/Dialect/Tosa/availability.mlir +++ b/mlir/test/Dialect/Tosa/availability.mlir @@ -622,6 +622,28 @@ func.func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform> } +// ----- +// CHECK-LABEL: test_rescale +func.func @test_rescale_unsigned_i8(%arg0: tensor<13x21x3x!quant.uniform>, %multiplier : tensor<1xi32>, %shift : tensor<1xi8>) -> tensor<13x21x3x!quant.uniform> { + %input_zp = "tosa.const"() {values = dense<127> : tensor<1xi8>} : () -> tensor<1xi8> + %output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi8>} : () -> tensor<1xi8> + // CHECK: tosa.rescale profiles: [ [pro_int] ] + // CHECK: tosa.rescale extensions: [ [int16] ] + %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", scale32 = true, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform> + return %0 : tensor<13x21x3x!quant.uniform> +} + +// ----- +// CHECK-LABEL: test_rescale +func.func @test_rescale_to_unsigned_i8(%arg0: tensor<13x21x3x!quant.uniform>, %multiplier : tensor<1xi32>, %shift : tensor<1xi8>) -> tensor<13x21x3x!quant.uniform> { + %input_zp = "tosa.const"() {values = dense<-1> : tensor<1xi8>} : () -> tensor<1xi8> + %output_zp = "tosa.const"() {values = dense<127> : tensor<1xi8>} : () -> tensor<1xi8> + // CHECK: tosa.rescale profiles: [ [pro_int] ] + // CHECK: tosa.rescale extensions: [ [int16] ] + %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", scale32 = true, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform> + return %0 : tensor<13x21x3x!quant.uniform> +} + // ----- // CHECK-LABEL: test_const func.func @test_const(%arg0 : index) -> tensor<4xi32> {