@@ -339,33 +339,84 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
339339 }
340340};
341341
342+ // Attempts the following transformation:
343+ //
344+ // For integers a, b, a', and b' such that [a, b] ∩ [a', b'] ≠ ∅ and input
345+ // tensor X the following identity holds:
346+ //
347+ // CLAMP(CLAMP(X, a, b), a', b') = CLAMP(X, max(a, a'), min(b, b'))
348+ //
349+ // subject to the following valid NaN propagation semantics:
350+ // --------------------------------------------
351+ // | OUTER CLAMP | INNER CLAMP | RESULT MODE |
352+ // |-------------|--------------|-------------|
353+ // | PROPAGATE | PROPAGATE | PROPAGATE |
354+ // | PROPAGATE | IGNORE | IGNORE |
355+ // | IGNORE | PROPAGATE | INVALID |
356+ // | IGNORE | IGNORE | IGNORE |
357+ // |------------------------------------------|
358+
342359struct ClampClampOptimization : public OpRewritePattern <tosa::ClampOp> {
343360 using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
344361
362+ // Helper structure to describe the range of a clamp operation.
363+ template <typename T>
364+ struct ClampRange {
365+ ClampRange (const T &start, const T &end) : start(start), end(end) {}
366+ T start;
367+ T end;
368+
369+ // Helper function to determine if two Clamp ranges intersect.
370+ bool intersects (const ClampRange<T> &otherRange) {
371+ return start < otherRange.end && otherRange.start < end;
372+ }
373+ };
374+
345375 LogicalResult matchAndRewrite (tosa::ClampOp op,
346376 PatternRewriter &rewriter) const override {
347- Value input = op. getInput ();
348-
349- Operation *definingOp = input. getDefiningOp ();
350- if (!definingOp )
377+ // Check the input to the CLAMP op is itself a CLAMP.
378+ auto clampOp =
379+ dyn_cast_if_present<tosa::ClampOp>(op. getInput (). getDefiningOp () );
380+ if (!clampOp )
351381 return failure ();
352382
353- if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) {
354- auto minFp = std::max (op.getMinFp (), clampOp.getMinFp ()).convertToFloat ();
355- auto maxFp = std::min (op.getMaxFp (), clampOp.getMaxFp ()).convertToFloat ();
383+ // Check we have a valid NaN propagation combination.
384+ const auto opNanMode = op.getNanMode ();
385+ const auto clampNanMode = clampOp.getNanMode ();
386+ if (opNanMode == " IGNORE" && clampNanMode == " PROPAGATE" )
387+ return failure ();
356388
357- auto minInt = std::max (op.getMinInt (), clampOp.getMinInt ());
358- auto maxInt = std::min (op.getMaxInt (), clampOp.getMaxInt ());
389+ // Check we have intersecting ranges.
390+ const auto opMinInt = op.getMinInt ();
391+ const auto opMaxInt = op.getMaxInt ();
392+ const auto clampOpMinInt = clampOp.getMinInt ();
393+ const auto clampOpMaxInt = clampOp.getMaxInt ();
394+ ClampRange<std::int64_t > opRangeIntRange (opMinInt, opMaxInt);
395+ ClampRange<std::int64_t > clampRangeIntRange (clampOpMinInt, clampOpMaxInt);
396+ if (!opRangeIntRange.intersects (clampRangeIntRange))
397+ return failure ();
359398
360- rewriter.replaceOpWithNewOp <tosa::ClampOp>(
361- op, op.getType (), clampOp.getInput (),
362- rewriter.getI64IntegerAttr (minInt),
363- rewriter.getI64IntegerAttr (maxInt), rewriter.getF32FloatAttr (minFp),
364- rewriter.getF32FloatAttr (maxFp));
365- return success ();
366- }
399+ const auto opMinFloat = op.getMinFp ();
400+ const auto opMaxFloat = op.getMaxFp ();
401+ const auto clampOpMinFloat = clampOp.getMinFp ();
402+ const auto clampOpMaxFloat = clampOp.getMaxFp ();
403+ ClampRange opRangeFloatRange (opMinFloat, opMaxFloat);
404+ ClampRange clampRangeFloatRange (clampOpMinFloat, clampOpMaxFloat);
405+ if (!opRangeFloatRange.intersects (clampRangeFloatRange))
406+ return failure ();
367407
368- return failure ();
408+ // Run the transformation.
409+ const auto minFp = std::max (opMinFloat, clampOpMinFloat).convertToFloat ();
410+ const auto maxFp = std::min (opMaxFloat, clampOpMaxFloat).convertToFloat ();
411+ const auto minInt = std::max (opMinInt, clampOpMinInt);
412+ const auto maxInt = std::min (opMaxInt, clampOpMaxInt);
413+ rewriter.replaceOpWithNewOp <tosa::ClampOp>(
414+ op, op.getType (), clampOp.getInput (),
415+ rewriter.getI64IntegerAttr (minInt), rewriter.getI64IntegerAttr (maxInt),
416+ rewriter.getF32FloatAttr (minFp), rewriter.getF32FloatAttr (maxFp),
417+ rewriter.getStringAttr ((opNanMode != clampNanMode) ? " IGNORE"
418+ : opNanMode));
419+ return success ();
369420 }
370421};
371422
0 commit comments