Skip to content

Commit e08226b

Browse files
lhutton1aokblast
authored andcommitted
[mlir][tosa] Fix argmax folder when output type is i64 (llvm#163583)
Previously the following IR: ``` tosa.argmax %arg0 {axis = 0 : i32} : (tensor<1xi8>) -> tensor<i64> ``` Would result in a crash with the assertion: ``` expected dense element bit width 64 to match data size 32 for type i64 ``` This commit ensures that zero is constructed with the correct bitwidth while folding, therefore fixing the crash.
1 parent 8d67157 commit e08226b

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,8 +1001,12 @@ OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
10011001
!outputTy.hasStaticShape())
10021002
return {};
10031003

1004-
if (inputTy.getDimSize(getAxis()) == 1)
1005-
return DenseElementsAttr::get(outputTy, 0);
1004+
const Type outputElementTy = getElementTypeOrSelf(outputTy);
1005+
if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.isInteger()) {
1006+
const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
1007+
const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
1008+
return DenseElementsAttr::get(outputTy, zero);
1009+
}
10061010

10071011
return {};
10081012
}

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,15 @@ func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<1xi32> {
99

1010
// -----
1111

12+
// CHECK-LABEL: @test_argmax_fold_i64_index
13+
func.func @test_argmax_fold_i64_index(%arg0: tensor<1xi8>) -> tensor<i64> {
14+
// CHECK: "tosa.const"() <{values = dense<0> : tensor<i64>}> : () -> tensor<i64>
15+
%0 = tosa.argmax %arg0 {axis = 0 : i32} : (tensor<1xi8>) -> tensor<i64>
16+
return %0 : tensor<i64>
17+
}
18+
19+
// -----
20+
1221
// CHECK-LABEL: @pad_wh_avg_pool2d_fold
1322
func.func @pad_wh_avg_pool2d_fold(%input: tensor<1x10x8x3xf32>) -> tensor<1x6x5x3xf32> {
1423
// CHECK-NOT: tosa.pad

0 commit comments

Comments
 (0)