@@ -658,33 +658,84 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
658658 }
659659};
660660
661+ // Attempts the following transformation:
662+ //
663+ // For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input
664+ // tensor X the following identity holds:
665+ //
666+ // CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b'))
667+ //
668+ // subject to the following valid NaN propagation semantics:
669+ // --------------------------------------------
670+ // | OUTER CLAMP | INNER CLAMP | RESULT MODE |
671+ // |-------------|--------------|-------------|
672+ // | PROPAGATE | PROPAGATE | PROPAGATE |
673+ // | PROPAGATE | IGNORE | IGNORE |
674+ // | IGNORE | PROPAGATE | INVALID |
675+ // | IGNORE | IGNORE | IGNORE |
676+ // |------------------------------------------|
677+
661678struct ClampClampOptimization : public OpRewritePattern <tosa::ClampOp> {
662679 using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
663680
681+ // Helper structure to describe the range of a clamp operation.
682+ template <typename T>
683+ struct ClampRange {
684+ ClampRange (const T &start, const T &end) : start(start), end(end) {}
685+ T start;
686+ T end;
687+
688+ // Helper function to determine if two Clamp ranges intersect.
689+ bool intersects (const ClampRange<T> &otherRange) {
690+ return start < otherRange.end && otherRange.start < end;
691+ }
692+ };
693+
664694 LogicalResult matchAndRewrite (tosa::ClampOp op,
665695 PatternRewriter &rewriter) const override {
666- Value input = op. getInput ();
667-
668- Operation *definingOp = input. getDefiningOp ();
669- if (!definingOp )
696+ // Check the input to the CLAMP op is itself a CLAMP.
697+ auto clampOp =
698+ dyn_cast_if_present<tosa::ClampOp>(op. getInput (). getDefiningOp () );
699+ if (!clampOp )
670700 return failure ();
671701
672- if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) {
673- auto minFp = std::max (op.getMinFp (), clampOp.getMinFp ()).convertToFloat ();
674- auto maxFp = std::min (op.getMaxFp (), clampOp.getMaxFp ()).convertToFloat ();
702+ // Check we have a valid NaN propagation combination.
703+ const auto opNanMode = op.getNanMode ();
704+ const auto clampNanMode = clampOp.getNanMode ();
705+ if (opNanMode == " IGNORE" && clampNanMode == " PROPAGATE" )
706+ return failure ();
675707
676- auto minInt = std::max (op.getMinInt (), clampOp.getMinInt ());
677- auto maxInt = std::min (op.getMaxInt (), clampOp.getMaxInt ());
708+ // Check we have intersecting ranges.
709+ const auto opMinInt = op.getMinInt ();
710+ const auto opMaxInt = op.getMaxInt ();
711+ const auto clampOpMinInt = clampOp.getMinInt ();
712+ const auto clampOpMaxInt = clampOp.getMaxInt ();
713+ ClampRange<std::int64_t > opRangeIntRange (opMinInt, opMaxInt);
714+ ClampRange<std::int64_t > clampRangeIntRange (clampOpMinInt, clampOpMaxInt);
715+ if (!opRangeIntRange.intersects (clampRangeIntRange))
716+ return failure ();
678717
679- rewriter.replaceOpWithNewOp <ClampOp>(
680- op, {op->getLoc (), clampOp->getLoc ()}, op.getType (),
681- clampOp.getInput (), rewriter.getI64IntegerAttr (minInt),
682- rewriter.getI64IntegerAttr (maxInt), rewriter.getF32FloatAttr (minFp),
683- rewriter.getF32FloatAttr (maxFp));
684- return success ();
685- }
718+ const auto opMinFloat = op.getMinFp ();
719+ const auto opMaxFloat = op.getMaxFp ();
720+ const auto clampOpMinFloat = clampOp.getMinFp ();
721+ const auto clampOpMaxFloat = clampOp.getMaxFp ();
722+ ClampRange opRangeFloatRange (opMinFloat, opMaxFloat);
723+ ClampRange clampRangeFloatRange (clampOpMinFloat, clampOpMaxFloat);
724+ if (!opRangeFloatRange.intersects (clampRangeFloatRange))
725+ return failure ();
686726
687- return failure ();
727+ // Run the transformation.
728+ const auto minFp = std::max (opMinFloat, clampOpMinFloat).convertToFloat ();
729+ const auto maxFp = std::min (opMaxFloat, clampOpMaxFloat).convertToFloat ();
730+ const auto minInt = std::max (opMinInt, clampOpMinInt);
731+ const auto maxInt = std::min (opMaxInt, clampOpMaxInt);
732+ rewriter.replaceOpWithNewOp <tosa::ClampOp>(
733+ op, {op->getLoc (), clampOp->getLoc ()}, op.getType (), clampOp.getInput (),
734+ rewriter.getI64IntegerAttr (minInt), rewriter.getI64IntegerAttr (maxInt),
735+ rewriter.getF32FloatAttr (minFp), rewriter.getF32FloatAttr (maxFp),
736+ rewriter.getStringAttr ((opNanMode != clampNanMode) ? " IGNORE"
737+ : opNanMode));
738+ return success ();
688739 }
689740};
690741
0 commit comments