-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][tosa] Fix integer-to-boolean cast folder #150370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
According to the TOSA spec, casting to boolean should produce true if the input is non-zero, and false otherwise — i.e., `out = (in != 0) ? true : false`. Previous behavior incorrectly relied on truncation, which could yield incorrect results for non-zero values whose least significant bit is zero.
|
@llvm/pr-subscribers-mlir Author: Longsheng Mou (CoTinker) ChangesAccording to the TOSA spec, casting to boolean should produce true if the input is non-zero, and false otherwise — i.e., Full diff: https://github.com/llvm/llvm-project/pull/150370.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 606626dfe4d2c..080955bf94761 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1302,9 +1302,13 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
auto intVal = operand.getSplatValue<APInt>();
auto bitwidth = outETy.getIntOrFloatBitWidth();
+ // i1 types are boolean in TOSA
if (trunc) {
- intVal = intVal.trunc(bitwidth);
- // i1 types are boolean in TOSA
+ if (outETy.isInteger(1)) {
+ intVal = intVal.isZero() ? APInt(bitwidth, 0) : APInt(bitwidth, 1);
+ } else {
+ intVal = intVal.trunc(bitwidth);
+ }
} else if (unsignIn || inIntType.isInteger(1)) {
intVal = intVal.zext(bitwidth);
} else {
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 11c8d54fda055..6b55442a82a0a 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1349,3 +1349,14 @@ func.func @test_fold_i1_to_i32_cast() -> tensor<i32> {
%1 = "tosa.cast"(%0) : (tensor<i1>) -> tensor<i32>
return %1 : tensor<i32>
}
+
+// -----
+
+// CHECK-LABEL: @test_fold_i32_to_i1_cast
+// CHECK: %[[OUT:.*]] = "tosa.const"() <{values = dense<true> : tensor<i1>}> : () -> tensor<i1>
+// CHECK: return %[[OUT]] : tensor<i1>
+func.func @test_fold_i32_to_i1_cast() -> tensor<i1> {
+ %0 = "tosa.const"() <{values = dense<10> : tensor<i32>}> : () -> tensor<i32>
+ %1 = "tosa.cast"(%0) : (tensor<i32>) -> tensor<i1>
+ return %1 : tensor<i1>
+}
|
|
@llvm/pr-subscribers-mlir-tosa Author: Longsheng Mou (CoTinker) ChangesAccording to the TOSA spec, casting to boolean should produce true if the input is non-zero, and false otherwise — i.e., Full diff: https://github.com/llvm/llvm-project/pull/150370.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 606626dfe4d2c..080955bf94761 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1302,9 +1302,13 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
auto intVal = operand.getSplatValue<APInt>();
auto bitwidth = outETy.getIntOrFloatBitWidth();
+ // i1 types are boolean in TOSA
if (trunc) {
- intVal = intVal.trunc(bitwidth);
- // i1 types are boolean in TOSA
+ if (outETy.isInteger(1)) {
+ intVal = intVal.isZero() ? APInt(bitwidth, 0) : APInt(bitwidth, 1);
+ } else {
+ intVal = intVal.trunc(bitwidth);
+ }
} else if (unsignIn || inIntType.isInteger(1)) {
intVal = intVal.zext(bitwidth);
} else {
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 11c8d54fda055..6b55442a82a0a 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1349,3 +1349,14 @@ func.func @test_fold_i1_to_i32_cast() -> tensor<i32> {
%1 = "tosa.cast"(%0) : (tensor<i1>) -> tensor<i32>
return %1 : tensor<i32>
}
+
+// -----
+
+// CHECK-LABEL: @test_fold_i32_to_i1_cast
+// CHECK: %[[OUT:.*]] = "tosa.const"() <{values = dense<true> : tensor<i1>}> : () -> tensor<i1>
+// CHECK: return %[[OUT]] : tensor<i1>
+func.func @test_fold_i32_to_i1_cast() -> tensor<i1> {
+ %0 = "tosa.const"() <{values = dense<10> : tensor<i32>}> : () -> tensor<i32>
+ %1 = "tosa.cast"(%0) : (tensor<i32>) -> tensor<i1>
+ return %1 : tensor<i1>
+}
|
lhutton1
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the quick fix @CoTinker! Had a nitpick, otherwise LGTM!
According to the TOSA spec, casting to boolean should produce true if the input is non-zero, and false otherwise — i.e., `out = (in != 0) ? true : false`. Previous behavior incorrectly relied on truncation, which could yield incorrect results for non-zero values whose least significant bit is zero. Fixes llvm#150302.
According to the TOSA spec, casting to boolean should produce true if the input is non-zero, and false otherwise — i.e.,
out = (in != 0) ? true : false. Previous behavior incorrectly relied on truncation, which could yield incorrect results for non-zero values whose least significant bit is zero. Fixes #150302.