From 58e7545e40e9bd1b10849e5172687fa8e12c3e55 Mon Sep 17 00:00:00 2001 From: Thomas Preud'homme Date: Fri, 7 Feb 2025 19:00:04 +0000 Subject: [PATCH 1/2] [TOSA] Fix negate maxValue computation getInput1Zp() returns an unsigned value which means in case of negative zero point value the max intermediate value computation currently goes wrong. Use getInput1ZpAttr() instead which returns an APInt and allows easy sign extension to int64_t. --- mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 8 ++++---- .../Conversion/TosaToLinalg/tosa-to-linalg.mlir | 15 ++++++++++++++- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index e4f055ea2f5c4..de6a156ed3f57 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -146,11 +146,11 @@ static Value createLinalgBodyCalculationForElementwiseOp( return rewriter.create(loc, resultTypes, args); if (isa(elementTy)) { - auto inputZpAttr = cast(op).getInput1Zp(); - auto outputZpAttr = cast(op).getOutputZp(); + auto inputZpAttr = cast(op).getInput1ZpAttr(); + auto outputZpAttr = cast(op).getOutputZpAttr(); - const int64_t inZp = inputZpAttr ? *inputZpAttr : 0; - const int64_t outZp = outputZpAttr ? *outputZpAttr : 0; + const int64_t inZp = inputZpAttr ? inputZpAttr.getValue().getSExtValue() : 0; + const int64_t outZp = outputZpAttr ? outputZpAttr.getValue().getSExtValue() : 0; if (!inZp && !outZp) { auto constant = rewriter.create( diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 3031434e6d4ba..d8ba28a3ce887 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -911,12 +911,25 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () { // CHECK: linalg.yield [[TRUNC]] %2 = tosa.negate %arg0 {input1_zp = 32640 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8> + // CHECK: linalg.generic + // CHECK: ^bb0(%[[BBARG0:.+]]: i8, + // CHECK: [[C_128:%.+]] = arith.constant -128 + // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16 + // CHECK: [[SUB:%.+]] = arith.subi [[C_128]], [[EXT]] + // CHECK: [[MIN:%.+]] = arith.constant -128 + // CHECK: [[MAX:%.+]] = arith.constant 127 + // CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]] + // CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]] + // CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]] + // CHECK: linalg.yield [[TRUNC]] + %3 = tosa.negate %arg0 {input1_zp = -128 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8> + // CHECK: linalg.generic // CHECK: ^bb0(%[[BBARG0:.+]]: i8, // CHECK: [[ZERO:%.+]] = arith.constant 0 // CHECK: [[SUB:%.+]] = arith.subi [[ZERO]], // CHECK: linalg.yield [[SUB]] - %3 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant} : (tensor<1xi8>) -> tensor<1xi8> + %4 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant} : (tensor<1xi8>) -> tensor<1xi8> return } From 6cba9c1935a5e42380cef7fd0b914bb30c8b0de2 Mon Sep 17 00:00:00 2001 From: Thomas Preud'homme Date: Fri, 7 Feb 2025 20:55:55 +0000 Subject: [PATCH 2/2] Fix code formatting. --- mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index de6a156ed3f57..0246d9019368a 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -149,8 +149,10 @@ static Value createLinalgBodyCalculationForElementwiseOp( auto inputZpAttr = cast(op).getInput1ZpAttr(); auto outputZpAttr = cast(op).getOutputZpAttr(); - const int64_t inZp = inputZpAttr ? inputZpAttr.getValue().getSExtValue() : 0; - const int64_t outZp = outputZpAttr ? outputZpAttr.getValue().getSExtValue() : 0; + const int64_t inZp = + inputZpAttr ? inputZpAttr.getValue().getSExtValue() : 0; + const int64_t outZp = + outputZpAttr ? outputZpAttr.getValue().getSExtValue() : 0; if (!inZp && !outZp) { auto constant = rewriter.create(