Skip to content

Commit b074e04

Browse files
author
hyun gyu kim
committed
[TIR][Schedule] FuseReductionEpilogue: Add Clipping pattern support
Currently, the FuseReductionEpilogue primitive only supports Bias (addition) and BiasReLU (addition + ReLU) epilogue patterns. However, clipping operations (min(max(x, lower), upper)) are commonly used in deep learning models and would benefit from the same fusion optimization. This commit extends FuseReductionEpilogue to support Clipping patterns by: 1. Adding EpilogueType::Clipping to the enum to distinguish clipping patterns from other epilogue types. 2. Adding clipping_lower_ and clipping_upper_ members to ReductionEpilogueFuser to store clipping bounds extracted from the epilogue pattern. 3. Extending AnalyzeEpiloguePattern to detect clipping patterns: - min(max(temp, lower), upper) - max(min(temp, upper), lower) - All commutative variants of min/max at each level 4. Updating BiasReLU pattern matching to handle max(0, x) form in addition to max(x, 0) for better commutativity support. 5. Modifying CreateFusedReductionBlock to apply clipping to the init value: init = min(max(0, lower), upper) 6. Updating BufferReplacer to apply clipping per-iteration: value = min(max(value, lower), upper) 7. Adding validation in BodyPatternAllowFusion to ensure temp appears exactly once in clipping patterns. 8. Creating comprehensive test coverage with 8 test cases: - Basic fusion test - Numerical correctness verification - Multiple epilogue blocks test - 5 commutative variant tests This implementation follows the same per-iteration semantics as BiasReLU, where clipping is applied at each reduction step rather than post-reduction. This semantic change is documented in the docstring with a warning about potential numerical differences. The test suite verifies that all commutative forms of clipping patterns are correctly recognized and that the fused implementation produces numerically identical results to the per-iteration reference implementation.
1 parent d48cd25 commit b074e04

File tree

1 file changed

+0
-1
lines changed

1 file changed

+0
-1
lines changed

tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_clipping.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,4 +269,3 @@ def test_func(
269269

270270
if __name__ == "__main__":
271271
tvm.testing.main()
272-

0 commit comments

Comments
 (0)