Skip to content

Commit 1023f6c

Browse files
authored
[mlir][tosa] Support boolean types for clamp folder (#151653)
This PR fixes several bugs in `ClampIsNoOp` pattern. - static shape check is no need. - ensures i1 values are zero extended to support fold boolean types clamp. Fixes #130016.
1 parent f9088f1 commit 1023f6c

File tree

2 files changed

+41
-27
lines changed

2 files changed

+41
-27
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -452,18 +452,14 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
452452
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
453453
auto inputElementType = inputType.getElementType();
454454

455-
if (!inputType.hasStaticShape()) {
456-
return failure();
457-
}
458-
459455
if (isa<FloatType>(inputElementType)) {
460456
// Unlike integer types, floating point types can represent infinity.
461-
auto minClamp =
457+
const auto minClamp =
462458
llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
463-
auto maxClamp =
459+
const auto maxClamp =
464460
llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
465-
bool isMin = minClamp.isNegInfinity();
466-
bool isMax = maxClamp.isInfinity();
461+
const bool isMin = minClamp.isNegInfinity();
462+
const bool isMax = maxClamp.isInfinity();
467463

468464
if (isMin && isMax) {
469465
rewriter.replaceOp(op, input);
@@ -472,18 +468,19 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
472468
return failure();
473469
}
474470

475-
if (inputElementType.isUnsignedInteger()) {
476-
int64_t minClamp =
477-
llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getUInt();
478-
int64_t maxClamp =
479-
llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getUInt();
471+
// i1 types are boolean in TOSA
472+
const bool isBoolean = inputElementType.isInteger(1);
473+
if (inputElementType.isUnsignedInteger() || isBoolean) {
474+
const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
475+
.getValue()
476+
.getZExtValue();
477+
const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
478+
.getValue()
479+
.getZExtValue();
480480

481-
int64_t intMin =
482-
APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
483-
.getZExtValue();
484-
int64_t intMax =
485-
APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
486-
.getZExtValue();
481+
const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
482+
const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
483+
const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
487484

488485
if (minClamp <= intMin && maxClamp >= intMax) {
489486
rewriter.replaceOp(op, input);
@@ -493,17 +490,14 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
493490
}
494491

495492
if (llvm::isa<IntegerType>(inputElementType)) {
496-
int64_t minClamp =
493+
const int64_t minClamp =
497494
llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
498-
int64_t maxClamp =
495+
const int64_t maxClamp =
499496
llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
500497

501-
int64_t intMin =
502-
APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
503-
.getSExtValue();
504-
int64_t intMax =
505-
APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
506-
.getSExtValue();
498+
const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
499+
const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
500+
const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
507501

508502
if (minClamp <= intMin && maxClamp >= intMax) {
509503
rewriter.replaceOp(op, input);

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,26 @@ func.func @clamp_f32_is_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> {
241241

242242
// -----
243243

244+
// CHECK-LABEL: @clamp_boolean_is_noop
245+
func.func @clamp_boolean_is_noop(%arg0: tensor<4xi1>) -> tensor<4xi1> {
246+
// CHECK: return %arg0
247+
// CHECK-NOT: tosa.clamp
248+
%0 = tosa.clamp %arg0 {min_val = false, max_val = true} : (tensor<4xi1>) -> tensor<4xi1>
249+
return %0 : tensor<4xi1>
250+
}
251+
252+
// -----
253+
254+
// CHECK-LABEL: @clamp_boolean_dynamic_is_noop
255+
func.func @clamp_boolean_dynamic_is_noop(%arg0: tensor<?xi1>) -> tensor<?xi1> {
256+
// CHECK: return %arg0
257+
// CHECK-NOT: tosa.clamp
258+
%0 = tosa.clamp %arg0 {min_val = false, max_val = true} : (tensor<?xi1>) -> tensor<?xi1>
259+
return %0 : tensor<?xi1>
260+
}
261+
262+
// -----
263+
244264
// CHECK-LABEL: @clamp_int8_is_noop
245265
func.func @clamp_int8_is_noop(%arg0: tensor<4xi8>) -> tensor<4xi8> {
246266
// CHECK: return %arg0

0 commit comments

Comments
 (0)