Skip to content

Commit 0a8ddd3

Browse files
authored
[mlir][tosa] Interpret boolean values correctly (#149312)
Previously the `ClampOp::verify` would sign extend boolean values, leading "true" to be casted to a value of -1 instead of 1. This PR ensures i1 values are zero extended, since i1 is used as a boolean value in TOSA.
1 parent 58be622 commit 0a8ddd3

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -945,10 +945,11 @@ LogicalResult tosa::ClampOp::verify() {
945945
return emitOpError("min/max attributes types are incompatible with "
946946
"input/output element types.");
947947

948-
const bool isUnsigned = cast<IntegerType>(inputETy).isUnsigned();
948+
const bool isUnsigned = inputETy.isUnsignedInteger();
949+
const bool isBoolean = inputETy.isInteger(1);
949950
const APInt minVal = intMinValAttr.getValue();
950951
const APInt maxVal = intMaxValAttr.getValue();
951-
if (isUnsigned ? maxVal.ult(minVal) : maxVal.slt(minVal))
952+
if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
952953
return emitOpError("expected min_val <= max_val, got min_val=")
953954
<< minValAttr << ", max_val=" << maxValAttr;
954955
} else {

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,15 @@ func.func @test_mismatch_in_out_shape_clamp(%arg0: tensor<13x21x3xf32>) -> tenso
750750

751751
// -----
752752

753+
// CHECK-LABEL: test_unsupported_boolean_type_clamp
754+
func.func @test_unsupported_boolean_type_clamp(%arg0: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> {
755+
// expected-error@+1 {{'tosa.clamp' op illegal: operation operand/result data types did not align with any profile or extension, got (i1,i1), did you mean (i8,i8)?}}
756+
%0 = tosa.clamp %arg0 {min_val = false, max_val = true} : (tensor<13x21x3xi1>) -> tensor<13x21x3xi1>
757+
return %0 : tensor<13x21x3xi1>
758+
}
759+
760+
// -----
761+
753762
// CHECK-LABEL: test_mismatch_in_out_data_type_erf
754763
func.func @test_mismatch_in_out_data_type_erf(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf16> {
755764
// expected-error@+1 {{'tosa.erf' op requires the same element type for all operands and results}}

0 commit comments

Comments
 (0)