Skip to content

Commit 358fc77

Browse files
RoboTuxIcohedron
authored andcommitted
[TOSA] Fix negate maxValue computation (llvm#126295)
getInput1Zp() returns an unsigned value which means in case of negative zero point value the max intermediate value computation currently goes wrong. Use getInput1ZpAttr() instead which returns an APInt and allows easy sign extension to int64_t.
1 parent b284abf commit 358fc77

File tree

2 files changed

+20
-5
lines changed

2 files changed

+20
-5
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,13 @@ static Value createLinalgBodyCalculationForElementwiseOp(
146146
return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
147147

148148
if (isa<IntegerType>(elementTy)) {
149-
auto inputZpAttr = cast<tosa::NegateOp>(op).getInput1Zp();
150-
auto outputZpAttr = cast<tosa::NegateOp>(op).getOutputZp();
149+
auto inputZpAttr = cast<tosa::NegateOp>(op).getInput1ZpAttr();
150+
auto outputZpAttr = cast<tosa::NegateOp>(op).getOutputZpAttr();
151151

152-
const int64_t inZp = inputZpAttr ? *inputZpAttr : 0;
153-
const int64_t outZp = outputZpAttr ? *outputZpAttr : 0;
152+
const int64_t inZp =
153+
inputZpAttr ? inputZpAttr.getValue().getSExtValue() : 0;
154+
const int64_t outZp =
155+
outputZpAttr ? outputZpAttr.getValue().getSExtValue() : 0;
154156

155157
if (!inZp && !outZp) {
156158
auto constant = rewriter.create<arith::ConstantOp>(

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -911,12 +911,25 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
911911
// CHECK: linalg.yield [[TRUNC]]
912912
%2 = tosa.negate %arg0 {input1_zp = 32640 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
913913

914+
// CHECK: linalg.generic
915+
// CHECK: ^bb0(%[[BBARG0:.+]]: i8,
916+
// CHECK: [[C_128:%.+]] = arith.constant -128
917+
// CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
918+
// CHECK: [[SUB:%.+]] = arith.subi [[C_128]], [[EXT]]
919+
// CHECK: [[MIN:%.+]] = arith.constant -128
920+
// CHECK: [[MAX:%.+]] = arith.constant 127
921+
// CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
922+
// CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
923+
// CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
924+
// CHECK: linalg.yield [[TRUNC]]
925+
%3 = tosa.negate %arg0 {input1_zp = -128 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
926+
914927
// CHECK: linalg.generic
915928
// CHECK: ^bb0(%[[BBARG0:.+]]: i8,
916929
// CHECK: [[ZERO:%.+]] = arith.constant 0
917930
// CHECK: [[SUB:%.+]] = arith.subi [[ZERO]],
918931
// CHECK: linalg.yield [[SUB]]
919-
%3 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
932+
%4 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
920933

921934
return
922935
}

0 commit comments

Comments
 (0)