@@ -452,10 +452,6 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
452452 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput ().getType ());
453453 auto inputElementType = inputType.getElementType ();
454454
455- if (!inputType.hasStaticShape ()) {
456- return failure ();
457- }
458-
459455 if (isa<FloatType>(inputElementType)) {
460456 // Unlike integer types, floating point types can represent infinity.
461457 auto minClamp =
@@ -472,18 +468,19 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
472468 return failure ();
473469 }
474470
475- if (inputElementType.isUnsignedInteger ()) {
476- int64_t minClamp =
477- llvm::cast<mlir::IntegerAttr>(op.getMinValAttr ()).getUInt ();
478- int64_t maxClamp =
479- llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr ()).getUInt ();
471+ // i1 types are boolean in TOSA
472+ const bool isBoolean = inputElementType.isInteger (1 );
473+ if (inputElementType.isUnsignedInteger () || isBoolean) {
474+ int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr ())
475+ .getValue ()
476+ .getZExtValue ();
477+ int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr ())
478+ .getValue ()
479+ .getZExtValue ();
480480
481- int64_t intMin =
482- APInt::getMinValue (inputElementType.getIntOrFloatBitWidth ())
483- .getZExtValue ();
484- int64_t intMax =
485- APInt::getMaxValue (inputElementType.getIntOrFloatBitWidth ())
486- .getZExtValue ();
481+ unsigned bitWidth = inputElementType.getIntOrFloatBitWidth ();
482+ int64_t intMin = APInt::getMinValue (bitWidth).getZExtValue ();
483+ int64_t intMax = APInt::getMaxValue (bitWidth).getZExtValue ();
487484
488485 if (minClamp <= intMin && maxClamp >= intMax) {
489486 rewriter.replaceOp (op, input);
@@ -498,12 +495,9 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
498495 int64_t maxClamp =
499496 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr ()).getInt ();
500497
501- int64_t intMin =
502- APInt::getSignedMinValue (inputElementType.getIntOrFloatBitWidth ())
503- .getSExtValue ();
504- int64_t intMax =
505- APInt::getSignedMaxValue (inputElementType.getIntOrFloatBitWidth ())
506- .getSExtValue ();
498+ unsigned bitWidth = inputElementType.getIntOrFloatBitWidth ();
499+ int64_t intMin = APInt::getSignedMinValue (bitWidth).getSExtValue ();
500+ int64_t intMax = APInt::getSignedMaxValue (bitWidth).getSExtValue ();
507501
508502 if (minClamp <= intMin && maxClamp >= intMax) {
509503 rewriter.replaceOp (op, input);
0 commit comments