diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index b2e471f2bba93..980ef18b975f9 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -2111,24 +2111,30 @@ static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal, const int64_t &zp, const std::string &operand) { bool isInputZp = (operand == "Input"); - bool tensorUnsigned = - isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned(); + isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned(); StringRef tensorName = isInputZp ? "input" : "output"; - Type zpElemType = getElementTypeOrSelf(zpVal); if (zp != 0) { - if (!zpElemType.isInteger(8) && - !(zpElemType.isInteger(16) && tensorUnsigned)) { - return op.emitOpError() - << "expect " << tensorName << "_zp of 0, got " << zp; + bool validType = zpElemType.isInteger(8); + + if (tensorUnsigned && zpElemType.isInteger(8)) { + validType = true; } - if (zpElemType.isInteger(16) && tensorUnsigned && - zp != static_cast(32768)) { - return op.emitOpError() << "expect " << tensorName - << "_zp of 0 or 32768 for unsigned int16 " - << tensorName << ", got " << zp; + + if (zpElemType.isInteger(16) && tensorUnsigned) { + validType = true; + if (zp != 32768) { + return op.emitOpError() << "expect " << tensorName + << "_zp of 0 or 32768 for unsigned int16 " + << tensorName << ", got " << zp; + } + } + + if (!validType) { + return op.emitOpError() + << "expect " << tensorName << "_zp of 0, got " << zp; } } diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir index 75126a11ac504..08d2bd30cf971 100644 --- a/mlir/test/Dialect/Tosa/availability.mlir +++ b/mlir/test/Dialect/Tosa/availability.mlir @@ -622,6 +622,28 @@ func.func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform> } +// ----- +// CHECK-LABEL: test_rescale +func.func @test_rescale_unsigned_i8(%arg0: tensor<13x21x3x!quant.uniform>, %multiplier : tensor<1xi32>, %shift : tensor<1xi8>) -> tensor<13x21x3x!quant.uniform> { + %input_zp = "tosa.const"() {values = dense<127> : tensor<1xi8>} : () -> tensor<1xi8> + %output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi8>} : () -> tensor<1xi8> + // CHECK: tosa.rescale profiles: [ [pro_int] ] + // CHECK: tosa.rescale extensions: [ [int16] ] + %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", scale32 = true, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform> + return %0 : tensor<13x21x3x!quant.uniform> +} + +// ----- +// CHECK-LABEL: test_rescale +func.func @test_rescale_to_unsigned_i8(%arg0: tensor<13x21x3x!quant.uniform>, %multiplier : tensor<1xi32>, %shift : tensor<1xi8>) -> tensor<13x21x3x!quant.uniform> { + %input_zp = "tosa.const"() {values = dense<-1> : tensor<1xi8>} : () -> tensor<1xi8> + %output_zp = "tosa.const"() {values = dense<127> : tensor<1xi8>} : () -> tensor<1xi8> + // CHECK: tosa.rescale profiles: [ [pro_int] ] + // CHECK: tosa.rescale extensions: [ [int16] ] + %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", scale32 = true, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform> + return %0 : tensor<13x21x3x!quant.uniform> +} + // ----- // CHECK-LABEL: test_const func.func @test_const(%arg0 : index) -> tensor<4xi32> {