Skip to content

Commit 58e7545

Browse files
committed
[TOSA] Fix negate maxValue computation
getInput1Zp() returns an unsigned value which means in case of negative zero point value the max intermediate value computation currently goes wrong. Use getInput1ZpAttr() instead which returns an APInt and allows easy sign extension to int64_t.
1 parent 5566bfa commit 58e7545

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
146146
return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
147147

148148
if (isa<IntegerType>(elementTy)) {
149-
auto inputZpAttr = cast<tosa::NegateOp>(op).getInput1Zp();
150-
auto outputZpAttr = cast<tosa::NegateOp>(op).getOutputZp();
149+
auto inputZpAttr = cast<tosa::NegateOp>(op).getInput1ZpAttr();
150+
auto outputZpAttr = cast<tosa::NegateOp>(op).getOutputZpAttr();
151151

152-
const int64_t inZp = inputZpAttr ? *inputZpAttr : 0;
153-
const int64_t outZp = outputZpAttr ? *outputZpAttr : 0;
152+
const int64_t inZp = inputZpAttr ? inputZpAttr.getValue().getSExtValue() : 0;
153+
const int64_t outZp = outputZpAttr ? outputZpAttr.getValue().getSExtValue() : 0;
154154

155155
if (!inZp && !outZp) {
156156
auto constant = rewriter.create<arith::ConstantOp>(

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -911,12 +911,25 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
911911
// CHECK: linalg.yield [[TRUNC]]
912912
%2 = tosa.negate %arg0 {input1_zp = 32640 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
913913

914+
// CHECK: linalg.generic
915+
// CHECK: ^bb0(%[[BBARG0:.+]]: i8,
916+
// CHECK: [[C_128:%.+]] = arith.constant -128
917+
// CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
918+
// CHECK: [[SUB:%.+]] = arith.subi [[C_128]], [[EXT]]
919+
// CHECK: [[MIN:%.+]] = arith.constant -128
920+
// CHECK: [[MAX:%.+]] = arith.constant 127
921+
// CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
922+
// CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
923+
// CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
924+
// CHECK: linalg.yield [[TRUNC]]
925+
%3 = tosa.negate %arg0 {input1_zp = -128 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
926+
914927
// CHECK: linalg.generic
915928
// CHECK: ^bb0(%[[BBARG0:.+]]: i8,
916929
// CHECK: [[ZERO:%.+]] = arith.constant 0
917930
// CHECK: [[SUB:%.+]] = arith.subi [[ZERO]],
918931
// CHECK: linalg.yield [[SUB]]
919-
%3 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
932+
%4 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8>
920933

921934
return
922935
}

0 commit comments

Comments
 (0)