@@ -452,18 +452,14 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
452
452
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput ().getType ());
453
453
auto inputElementType = inputType.getElementType ();
454
454
455
- if (!inputType.hasStaticShape ()) {
456
- return failure ();
457
- }
458
-
459
455
if (isa<FloatType>(inputElementType)) {
460
456
// Unlike integer types, floating point types can represent infinity.
461
- auto minClamp =
457
+ const auto minClamp =
462
458
llvm::cast<mlir::FloatAttr>(op.getMinValAttr ()).getValue ();
463
- auto maxClamp =
459
+ const auto maxClamp =
464
460
llvm::cast<mlir::FloatAttr>(op.getMaxValAttr ()).getValue ();
465
- bool isMin = minClamp.isNegInfinity ();
466
- bool isMax = maxClamp.isInfinity ();
461
+ const bool isMin = minClamp.isNegInfinity ();
462
+ const bool isMax = maxClamp.isInfinity ();
467
463
468
464
if (isMin && isMax) {
469
465
rewriter.replaceOp (op, input);
@@ -472,18 +468,19 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
472
468
return failure ();
473
469
}
474
470
475
- if (inputElementType.isUnsignedInteger ()) {
476
- int64_t minClamp =
477
- llvm::cast<mlir::IntegerAttr>(op.getMinValAttr ()).getUInt ();
478
- int64_t maxClamp =
479
- llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr ()).getUInt ();
471
+ // i1 types are boolean in TOSA
472
+ const bool isBoolean = inputElementType.isInteger (1 );
473
+ if (inputElementType.isUnsignedInteger () || isBoolean) {
474
+ const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr ())
475
+ .getValue ()
476
+ .getZExtValue ();
477
+ const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr ())
478
+ .getValue ()
479
+ .getZExtValue ();
480
480
481
- int64_t intMin =
482
- APInt::getMinValue (inputElementType.getIntOrFloatBitWidth ())
483
- .getZExtValue ();
484
- int64_t intMax =
485
- APInt::getMaxValue (inputElementType.getIntOrFloatBitWidth ())
486
- .getZExtValue ();
481
+ const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth ();
482
+ const int64_t intMin = APInt::getMinValue (bitWidth).getZExtValue ();
483
+ const int64_t intMax = APInt::getMaxValue (bitWidth).getZExtValue ();
487
484
488
485
if (minClamp <= intMin && maxClamp >= intMax) {
489
486
rewriter.replaceOp (op, input);
@@ -493,17 +490,14 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
493
490
}
494
491
495
492
if (llvm::isa<IntegerType>(inputElementType)) {
496
- int64_t minClamp =
493
+ const int64_t minClamp =
497
494
llvm::cast<mlir::IntegerAttr>(op.getMinValAttr ()).getInt ();
498
- int64_t maxClamp =
495
+ const int64_t maxClamp =
499
496
llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr ()).getInt ();
500
497
501
- int64_t intMin =
502
- APInt::getSignedMinValue (inputElementType.getIntOrFloatBitWidth ())
503
- .getSExtValue ();
504
- int64_t intMax =
505
- APInt::getSignedMaxValue (inputElementType.getIntOrFloatBitWidth ())
506
- .getSExtValue ();
498
+ const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth ();
499
+ const int64_t intMin = APInt::getSignedMinValue (bitWidth).getSExtValue ();
500
+ const int64_t intMax = APInt::getSignedMaxValue (bitWidth).getSExtValue ();
507
501
508
502
if (minClamp <= intMin && maxClamp >= intMax) {
509
503
rewriter.replaceOp (op, input);
0 commit comments