Skip to content

Conversation

@CoTinker
Copy link
Contributor

According to the TOSA spec, casting to boolean should produce true if the input is non-zero, and false otherwise — i.e., out = (in != 0) ? true : false. Previous behavior incorrectly relied on truncation, which could yield incorrect results for non-zero values whose least significant bit is zero. Fixes #150302.

CoTinker added 2 commits July 24, 2025 12:24
According to the TOSA spec, casting to boolean should produce true if the input is non-zero, and false otherwise — i.e., `out = (in != 0) ? true : false`. Previous behavior incorrectly relied on truncation, which could yield incorrect results for non-zero values whose least significant bit is zero.
@llvmbot
Copy link
Member

llvmbot commented Jul 24, 2025

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

Changes

According to the TOSA spec, casting to boolean should produce true if the input is non-zero, and false otherwise — i.e., out = (in != 0) ? true : false. Previous behavior incorrectly relied on truncation, which could yield incorrect results for non-zero values whose least significant bit is zero. Fixes #150302.


Full diff: https://github.com/llvm/llvm-project/pull/150370.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+6-2)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+11)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 606626dfe4d2c..080955bf94761 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1302,9 +1302,13 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
       auto intVal = operand.getSplatValue<APInt>();
       auto bitwidth = outETy.getIntOrFloatBitWidth();
 
+      // i1 types are boolean in TOSA
       if (trunc) {
-        intVal = intVal.trunc(bitwidth);
-        // i1 types are boolean in TOSA
+        if (outETy.isInteger(1)) {
+          intVal = intVal.isZero() ? APInt(bitwidth, 0) : APInt(bitwidth, 1);
+        } else {
+          intVal = intVal.trunc(bitwidth);
+        }
       } else if (unsignIn || inIntType.isInteger(1)) {
         intVal = intVal.zext(bitwidth);
       } else {
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 11c8d54fda055..6b55442a82a0a 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1349,3 +1349,14 @@ func.func @test_fold_i1_to_i32_cast() -> tensor<i32> {
   %1 = "tosa.cast"(%0) : (tensor<i1>) -> tensor<i32>
   return %1 : tensor<i32>
 }
+
+// -----
+
+// CHECK-LABEL: @test_fold_i32_to_i1_cast
+// CHECK: %[[OUT:.*]] = "tosa.const"() <{values = dense<true> : tensor<i1>}> : () -> tensor<i1>
+// CHECK: return %[[OUT]] : tensor<i1>
+func.func @test_fold_i32_to_i1_cast() -> tensor<i1> {
+  %0 = "tosa.const"() <{values = dense<10> : tensor<i32>}> : () -> tensor<i32>
+  %1 = "tosa.cast"(%0) : (tensor<i32>) -> tensor<i1>
+  return %1 : tensor<i1>
+}

@llvmbot
Copy link
Member

llvmbot commented Jul 24, 2025

@llvm/pr-subscribers-mlir-tosa

Author: Longsheng Mou (CoTinker)

Changes

According to the TOSA spec, casting to boolean should produce true if the input is non-zero, and false otherwise — i.e., out = (in != 0) ? true : false. Previous behavior incorrectly relied on truncation, which could yield incorrect results for non-zero values whose least significant bit is zero. Fixes #150302.


Full diff: https://github.com/llvm/llvm-project/pull/150370.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+6-2)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+11)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 606626dfe4d2c..080955bf94761 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1302,9 +1302,13 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
       auto intVal = operand.getSplatValue<APInt>();
       auto bitwidth = outETy.getIntOrFloatBitWidth();
 
+      // i1 types are boolean in TOSA
       if (trunc) {
-        intVal = intVal.trunc(bitwidth);
-        // i1 types are boolean in TOSA
+        if (outETy.isInteger(1)) {
+          intVal = intVal.isZero() ? APInt(bitwidth, 0) : APInt(bitwidth, 1);
+        } else {
+          intVal = intVal.trunc(bitwidth);
+        }
       } else if (unsignIn || inIntType.isInteger(1)) {
         intVal = intVal.zext(bitwidth);
       } else {
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 11c8d54fda055..6b55442a82a0a 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1349,3 +1349,14 @@ func.func @test_fold_i1_to_i32_cast() -> tensor<i32> {
   %1 = "tosa.cast"(%0) : (tensor<i1>) -> tensor<i32>
   return %1 : tensor<i32>
 }
+
+// -----
+
+// CHECK-LABEL: @test_fold_i32_to_i1_cast
+// CHECK: %[[OUT:.*]] = "tosa.const"() <{values = dense<true> : tensor<i1>}> : () -> tensor<i1>
+// CHECK: return %[[OUT]] : tensor<i1>
+func.func @test_fold_i32_to_i1_cast() -> tensor<i1> {
+  %0 = "tosa.const"() <{values = dense<10> : tensor<i32>}> : () -> tensor<i32>
+  %1 = "tosa.cast"(%0) : (tensor<i32>) -> tensor<i1>
+  return %1 : tensor<i1>
+}

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick fix @CoTinker! Had a nitpick, otherwise LGTM!

@CoTinker CoTinker merged commit c2c881f into llvm:main Jul 24, 2025
9 checks passed
@CoTinker CoTinker deleted the cast_i1_folder branch July 24, 2025 14:48
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Jul 28, 2025
According to the TOSA spec, casting to boolean should produce true if
the input is non-zero, and false otherwise — i.e., `out = (in != 0) ?
true : false`. Previous behavior incorrectly relied on truncation, which
could yield incorrect results for non-zero values whose least
significant bit is zero. Fixes llvm#150302.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[MLIR][TOSA] Folding cast to bool gives wrong value

3 participants