Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1302,9 +1302,13 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
auto intVal = operand.getSplatValue<APInt>();
auto bitwidth = outETy.getIntOrFloatBitWidth();

// i1 types are boolean in TOSA
if (trunc) {
intVal = intVal.trunc(bitwidth);
// i1 types are boolean in TOSA
if (outETy.isInteger(1)) {
intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
} else {
intVal = intVal.trunc(bitwidth);
}
} else if (unsignIn || inIntType.isInteger(1)) {
intVal = intVal.zext(bitwidth);
} else {
Expand Down
11 changes: 11 additions & 0 deletions mlir/test/Dialect/Tosa/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1349,3 +1349,14 @@ func.func @test_fold_i1_to_i32_cast() -> tensor<i32> {
%1 = "tosa.cast"(%0) : (tensor<i1>) -> tensor<i32>
return %1 : tensor<i32>
}

// -----

// CHECK-LABEL: @test_fold_i32_to_i1_cast
// CHECK: %[[OUT:.*]] = "tosa.const"() <{values = dense<true> : tensor<i1>}> : () -> tensor<i1>
// CHECK: return %[[OUT]] : tensor<i1>
func.func @test_fold_i32_to_i1_cast() -> tensor<i1> {
%0 = "tosa.const"() <{values = dense<10> : tensor<i32>}> : () -> tensor<i32>
%1 = "tosa.cast"(%0) : (tensor<i32>) -> tensor<i1>
return %1 : tensor<i1>
}