Skip to content

Commit a318c50

Browse files
authored
[mlir][tosa] Remove NegateOp to SubOp and 48-bit promotion in TosaToLinalg (#170622)
The patch motivated by Tosa Conformance test negate_32x45x49_i16_full failure. TosaToLinalg pass has an optimization to transfer Tosa Negate to Sub if the zero points are zeros. However, when the input value is minimum negative number, the transformation will cause the underflow. By removing the transformation, if zp = 0 it would do the promotion to avoid the underflow. Promotion types could be from int32 to int48. TOSA negate specification does not mention support for int48. Should we consider removing the promotion to int48 to stay aligned with the TOSA spec?
1 parent 19e1011 commit a318c50

File tree

2 files changed

+22
-15
lines changed

2 files changed

+22
-15
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -200,13 +200,6 @@ static Value createLinalgBodyCalculationForElementwiseOp(
200200
return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
201201

202202
if (isa<IntegerType>(elementTy)) {
203-
if (hasInZp && hasOutZp && !inZp && !outZp) {
204-
auto constant = arith::ConstantOp::create(
205-
rewriter, loc, IntegerAttr::get(elementTy, 0));
206-
return arith::SubIOp::create(rewriter, loc, resultTypes, constant,
207-
args[0]);
208-
}
209-
210203
Value zpAddValue;
211204
Type intermediateType;
212205
// Compute the maximum value that can occur in the intermediate buffer.
@@ -221,14 +214,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
221214
std::abs(zpAdd) + 1;
222215

223216
// Convert that maximum value into the maximum bitwidth needed to
224-
// represent it. We assume 48-bit numbers may be supported further in
225-
// the pipeline.
217+
// represent it.
226218
if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
227219
intermediateBitWidth = 16;
228220
} else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
229221
intermediateBitWidth = 32;
230-
} else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
231-
intermediateBitWidth = 48;
232222
}
233223

234224
intermediateType = rewriter.getIntegerType(intermediateBitWidth);

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,13 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns
666666
// CHECK: linalg.generic
667667
// CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32):
668668
// CHECK: [[ZERO:%.+]] = arith.constant 0
669-
// CHECK: arith.subi [[ZERO]], %[[ARG1]]
669+
// CHECK: [[EXT:%.+]] = arith.extsi %[[IN:.*]] : i32 to i64
670+
// CHECK: [[SUB:%.+]] = arith.subi [[ZERO]], [[EXT]]
671+
// CHECK: [[MIN:%.+]] = arith.constant -2147483648
672+
// CHECK: [[MAX:%.+]] = arith.constant 2147483647
673+
// CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
674+
// CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
675+
// CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
670676
%in_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
671677
%out_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
672678
%5 = tosa.negate %arg0, %in_zp, %out_zp : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
@@ -889,8 +895,13 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
889895
// CHECK: linalg.generic
890896
// CHECK: ^bb0(%[[BBARG0:.+]]: i8,
891897
// CHECK: [[ZERO:%.+]] = arith.constant 0
892-
// CHECK: [[SUB:%.+]] = arith.subi [[ZERO]],
893-
// CHECK: linalg.yield [[SUB]]
898+
// CHECK: [[EXT:%.+]] = arith.extsi %[[IN:.*]] : i8 to i16
899+
// CHECK: [[SUB:%.+]] = arith.subi [[ZERO]], [[EXT]]
900+
// CHECK: [[MIN:%.+]] = arith.constant -128
901+
// CHECK: [[MAX:%.+]] = arith.constant 127
902+
// CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
903+
// CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
904+
// CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
894905
%in_zp4 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
895906
%out_zp4 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
896907
%4 = tosa.negate %arg0, %in_zp4, %out_zp4 : (tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8>
@@ -2610,7 +2621,13 @@ func.func @test_0d_input(%arg0: tensor<i32>) -> () {
26102621
// CHECK: linalg.generic
26112622
// CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32):
26122623
// CHECK: [[ZERO:%.+]] = arith.constant 0
2613-
// CHECK: arith.subi [[ZERO]], %[[ARG1]]
2624+
// CHECK: [[EXT:%.+]] = arith.extsi %[[IN:.*]] : i32 to i64
2625+
// CHECK: [[SUB:%.+]] = arith.subi [[ZERO]], [[EXT]]
2626+
// CHECK: [[MIN:%.+]] = arith.constant -2147483648
2627+
// CHECK: [[MAX:%.+]] = arith.constant 2147483647
2628+
// CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
2629+
// CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
2630+
// CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
26142631
%in_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
26152632
%out_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
26162633
%5 = tosa.negate %arg0, %in_zp, %out_zp : (tensor<i32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>

0 commit comments

Comments
 (0)