Skip to content

Commit 7a8bff4

Browse files
authored
[mlir][tosa-to-linalg] fix arithmetic_right_shift conversion with round (#159930)
Fixed: #154259 According to TOSA spec, `tosa.arithmetic_right_shift` should handle round. ``` if (round == true && static_cast<int32_t>(value2) > 0 && (apply_arith_rshift<in_out_t>(value1, apply_sub_s<in_out_t>(value2, 1)) & 1 != 0)) { result = result + 1; } ``` The original conversion is the similar as definition, and will convert to pseudo code ```c++ result = (value1 >> value2) + ( (i1)(value2 > 0) & (i1)((value1 >> (value2 - 1)) & 1) ) ``` But when value2 is 0,`value1 >> (value2 - 1)` will produce poison value because performing arithmetic right shift on a negative number. Then the poison value propagate to the final result. This PR wants to change the conversion to `arith.select` to stop poison propagation. ```c++ result = (value1 >> value2) + (value2 > 0) ? (i1)((value1 >> (value2 - 1)) & 1) : (i1)(0) ```
1 parent d08e445 commit 7a8bff4

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
305305
IntegerAttr::get(elementTy, 1));
306306
auto zero = arith::ConstantOp::create(rewriter, loc,
307307
IntegerAttr::get(elementTy, 0));
308+
auto i1zero =
309+
arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 0));
308310
auto i1one =
309311
arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 1));
310312

@@ -322,9 +324,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
322324
ArrayRef<NamedAttribute>());
323325
auto isInputOdd =
324326
arith::AndIOp::create(rewriter, loc, i1Ty, truncated, i1one);
325-
326-
auto shouldRound = arith::AndIOp::create(
327-
rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
327+
// shifted, truncated, isInputOdd can be poison when input2 is 0.
328+
auto shouldRound = arith::SelectOp::create(
329+
rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd, i1zero);
328330
auto extended =
329331
arith::ExtUIOp::create(rewriter, loc, resultTypes, shouldRound);
330332
return arith::AddIOp::create(rewriter, loc, resultTypes, result, extended);

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -698,13 +698,14 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns
698698
// CHECK: linalg.generic
699699
// CHECK: arith.constant 1
700700
// CHECK: arith.constant 0
701+
// CHECK: arith.constant false
701702
// CHECK: arith.constant true
702703
// CHECK: arith.cmpi
703704
// CHECK: arith.subi
704705
// CHECK: arith.shrsi
705706
// CHECK: arith.trunci
706707
// CHECK: and
707-
// CHECK: and
708+
// CHECK: arith.select
708709
// CHECK: arith.extui
709710
// CHECK: arith.addi
710711
%12 = tosa.arithmetic_right_shift %arg0, %arg0 {round = 1 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>

0 commit comments

Comments
 (0)