@@ -452,18 +452,14 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
452452 auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput ().getType ());
453453 auto inputElementType = inputType.getElementType ();
454454
455- if (!inputType.hasStaticShape ()) {
456- return failure ();
457- }
458-
459455 if (isa<FloatType>(inputElementType)) {
460456 // Unlike integer types, floating point types can represent infinity.
461- auto minClamp =
457+ const auto minClamp =
462458 llvm::cast<mlir::FloatAttr>(op.getMinValAttr ()).getValue ();
463- auto maxClamp =
459+ const auto maxClamp =
464460 llvm::cast<mlir::FloatAttr>(op.getMaxValAttr ()).getValue ();
465- bool isMin = minClamp.isNegInfinity ();
466- bool isMax = maxClamp.isInfinity ();
461+ const bool isMin = minClamp.isNegInfinity ();
462+ const bool isMax = maxClamp.isInfinity ();
467463
468464 if (isMin && isMax) {
469465 rewriter.replaceOp (op, input);
@@ -472,18 +468,19 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
472468 return failure ();
473469 }
474470
475- if (inputElementType.isUnsignedInteger ()) {
476- int64_t minClamp =
477- llvm::cast<mlir::IntegerAttr>(op.getMinValAttr ()).getUInt ();
478- int64_t maxClamp =
479- llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr ()).getUInt ();
471+ // i1 types are boolean in TOSA
472+ const bool isBoolean = inputElementType.isInteger (1 );
473+ if (inputElementType.isUnsignedInteger () || isBoolean) {
474+ const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr ())
475+ .getValue ()
476+ .getZExtValue ();
477+ const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr ())
478+ .getValue ()
479+ .getZExtValue ();
480480
481- int64_t intMin =
482- APInt::getMinValue (inputElementType.getIntOrFloatBitWidth ())
483- .getZExtValue ();
484- int64_t intMax =
485- APInt::getMaxValue (inputElementType.getIntOrFloatBitWidth ())
486- .getZExtValue ();
481+ const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth ();
482+ const int64_t intMin = APInt::getMinValue (bitWidth).getZExtValue ();
483+ const int64_t intMax = APInt::getMaxValue (bitWidth).getZExtValue ();
487484
488485 if (minClamp <= intMin && maxClamp >= intMax) {
489486 rewriter.replaceOp (op, input);
@@ -493,17 +490,14 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
493490 }
494491
495492 if (llvm::isa<IntegerType>(inputElementType)) {
496- int64_t minClamp =
493+ const int64_t minClamp =
497494 llvm::cast<mlir::IntegerAttr>(op.getMinValAttr ()).getInt ();
498- int64_t maxClamp =
495+ const int64_t maxClamp =
499496 llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr ()).getInt ();
500497
501- int64_t intMin =
502- APInt::getSignedMinValue (inputElementType.getIntOrFloatBitWidth ())
503- .getSExtValue ();
504- int64_t intMax =
505- APInt::getSignedMaxValue (inputElementType.getIntOrFloatBitWidth ())
506- .getSExtValue ();
498+ const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth ();
499+ const int64_t intMin = APInt::getSignedMinValue (bitWidth).getSExtValue ();
500+ const int64_t intMax = APInt::getSignedMaxValue (bitWidth).getSExtValue ();
507501
508502 if (minClamp <= intMin && maxClamp >= intMax) {
509503 rewriter.replaceOp (op, input);
0 commit comments