Skip to content

Commit bd47095

Browse files
authored
[mlir][tosa] Support boolean types for clamp folder
This PR fixes several bugs in `ClampIsNoOp` pattern. - static shape check is no need. - ensures i1 values are zero extended to support fold boolean types clamp.
1 parent 3e2fadf commit bd47095

File tree

1 file changed

+15
-21
lines changed

1 file changed

+15
-21
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -452,10 +452,6 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
452452
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
453453
auto inputElementType = inputType.getElementType();
454454

455-
if (!inputType.hasStaticShape()) {
456-
return failure();
457-
}
458-
459455
if (isa<FloatType>(inputElementType)) {
460456
// Unlike integer types, floating point types can represent infinity.
461457
auto minClamp =
@@ -472,18 +468,19 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
472468
return failure();
473469
}
474470

475-
if (inputElementType.isUnsignedInteger()) {
476-
int64_t minClamp =
477-
llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getUInt();
478-
int64_t maxClamp =
479-
llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getUInt();
471+
// i1 types are boolean in TOSA
472+
const bool isBoolean = inputElementType.isInteger(1);
473+
if (inputElementType.isUnsignedInteger() || isBoolean) {
474+
int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
475+
.getValue()
476+
.getZExtValue();
477+
int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
478+
.getValue()
479+
.getZExtValue();
480480

481-
int64_t intMin =
482-
APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
483-
.getZExtValue();
484-
int64_t intMax =
485-
APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
486-
.getZExtValue();
481+
unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
482+
int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
483+
int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
487484

488485
if (minClamp <= intMin && maxClamp >= intMax) {
489486
rewriter.replaceOp(op, input);
@@ -498,12 +495,9 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
498495
int64_t maxClamp =
499496
llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
500497

501-
int64_t intMin =
502-
APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
503-
.getSExtValue();
504-
int64_t intMax =
505-
APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
506-
.getSExtValue();
498+
unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
499+
int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
500+
int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
507501

508502
if (minClamp <= intMin && maxClamp >= intMax) {
509503
rewriter.replaceOp(op, input);

0 commit comments

Comments
 (0)