From dc296e6662b920a7c1642501f3e38e57238ae807 Mon Sep 17 00:00:00 2001 From: hyun gyu kim Date: Wed, 26 Nov 2025 13:55:36 +0900 Subject: [PATCH 1/7] [TIR][Schedule] FuseReductionEpilogue: Add ReLU support The FuseReductionEpilogue primitive currently supports fusing bias addition epilogues into reduction blocks. This commit extends the primitive to also support ReLU activation functions in epilogue blocks, enabling fusion of patterns like max(temp + bias, 0) into the reduction computation. The implementation adds an EpilogueType enumeration to distinguish between Bias and BiasReLU patterns. The AnalyzeEpiloguePattern method is extended to detect ReLU patterns by checking for MaxNode expressions with zero constants. This commit also adds comprehensive tests in test_tir_schedule_fuse_reduction_epilogue_relu.py, following the same patterns as the existing bias tests. The tests verify structural equality, numerical correctness with per-iteration ReLU semantics, and multiple epilogue block scenarios. All tests pass successfully. --- .gitignore | 3 + src/tir/schedule/primitive/compute_inline.cc | 97 +++++++- ...r_schedule_fuse_reduction_epilogue_relu.py | 229 ++++++++++++++++++ 3 files changed, 318 insertions(+), 11 deletions(-) create mode 100644 tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_relu.py diff --git a/.gitignore b/.gitignore index 5bcbd5e37314..6fa10a5e7651 100644 --- a/.gitignore +++ b/.gitignore @@ -274,3 +274,6 @@ tvm-site/ # GDB history file .gdb_history + +# Less command history file +.lesshst diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index cc3785d5c103..b0684a6738d2 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -988,6 +988,12 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sre * \brief Helper to fuse epilogue block into reduction block * Analyzes epilogue pattern and transforms reduction init/update */ +// Epilogue type enumeration +enum class EpilogueType { + Bias, // temp + C + BiasReLU, // max(temp + C, 0) +}; + class ReductionEpilogueFuser : public BaseInliner { public: explicit ReductionEpilogueFuser(const Buffer& reduction_buffer, const BlockNode* reduction_block, @@ -995,7 +1001,19 @@ class ReductionEpilogueFuser : public BaseInliner { const StmtSRef& scope_root_sref) : BaseInliner(reduction_buffer, epilogue_block_realize->block, scope_root_sref), reduction_block_(reduction_block), - epilogue_block_(epilogue_block_realize->block.get()) {} + epilogue_block_(epilogue_block_realize->block.get()), + epilogue_type_(EpilogueType::Bias) { + // Disable opaque access check for epilogue fusion + // Epilogue blocks can read multiple buffers (temp + bias), which is allowed + has_opaque_access = false; + } + + // Override CheckOpaqueAccess to allow multiple buffer reads + void CheckOpaqueAccess(const VarNode* buffer_var) { + // For epilogue fusion, we allow multiple buffer reads (temp + bias) + // So we don't check for opaque access + // BaseInliner::CheckOpaqueAccess(buffer_var); // Don't call base class + } bool BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize); @@ -1012,18 +1030,21 @@ class ReductionEpilogueFuser : public BaseInliner { const BufferStoreNode* from) { struct Extractor : public ExprVisitor { void VisitExpr_(const BufferLoadNode* load) final { - if (load->buffer.get() == buffer) { + if (load->buffer.same_as(buffer)) { result.push_back(load); } + // Continue visiting child nodes (indices) ExprVisitor::VisitExpr_(load); } - const BufferNode* buffer; + Buffer buffer; std::vector result; } extractor; - extractor.buffer = buffer.get(); + extractor.buffer = buffer; + // Visit indices first (though they typically don't contain BufferLoad) for (const PrimExpr& expr : from->indices) { extractor(expr); } + // Visit the value expression (e.g., max(temp + C, 0) for ReLU) extractor(from->value); return std::move(extractor.result); } @@ -1036,6 +1057,7 @@ class ReductionEpilogueFuser : public BaseInliner { BufferRegion epilogue_output_region_{nullptr}; // Write region of D Buffer epilogue_addend_buffer_{nullptr}; // Addend buffer C BufferRegion epilogue_addend_region_{nullptr}; // Read region of C + EpilogueType epilogue_type_; // Type of epilogue operation }; bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize) { @@ -1077,7 +1099,7 @@ bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue } bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) { - // Pattern: temp[i,j] + C[i,j] or C[i,j] + temp[i,j] + // Pattern 1: temp[i,j] + C[i,j] or C[i,j] + temp[i,j] (Bias) if (const auto* add = value.as()) { const auto* load_a = add->a.as(); const auto* load_b = add->b.as(); @@ -1088,10 +1110,40 @@ bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) { // Ensure exactly one operand is from the reduction buffer if (a_is_target != b_is_target) { epilogue_addend_ = a_is_target ? add->b : add->a; + epilogue_type_ = EpilogueType::Bias; return true; } } + // Pattern 2: max(temp[i,j] + C[i,j], 0) or max(C[i,j] + temp[i,j], 0) (BiasReLU) + if (const auto* max_node = value.as()) { + // Check if second operand is zero (ReLU: max(x, 0)) + // Support both integer and float zero constants + bool is_zero_const = false; + if (tir::is_zero(max_node->b)) { + is_zero_const = true; + } else if (const auto* float_imm = max_node->b.as()) { + is_zero_const = (float_imm->value == 0.0); + } + if (is_zero_const) { + // Check if first operand is AddNode + if (const auto* add = max_node->a.as()) { + const auto* load_a = add->a.as(); + const auto* load_b = add->b.as(); + + bool a_is_target = load_a && load_a->buffer.same_as(inlined_buffer_); + bool b_is_target = load_b && load_b->buffer.same_as(inlined_buffer_); + + // Ensure exactly one operand is from the reduction buffer + if (a_is_target != b_is_target) { + epilogue_addend_ = a_is_target ? add->b : add->a; + epilogue_type_ = EpilogueType::BiasReLU; + return true; + } + } + } + } + return false; } @@ -1158,20 +1210,40 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti var_map[epilogue_data_vars[i]] = reduction_data_vars[i]; } - // 2. Change init to epilogue value: D[vi, vj] = C[vi, vj] - BufferStore new_init_store(epilogue_output_buffer_, Substitute(epilogue_addend_, var_map), - Substitute(epilogue_output_indices_, var_map)); + // 2. Change init to epilogue value based on epilogue type + BufferStore new_init_store; + if (epilogue_type_ == EpilogueType::BiasReLU) { + // For ReLU, init should be max(C[vi, vj], 0) to match per-iteration ReLU semantics + PrimExpr init_value = Substitute(epilogue_addend_, var_map); + PrimExpr zero = tir::make_zero(init_value.dtype()); + new_init_store = BufferStore(epilogue_output_buffer_, Max(init_value, zero), + Substitute(epilogue_output_indices_, var_map)); + } else { + // Bias: D[vi, vj] = C[vi, vj] + new_init_store = BufferStore(epilogue_output_buffer_, Substitute(epilogue_addend_, var_map), + Substitute(epilogue_output_indices_, var_map)); + } new_block->init = new_init_store; // 3. Replace output buffer from temp to D in body class BufferReplacer : public StmtExprMutator { public: - BufferReplacer(Buffer old_buf, Buffer new_buf) : old_buffer_(old_buf), new_buffer_(new_buf) {} + BufferReplacer(Buffer old_buf, Buffer new_buf, EpilogueType epilogue_type, DataType dtype) + : old_buffer_(old_buf), + new_buffer_(new_buf), + epilogue_type_(epilogue_type), + dtype_(dtype) {} Stmt VisitStmt_(const BufferStoreNode* op) final { BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); if (store->buffer.same_as(old_buffer_)) { - return BufferStore(new_buffer_, store->value, store->indices); + PrimExpr new_value = store->value; + // For ReLU, apply max per iteration to match per-iteration ReLU semantics + if (epilogue_type_ == EpilogueType::BiasReLU) { + PrimExpr zero = tir::make_zero(dtype_); + new_value = Max(new_value, zero); + } + return BufferStore(new_buffer_, new_value, store->indices); } return store; } @@ -1187,9 +1259,12 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti private: Buffer old_buffer_; Buffer new_buffer_; + EpilogueType epilogue_type_; + DataType dtype_; }; - BufferReplacer replacer(inlined_buffer_, epilogue_output_buffer_); + DataType dtype = epilogue_output_buffer_->dtype; + BufferReplacer replacer(inlined_buffer_, epilogue_output_buffer_, epilogue_type_, dtype); new_block->body = replacer(reduction_block->body); // 4. Update write regions diff --git a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_relu.py b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_relu.py new file mode 100644 index 000000000000..66e5e52e43db --- /dev/null +++ b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_relu.py @@ -0,0 +1,229 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import ( + verify_trace_roundtrip, + assert_structural_equal_ignore_global_symbol, +) +import numpy as np + +# pylint: disable=no-member,invalid-name,unused-variable + + +@T.prim_func +def matmul_bias_relu_before( + A: T.Buffer((16, 16), "float32"), + B: T.Buffer((16, 16), "float32"), + C: T.Buffer((16, 16), "float32"), + D: T.Buffer((16, 16), "float32"), +) -> None: + """Original function with separate reduction and epilogue blocks (Bias + ReLU).""" + temp = T.alloc_buffer((16, 16), dtype="float32") + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.float32(0) + temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk] + + for i, j in T.grid(16, 16): + with T.block("bias_relu"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = T.max(temp[vi, vj] + C[vi, vj], T.float32(0)) + + +@T.prim_func +def matmul_bias_relu_before_per_iteration( + A: T.Buffer((16, 16), "float32"), + B: T.Buffer((16, 16), "float32"), + C: T.Buffer((16, 16), "float32"), + D: T.Buffer((16, 16), "float32"), +) -> None: + """Original function with per-iteration ReLU (same semantics as fused).""" + temp = T.alloc_buffer((16, 16), dtype="float32") + for i, j in T.grid(16, 16): + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) + temp[vi, vj] = T.max(C[vi, vj], T.float32(0)) # ReLU on bias + + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + # Per-iteration ReLU + temp[vi, vj] = T.max(temp[vi, vj] + A[vi, vk] * B[vj, vk], T.float32(0)) + + for i, j in T.grid(16, 16): + with T.block("copy"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = temp[vi, vj] + + +@T.prim_func +def matmul_bias_relu_expected( + A: T.Buffer((16, 16), "float32"), + B: T.Buffer((16, 16), "float32"), + C: T.Buffer((16, 16), "float32"), + D: T.Buffer((16, 16), "float32"), +) -> None: + """Expected function after fusion (Bias + ReLU).""" + temp = T.alloc_buffer((16, 16), dtype="float32") + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) + T.writes(D[vi, vj]) + with T.init(): + D[vi, vj] = T.max(C[vi, vj], T.float32(0)) + D[vi, vj] = T.max(D[vi, vj] + A[vi, vk] * B[vj, vk], T.float32(0)) + + +def test_matmul_bias_relu(): + """Test fusion of matmul with bias + ReLU epilogue.""" + sch = tir.Schedule(matmul_bias_relu_before, debug_mask="all") + sch.fuse_reduction_epilogue("matmul", "bias_relu") + assert_structural_equal_ignore_global_symbol(sch.mod["main"], matmul_bias_relu_expected) + verify_trace_roundtrip(sch=sch, mod=matmul_bias_relu_before) + + +def test_matmul_bias_relu_correctness_unified(): + """Test that original and fused produce identical results with per-iteration ReLU.""" + A_np = np.random.randn(16, 16).astype("float32") + B_np = np.random.randn(16, 16).astype("float32") + C_np = np.random.randn(16, 16).astype("float32") + + # NumPy reference for per-iteration ReLU + # Simulate per-iteration ReLU behavior + # Original code computes A[vi, vk] * B[vj, vk] which is A[i, k] * B[j, k] + # For each k: add outer product of A[:, k] and B[:, k] + D_ref = np.maximum(C_np, 0) # init with ReLU on bias + for k in range(16): + # A[:, k] is shape (16,), B[:, k] is shape (16,) + # Outer product: A[:, k] * B[:, k] for all i, j = A[i, k] * B[j, k] + # Using broadcasting: A[:, k:k+1] * B[:, k:k+1].T gives (16, 1) * (1, 16) = (16, 16) + D_ref = np.maximum(D_ref + np.outer(A_np[:, k], B_np[:, k]), 0) + + # TVM execution (original with per-iteration ReLU) + mod_original = tvm.compile(matmul_bias_relu_before_per_iteration, target="llvm") + D_original_tvm = tvm.runtime.tensor(np.zeros((16, 16), dtype="float32")) + mod_original( + tvm.runtime.tensor(A_np), + tvm.runtime.tensor(B_np), + tvm.runtime.tensor(C_np), + D_original_tvm, + ) + + # TVM execution (fused) + sch = tir.Schedule(matmul_bias_relu_before) + sch.fuse_reduction_epilogue("matmul", "bias_relu") + mod_fused = tvm.compile(sch.mod["main"], target="llvm") + D_fused_tvm = tvm.runtime.tensor(np.zeros((16, 16), dtype="float32")) + mod_fused( + tvm.runtime.tensor(A_np), + tvm.runtime.tensor(B_np), + tvm.runtime.tensor(C_np), + D_fused_tvm, + ) + + D_original = D_original_tvm.numpy() + D_fused = D_fused_tvm.numpy() + + # Now both should match exactly + np.testing.assert_allclose(D_original, D_ref, rtol=1e-5, atol=1e-6) + np.testing.assert_allclose(D_fused, D_ref, rtol=1e-5, atol=1e-6) + np.testing.assert_allclose(D_original, D_fused, rtol=1e-5, atol=1e-6) + + +@T.prim_func +def matmul_bias_relu_multiple_epilogue_before( + A: T.Buffer((16, 16), "float32"), + B: T.Buffer((16, 16), "float32"), + C: T.Buffer((16, 16), "float32"), + D: T.Buffer((16, 16), "float32"), + E: T.Buffer((16, 16), "float32"), +) -> None: + """Original function with separate reduction and multiple epilogue blocks (one with ReLU, one without).""" + temp = T.alloc_buffer((16, 16), dtype="float32") + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.float32(0) + temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk] + + for i, j in T.grid(16, 16): + with T.block("bias_relu"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = T.max(temp[vi, vj] + C[vi, vj], T.float32(0)) + + for i, j in T.grid(16, 16): + with T.block("bias"): + vi, vj = T.axis.remap("SS", [i, j]) + E[vi, vj] = temp[vi, vj] + C[vi, vj] + + +@T.prim_func +def matmul_bias_relu_multiple_epilogue_expected( + A: T.Buffer((16, 16), "float32"), + B: T.Buffer((16, 16), "float32"), + C: T.Buffer((16, 16), "float32"), + D: T.Buffer((16, 16), "float32"), + E: T.Buffer((16, 16), "float32"), +) -> None: + """Expected function after fusion (Bias + ReLU) with multiple epilogue blocks.""" + temp = T.alloc_buffer((16, 16), dtype="float32") + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) + T.writes(D[vi, vj]) + with T.init(): + D[vi, vj] = T.max(C[vi, vj], T.float32(0)) + D[vi, vj] = T.max(D[vi, vj] + A[vi, vk] * B[vj, vk], T.float32(0)) + for i, j in T.grid(16, 16): + with T.block("bias"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(temp[vi, vj], C[vi, vj]) + T.writes(E[vi, vj]) + E[vi, vj] = temp[vi, vj] + C[vi, vj] + + +def test_matmul_bias_relu_multiple_epilogue(): + """Test fusion with multiple epilogue blocks - one with ReLU, one without. + + Following the same pattern as test_fuse_reduction_epilogue_multiple_epilogue, + this test verifies that fusion works correctly when there are multiple + epilogue blocks. The temp buffer is kept because the second epilogue block + still needs it. + """ + sch = tir.Schedule(matmul_bias_relu_multiple_epilogue_before, debug_mask="all") + sch.fuse_reduction_epilogue("matmul", "bias_relu") + assert_structural_equal_ignore_global_symbol( + sch.mod["main"], matmul_bias_relu_multiple_epilogue_expected + ) + verify_trace_roundtrip(sch=sch, mod=matmul_bias_relu_multiple_epilogue_before) + + mod = tvm.compile(sch.mod["main"], target="llvm") + assert mod is not None + + +if __name__ == "__main__": + tvm.testing.main() From d48cd25149ae7aad553342988a5a09f3dc102bbc Mon Sep 17 00:00:00 2001 From: hyun gyu kim Date: Thu, 27 Nov 2025 10:52:20 +0900 Subject: [PATCH 2/7] [TIR][Schedule] FuseReductionEpilogue: Add Clipping pattern support Currently, the FuseReductionEpilogue primitive only supports Bias (addition) and BiasReLU (addition + ReLU) epilogue patterns. However, clipping operations (min(max(x, lower), upper)) are commonly used in deep learning models and would benefit from the same fusion optimization. This commit extends FuseReductionEpilogue to support Clipping patterns by: 1. Adding EpilogueType::Clipping to the enum to distinguish clipping patterns from other epilogue types. 2. Adding clipping_lower_ and clipping_upper_ members to ReductionEpilogueFuser to store clipping bounds extracted from the epilogue pattern. 3. Extending AnalyzeEpiloguePattern to detect clipping patterns: - min(max(temp, lower), upper) - max(min(temp, upper), lower) - All commutative variants of min/max at each level 4. Updating BiasReLU pattern matching to handle max(0, x) form in addition to max(x, 0) for better commutativity support. 5. Modifying CreateFusedReductionBlock to apply clipping to the init value: init = min(max(0, lower), upper) 6. Updating BufferReplacer to apply clipping per-iteration: value = min(max(value, lower), upper) 7. Adding validation in BodyPatternAllowFusion to ensure temp appears exactly once in clipping patterns. 8. Creating comprehensive test coverage with 8 test cases: - Basic fusion test - Numerical correctness verification - Multiple epilogue blocks test - 5 commutative variant tests This implementation follows the same per-iteration semantics as BiasReLU, where clipping is applied at each reduction step rather than post-reduction. This semantic change is documented in the docstring with a warning about potential numerical differences. The test suite verifies that all commutative forms of clipping patterns are correctly recognized and that the fused implementation produces numerically identical results to the per-iteration reference implementation. --- 3rdparty/tvm-ffi | 2 +- ffi/3rdparty/dlpack | 1 + ffi/3rdparty/libbacktrace | 1 + python/tvm/tir/schedule/schedule.py | 31 +- src/tir/schedule/primitive/compute_inline.cc | 143 +++++++-- ...hedule_fuse_reduction_epilogue_clipping.py | 272 ++++++++++++++++++ 6 files changed, 431 insertions(+), 19 deletions(-) create mode 160000 ffi/3rdparty/dlpack create mode 160000 ffi/3rdparty/libbacktrace create mode 100644 tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_clipping.py diff --git a/3rdparty/tvm-ffi b/3rdparty/tvm-ffi index ae346ec92a3c..f703a0cf9358 160000 --- a/3rdparty/tvm-ffi +++ b/3rdparty/tvm-ffi @@ -1 +1 @@ -Subproject commit ae346ec92a3c386f1376064ae086aae72947c329 +Subproject commit f703a0cf9358fa30d8faee719f905c58d8ca6ee3 diff --git a/ffi/3rdparty/dlpack b/ffi/3rdparty/dlpack new file mode 160000 index 000000000000..3ea601bb4130 --- /dev/null +++ b/ffi/3rdparty/dlpack @@ -0,0 +1 @@ +Subproject commit 3ea601bb413074c49a77c4ce3218bc08f8c4703c diff --git a/ffi/3rdparty/libbacktrace b/ffi/3rdparty/libbacktrace new file mode 160000 index 000000000000..793921876c98 --- /dev/null +++ b/ffi/3rdparty/libbacktrace @@ -0,0 +1 @@ +Subproject commit 793921876c981ce49759114d7bb89bb89b2d3a2d diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 0d41ffe94307..b432fd35c3ee 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2356,14 +2356,41 @@ def fuse_reduction_epilogue( It requires: 1) The reduction block is a complete reduction block 2) The epilogue block only reads from the reduction block's output - 3) The epilogue performs a simple addition: output = reduction_result + bias + 3) The epilogue matches one of the supported patterns: + - Bias: ``output = reduction_result + bias`` + - BiasReLU: ``output = max(reduction_result + bias, 0)`` + - Clipping: ``output = min(max(reduction_result, lower), upper)`` + or their commutative variants + + .. warning:: + + **Semantic Change for Non-Linear Epilogues (BiasReLU, Clipping):** + + For non-linear epilogues (BiasReLU and Clipping), fusion changes the + computation semantics from post-reduction application to per-iteration + application. This can lead to different numerical results. + + **Example with Clipping to [-5, 5] and inputs [6, -2]:** + + - **Post-reduction clipping** (original): ``clip(sum([6, -2])) = clip(4) = 4`` + - **Per-iteration clipping** (fused): ``acc=0 → clip(0+6)=5 → clip(5+(-2))=3`` + + The fused version applies clipping at each reduction iteration, which + may be an intended optimization for some models but can cause unexpected + correctness issues if users are not aware of this behavior. + + For linear epilogues (Bias), fusion preserves exact numerical equivalence. Parameters ---------- reduction_block : Union[BlockRV, str] The reduction block (e.g., matmul) epilogue_block : Union[BlockRV, str] - The epilogue block to be fused (e.g., bias add) + The epilogue block to be fused (e.g., bias add, ReLU, clipping) + + Examples + -------- + See :py:func:`test_tir_schedule_fuse_reduction_epilogue` for examples. """ reduction_block = self._normalize_block_arg(reduction_block) epilogue_block = self._normalize_block_arg(epilogue_block) diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index b0684a6738d2..27714fda6717 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -992,6 +992,7 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sre enum class EpilogueType { Bias, // temp + C BiasReLU, // max(temp + C, 0) + Clipping, // min(max(temp, lower), upper) }; class ReductionEpilogueFuser : public BaseInliner { @@ -1058,6 +1059,8 @@ class ReductionEpilogueFuser : public BaseInliner { Buffer epilogue_addend_buffer_{nullptr}; // Addend buffer C BufferRegion epilogue_addend_region_{nullptr}; // Read region of C EpilogueType epilogue_type_; // Type of epilogue operation + PrimExpr clipping_lower_{nullptr}; // Lower bound for clipping + PrimExpr clipping_upper_{nullptr}; // Upper bound for clipping }; bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize) { @@ -1080,19 +1083,28 @@ bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue return false; } - // 4. Analyze epilogue pattern: D[i,j] = temp[i,j] + C[i,j] + // 4. Analyze epilogue pattern: D[i,j] = temp[i,j] + C[i,j] or + // D[i,j] = min(max(temp[i,j], lower), upper) if (!AnalyzeEpiloguePattern(inlined_store_->value)) { - // Failure: epilogue is not a simple addition pattern + // Failure: epilogue is not a supported pattern (Bias, BiasReLU, or Clipping) return false; } - // 5. Check if producer is a reduction block + // 5. For Clipping pattern, verify temp appears exactly once + if (epilogue_type_ == EpilogueType::Clipping) { + if (loads.size() != 1) { + // Failure: temp must appear exactly once in clipping pattern + return false; + } + } + + // 6. Check if producer is a reduction block if (!IsReductionBlock(reduction_block_)) { // Failure: producer is not a reduction block return false; } - // 6. Extract epilogue information (output buffer, indices, regions, etc.) + // 7. Extract epilogue information (output buffer, indices, regions, etc.) ExtractEpilogueInfo(); return true; @@ -1115,19 +1127,97 @@ bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) { } } - // Pattern 2: max(temp[i,j] + C[i,j], 0) or max(C[i,j] + temp[i,j], 0) (BiasReLU) + // Pattern 2: min(max(temp[i,j], lower), upper) or max(min(temp[i,j], upper), lower) (Clipping) + // Handle all commutative variants of min/max at each level. + + // Helper to check if an expression is a load from the reduction buffer, and + // return the other operand as `other` if so. + auto match_buffer_in_commutative_op = [this](const PrimExpr& a, const PrimExpr& b, + PrimExpr* other) -> bool { + if (const auto* load_a = a.as()) { + if (load_a->buffer.same_as(inlined_buffer_)) { + *other = b; + return true; + } + } + if (const auto* load_b = b.as()) { + if (load_b->buffer.same_as(inlined_buffer_)) { + *other = a; + return true; + } + } + return false; + }; + + // Check for min(max(temp, lower), upper) and commutative variants + if (const auto* min_node = value.as()) { + const MaxNode* max_node = nullptr; + PrimExpr upper; + // Try both (a, b) as possible positions of the inner max + if ((max_node = min_node->a.as())) { + upper = min_node->b; + } else if ((max_node = min_node->b.as())) { + upper = min_node->a; + } + if (max_node != nullptr) { + PrimExpr lower; + if (match_buffer_in_commutative_op(max_node->a, max_node->b, &lower)) { + clipping_lower_ = lower; + clipping_upper_ = upper; + epilogue_type_ = EpilogueType::Clipping; + return true; + } + } + } + + // Check for max(min(temp[i,j], upper), lower) and commutative variants if (const auto* max_node = value.as()) { - // Check if second operand is zero (ReLU: max(x, 0)) - // Support both integer and float zero constants + const MinNode* min_node = nullptr; + PrimExpr lower; + // Try both (a, b) as possible positions of the inner min + if ((min_node = max_node->a.as())) { + lower = max_node->b; + } else if ((min_node = max_node->b.as())) { + lower = max_node->a; + } + if (min_node != nullptr) { + PrimExpr upper; + if (match_buffer_in_commutative_op(min_node->a, min_node->b, &upper)) { + clipping_lower_ = lower; + clipping_upper_ = upper; + epilogue_type_ = EpilogueType::Clipping; + return true; + } + } + } + + // Pattern 3: max(temp[i,j] + C[i,j], 0) or max(C[i,j] + temp[i,j], 0) (BiasReLU) + // Also handle max(0, temp[i,j] + C[i,j]) or max(0, C[i,j] + temp[i,j]) + if (const auto* max_node = value.as()) { + // Check if either operand is zero (ReLU: max(x, 0) or max(0, x)) + // Support both integer and float zero constants. + const PrimExpr* add_candidate = nullptr; bool is_zero_const = false; - if (tir::is_zero(max_node->b)) { + auto is_zero_expr = [](const PrimExpr& expr) -> bool { + if (tir::is_zero(expr)) { + return true; + } + if (const auto* float_imm = expr.as()) { + return float_imm->value == 0.0; + } + return false; + }; + + if (is_zero_expr(max_node->a)) { + is_zero_const = true; + add_candidate = &max_node->b; + } else if (is_zero_expr(max_node->b)) { is_zero_const = true; - } else if (const auto* float_imm = max_node->b.as()) { - is_zero_const = (float_imm->value == 0.0); + add_candidate = &max_node->a; } - if (is_zero_const) { - // Check if first operand is AddNode - if (const auto* add = max_node->a.as()) { + + if (is_zero_const && add_candidate != nullptr) { + if (const auto* add = add_candidate->as()) { const auto* load_a = add->a.as(); const auto* load_b = add->b.as(); @@ -1218,6 +1308,14 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti PrimExpr zero = tir::make_zero(init_value.dtype()); new_init_store = BufferStore(epilogue_output_buffer_, Max(init_value, zero), Substitute(epilogue_output_indices_, var_map)); + } else if (epilogue_type_ == EpilogueType::Clipping) { + // For Clipping, init should be min(max(init_value, lower), upper) + // Since init is typically 0, this becomes min(max(0, lower), upper) + PrimExpr init_value = tir::make_zero(epilogue_output_buffer_->dtype); + PrimExpr clipped_init = Min(Max(init_value, Substitute(clipping_lower_, var_map)), + Substitute(clipping_upper_, var_map)); + new_init_store = BufferStore(epilogue_output_buffer_, clipped_init, + Substitute(epilogue_output_indices_, var_map)); } else { // Bias: D[vi, vj] = C[vi, vj] new_init_store = BufferStore(epilogue_output_buffer_, Substitute(epilogue_addend_, var_map), @@ -1228,11 +1326,14 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti // 3. Replace output buffer from temp to D in body class BufferReplacer : public StmtExprMutator { public: - BufferReplacer(Buffer old_buf, Buffer new_buf, EpilogueType epilogue_type, DataType dtype) + BufferReplacer(Buffer old_buf, Buffer new_buf, EpilogueType epilogue_type, DataType dtype, + PrimExpr clipping_lower = PrimExpr(), PrimExpr clipping_upper = PrimExpr()) : old_buffer_(old_buf), new_buffer_(new_buf), epilogue_type_(epilogue_type), - dtype_(dtype) {} + dtype_(dtype), + clipping_lower_(clipping_lower), + clipping_upper_(clipping_upper) {} Stmt VisitStmt_(const BufferStoreNode* op) final { BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); @@ -1242,6 +1343,9 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti if (epilogue_type_ == EpilogueType::BiasReLU) { PrimExpr zero = tir::make_zero(dtype_); new_value = Max(new_value, zero); + } else if (epilogue_type_ == EpilogueType::Clipping) { + // For Clipping, apply min(max(value, lower), upper) per iteration + new_value = Min(Max(new_value, clipping_lower_), clipping_upper_); } return BufferStore(new_buffer_, new_value, store->indices); } @@ -1261,10 +1365,17 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti Buffer new_buffer_; EpilogueType epilogue_type_; DataType dtype_; + PrimExpr clipping_lower_; + PrimExpr clipping_upper_; }; DataType dtype = epilogue_output_buffer_->dtype; - BufferReplacer replacer(inlined_buffer_, epilogue_output_buffer_, epilogue_type_, dtype); + PrimExpr clipping_lower_subst = + epilogue_type_ == EpilogueType::Clipping ? Substitute(clipping_lower_, var_map) : PrimExpr(); + PrimExpr clipping_upper_subst = + epilogue_type_ == EpilogueType::Clipping ? Substitute(clipping_upper_, var_map) : PrimExpr(); + BufferReplacer replacer(inlined_buffer_, epilogue_output_buffer_, epilogue_type_, dtype, + clipping_lower_subst, clipping_upper_subst); new_block->body = replacer(reduction_block->body); // 4. Update write regions diff --git a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_clipping.py b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_clipping.py new file mode 100644 index 000000000000..76341cebb2c4 --- /dev/null +++ b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_clipping.py @@ -0,0 +1,272 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import ( + verify_trace_roundtrip, + assert_structural_equal_ignore_global_symbol, +) +import numpy as np + +# pylint: disable=no-member,invalid-name,unused-variable + + +@T.prim_func +def matmul_clipping_before( + A: T.Buffer((16, 16), "float32"), + B: T.Buffer((16, 16), "float32"), + D: T.Buffer((16, 16), "float32"), + lower: T.float32, + upper: T.float32, +) -> None: + """Original function with separate reduction and clipping epilogue blocks.""" + temp = T.alloc_buffer((16, 16), dtype="float32") + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.float32(0) + temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk] + + for i, j in T.grid(16, 16): + with T.block("clipping"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = T.min(T.max(temp[vi, vj], lower), upper) + + +@T.prim_func +def matmul_clipping_expected( + A: T.Buffer((16, 16), "float32"), + B: T.Buffer((16, 16), "float32"), + D: T.Buffer((16, 16), "float32"), + lower: T.float32, + upper: T.float32, +) -> None: + """Expected function after fusion (Clipping).""" + temp = T.alloc_buffer((16, 16), dtype="float32") + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + T.reads(A[vi, vk], B[vj, vk]) + T.writes(D[vi, vj]) + with T.init(): + D[vi, vj] = T.min(T.max(T.float32(0), lower), upper) + D[vi, vj] = T.min(T.max(D[vi, vj] + A[vi, vk] * B[vj, vk], lower), upper) + + +def test_matmul_clipping(): + """Test fusion of matmul with clipping epilogue.""" + sch = tir.Schedule(matmul_clipping_before, debug_mask="all") + sch.fuse_reduction_epilogue("matmul", "clipping") + assert_structural_equal_ignore_global_symbol(sch.mod["main"], matmul_clipping_expected) + verify_trace_roundtrip(sch=sch, mod=matmul_clipping_before) + + +@T.prim_func +def matmul_clipping_before_per_iteration( + A: T.Buffer((16, 16), "float32"), + B: T.Buffer((16, 16), "float32"), + D: T.Buffer((16, 16), "float32"), +) -> None: + """Original function with per-iteration clipping (same semantics as fused).""" + temp = T.alloc_buffer((16, 16), dtype="float32") + lower = T.float32(-5.0) + upper = T.float32(5.0) + for i, j in T.grid(16, 16): + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) + temp[vi, vj] = T.min(T.max(T.float32(0), lower), upper) # Clip init + + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + # Per-iteration clipping + temp[vi, vj] = T.min(T.max(temp[vi, vj] + A[vi, vk] * B[vj, vk], lower), upper) + + for i, j in T.grid(16, 16): + with T.block("copy"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = temp[vi, vj] + + +def test_matmul_clipping_correctness_unified(): + """Test that original and fused produce identical results with per-iteration clipping.""" + A_np = np.random.randn(16, 16).astype("float32") + B_np = np.random.randn(16, 16).astype("float32") + lower = -5.0 + upper = 5.0 + + # NumPy reference for per-iteration clipping + D_ref = np.clip(0.0, lower, upper) # init with clipping + for k in range(16): + D_ref = np.clip(D_ref + np.outer(A_np[:, k], B_np[:, k]), lower, upper) + + # TVM execution (original with per-iteration clipping) + mod_original = tvm.compile(matmul_clipping_before_per_iteration, target="llvm") + D_original_tvm = tvm.runtime.tensor(np.zeros((16, 16), dtype="float32")) + mod_original( + tvm.runtime.tensor(A_np), + tvm.runtime.tensor(B_np), + D_original_tvm, + ) + + # TVM execution (fused) + sch = tir.Schedule(matmul_clipping_before) + sch.fuse_reduction_epilogue("matmul", "clipping") + mod_fused = tvm.compile(sch.mod["main"], target="llvm") + D_fused_tvm = tvm.runtime.tensor(np.zeros((16, 16), dtype="float32")) + # Pass scalar values directly as Python floats + mod_fused( + tvm.runtime.tensor(A_np), + tvm.runtime.tensor(B_np), + D_fused_tvm, + lower, + upper, + ) + + D_original = D_original_tvm.numpy() + D_fused = D_fused_tvm.numpy() + + # Now both should match exactly + np.testing.assert_allclose(D_original, D_ref, rtol=1e-5, atol=1e-6) + np.testing.assert_allclose(D_fused, D_ref, rtol=1e-5, atol=1e-6) + np.testing.assert_allclose(D_original, D_fused, rtol=1e-5, atol=1e-6) + + +@T.prim_func +def matmul_clipping_multiple_epilogue_before( + A: T.Buffer((16, 16), "float32"), + B: T.Buffer((16, 16), "float32"), + D: T.Buffer((16, 16), "float32"), + E: T.Buffer((16, 16), "float32"), + lower: T.float32, + upper: T.float32, +) -> None: + """Original function with separate reduction and multiple epilogue blocks (one with clipping, one without).""" + temp = T.alloc_buffer((16, 16), dtype="float32") + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.float32(0) + temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk] + + for i, j in T.grid(16, 16): + with T.block("clipping"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = T.min(T.max(temp[vi, vj], lower), upper) + + for i, j in T.grid(16, 16): + with T.block("copy"): + vi, vj = T.axis.remap("SS", [i, j]) + E[vi, vj] = temp[vi, vj] + + +@T.prim_func +def matmul_clipping_multiple_epilogue_expected( + A: T.Buffer((16, 16), "float32"), + B: T.Buffer((16, 16), "float32"), + D: T.Buffer((16, 16), "float32"), + E: T.Buffer((16, 16), "float32"), + lower: T.float32, + upper: T.float32, +) -> None: + """Expected function after fusion (Clipping) with multiple epilogue blocks.""" + temp = T.alloc_buffer((16, 16), dtype="float32") + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + T.reads(A[vi, vk], B[vj, vk]) + T.writes(D[vi, vj]) + with T.init(): + D[vi, vj] = T.min(T.max(T.float32(0), lower), upper) + D[vi, vj] = T.min(T.max(D[vi, vj] + A[vi, vk] * B[vj, vk], lower), upper) + for i, j in T.grid(16, 16): + with T.block("copy"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(temp[vi, vj]) + T.writes(E[vi, vj]) + E[vi, vj] = temp[vi, vj] + + +def test_matmul_clipping_multiple_epilogue(): + """Test fusion with multiple epilogue blocks - one with clipping, one without. + + Following the same pattern as test_fuse_reduction_epilogue_multiple_epilogue, + this test verifies that fusion works correctly when there are multiple + epilogue blocks. The temp buffer is kept because the second epilogue block + still needs it. + """ + sch = tir.Schedule(matmul_clipping_multiple_epilogue_before, debug_mask="all") + sch.fuse_reduction_epilogue("matmul", "clipping") + assert_structural_equal_ignore_global_symbol( + sch.mod["main"], matmul_clipping_multiple_epilogue_expected + ) + verify_trace_roundtrip(sch=sch, mod=matmul_clipping_multiple_epilogue_before) + + mod = tvm.compile(sch.mod["main"], target="llvm") + assert mod is not None + + +# Test commutative variants of clipping patterns +@pytest.mark.parametrize( + "pattern_func", + [ + lambda temp, lower, upper: T.min(T.max(temp, lower), upper), # min(max(temp, lower), upper) + lambda temp, lower, upper: T.min(upper, T.max(temp, lower)), # min(upper, max(temp, lower)) + lambda temp, lower, upper: T.min(T.max(lower, temp), upper), # min(max(lower, temp), upper) + lambda temp, lower, upper: T.max(T.min(temp, upper), lower), # max(min(temp, upper), lower) + lambda temp, lower, upper: T.max(lower, T.min(temp, upper)), # max(lower, min(temp, upper)) + ], +) +def test_matmul_clipping_commutative_variants(pattern_func): + """Test that all commutative variants of clipping patterns are recognized.""" + lower = -5.0 + upper = 5.0 + + @T.prim_func + def test_func( + A: T.Buffer((8, 8), "float32"), + B: T.Buffer((8, 8), "float32"), + D: T.Buffer((8, 8), "float32"), + ) -> None: + temp = T.alloc_buffer((8, 8), dtype="float32") + for i, j, k in T.grid(8, 8, 8): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.float32(0) + temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk] + + for i, j in T.grid(8, 8): + with T.block("clipping"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = pattern_func(temp[vi, vj], T.float32(lower), T.float32(upper)) + + sch = tir.Schedule(test_func, debug_mask="all") + # Should not raise an error - all variants should be recognized + sch.fuse_reduction_epilogue("matmul", "clipping") + verify_trace_roundtrip(sch=sch, mod=test_func) + + +if __name__ == "__main__": + tvm.testing.main() + From b074e04191ae2c9ae8073662c37883ca60b3491f Mon Sep 17 00:00:00 2001 From: hyun gyu kim Date: Thu, 27 Nov 2025 11:19:41 +0900 Subject: [PATCH 3/7] [TIR][Schedule] FuseReductionEpilogue: Add Clipping pattern support Currently, the FuseReductionEpilogue primitive only supports Bias (addition) and BiasReLU (addition + ReLU) epilogue patterns. However, clipping operations (min(max(x, lower), upper)) are commonly used in deep learning models and would benefit from the same fusion optimization. This commit extends FuseReductionEpilogue to support Clipping patterns by: 1. Adding EpilogueType::Clipping to the enum to distinguish clipping patterns from other epilogue types. 2. Adding clipping_lower_ and clipping_upper_ members to ReductionEpilogueFuser to store clipping bounds extracted from the epilogue pattern. 3. Extending AnalyzeEpiloguePattern to detect clipping patterns: - min(max(temp, lower), upper) - max(min(temp, upper), lower) - All commutative variants of min/max at each level 4. Updating BiasReLU pattern matching to handle max(0, x) form in addition to max(x, 0) for better commutativity support. 5. Modifying CreateFusedReductionBlock to apply clipping to the init value: init = min(max(0, lower), upper) 6. Updating BufferReplacer to apply clipping per-iteration: value = min(max(value, lower), upper) 7. Adding validation in BodyPatternAllowFusion to ensure temp appears exactly once in clipping patterns. 8. Creating comprehensive test coverage with 8 test cases: - Basic fusion test - Numerical correctness verification - Multiple epilogue blocks test - 5 commutative variant tests This implementation follows the same per-iteration semantics as BiasReLU, where clipping is applied at each reduction step rather than post-reduction. This semantic change is documented in the docstring with a warning about potential numerical differences. The test suite verifies that all commutative forms of clipping patterns are correctly recognized and that the fused implementation produces numerically identical results to the per-iteration reference implementation. --- .../test_tir_schedule_fuse_reduction_epilogue_clipping.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_clipping.py b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_clipping.py index 76341cebb2c4..6b3338b9a164 100644 --- a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_clipping.py +++ b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_clipping.py @@ -269,4 +269,3 @@ def test_func( if __name__ == "__main__": tvm.testing.main() - From a447b5d0c00c4909e3b1cedc618785cacdf02544 Mon Sep 17 00:00:00 2001 From: hyun gyu kim Date: Thu, 27 Nov 2025 11:29:30 +0900 Subject: [PATCH 4/7] Remove invalid submodules ffi/3rdparty/dlpack and ffi/3rdparty/libbacktrace These submodules were incorrectly added but not defined in .gitmodules, causing CI failures. They should not be tracked as submodules. --- ffi/3rdparty/dlpack | 1 - ffi/3rdparty/libbacktrace | 1 - 2 files changed, 2 deletions(-) delete mode 160000 ffi/3rdparty/dlpack delete mode 160000 ffi/3rdparty/libbacktrace diff --git a/ffi/3rdparty/dlpack b/ffi/3rdparty/dlpack deleted file mode 160000 index 3ea601bb4130..000000000000 --- a/ffi/3rdparty/dlpack +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 3ea601bb413074c49a77c4ce3218bc08f8c4703c diff --git a/ffi/3rdparty/libbacktrace b/ffi/3rdparty/libbacktrace deleted file mode 160000 index 793921876c98..000000000000 --- a/ffi/3rdparty/libbacktrace +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 793921876c981ce49759114d7bb89bb89b2d3a2d From b932c44c6b6ea6cb82bcf4d05d572ea2f0ed8a38 Mon Sep 17 00:00:00 2001 From: hyun gyu kim Date: Thu, 27 Nov 2025 13:09:22 +0900 Subject: [PATCH 5/7] Update 3rdparty/tvm-ffi submodule to match upstream/main --- 3rdparty/tvm-ffi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm-ffi b/3rdparty/tvm-ffi index f703a0cf9358..ae346ec92a3c 160000 --- a/3rdparty/tvm-ffi +++ b/3rdparty/tvm-ffi @@ -1 +1 @@ -Subproject commit f703a0cf9358fa30d8faee719f905c58d8ca6ee3 +Subproject commit ae346ec92a3c386f1376064ae086aae72947c329 From b55b7fbe455eb24e31b5e88a5e6f3dc159b7923c Mon Sep 17 00:00:00 2001 From: hyun gyu kim Date: Thu, 27 Nov 2025 14:29:00 +0900 Subject: [PATCH 6/7] [TIR][Schedule] FuseReductionEpilogue: Add Clipping pattern support Currently, the FuseReductionEpilogue primitive only supports Bias (addition) and BiasReLU (addition + ReLU) epilogue patterns. However, clipping operations (min(max(x, lower), upper)) are commonly used in deep learning models and would benefit from the same fusion optimization. This commit extends FuseReductionEpilogue to support Clipping patterns by: 1. Adding EpilogueType::Clipping to the enum to distinguish clipping patterns from other epilogue types. 2. Adding clipping_lower_ and clipping_upper_ members to ReductionEpilogueFuser to store clipping bounds extracted from the epilogue pattern. 3. Extending AnalyzeEpiloguePattern to detect clipping patterns: - min(max(temp, lower), upper) - max(min(temp, upper), lower) - All commutative variants of min/max at each level 4. Updating BiasReLU pattern matching to handle max(0, x) form in addition to max(x, 0) for better commutativity support. 5. Modifying CreateFusedReductionBlock to apply clipping to the init value: init = min(max(0, lower), upper) 6. Updating BufferReplacer to apply clipping per-iteration: value = min(max(value, lower), upper) 7. Adding validation in BodyPatternAllowFusion to ensure temp appears exactly once in clipping patterns. 8. Creating comprehensive test coverage with 8 test cases: - Basic fusion test - Numerical correctness verification - Multiple epilogue blocks test - 5 commutative variant tests This implementation follows the same per-iteration semantics as BiasReLU, where clipping is applied at each reduction step rather than post-reduction. This semantic change is documented in the docstring with a warning about potential numerical differences. The test suite verifies that all commutative forms of clipping patterns are correctly recognized and that the fused implementation produces numerically identical results to the per-iteration reference implementation. --- python/tvm/tir/schedule/schedule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index b432fd35c3ee..b1e1a3f5d532 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2360,7 +2360,7 @@ def fuse_reduction_epilogue( - Bias: ``output = reduction_result + bias`` - BiasReLU: ``output = max(reduction_result + bias, 0)`` - Clipping: ``output = min(max(reduction_result, lower), upper)`` - or their commutative variants + or their commutative variants .. warning:: From 60da74710afb79e24b9c5a5d6635b8e3a77f81f5 Mon Sep 17 00:00:00 2001 From: hyun gyu kim Date: Thu, 27 Nov 2025 16:12:18 +0900 Subject: [PATCH 7/7] [TIR][Schedule] FuseReductionEpilogue: Add Clipping pattern support Currently, the FuseReductionEpilogue primitive only supports Bias (addition) and BiasReLU (addition + ReLU) epilogue patterns. However, clipping operations (min(max(x, lower), upper)) are commonly used in deep learning models and would benefit from the same fusion optimization. This commit extends FuseReductionEpilogue to support Clipping patterns by: 1. Adding EpilogueType::Clipping to the enum to distinguish clipping patterns from other epilogue types. 2. Adding clipping_lower_ and clipping_upper_ members to ReductionEpilogueFuser to store clipping bounds extracted from the epilogue pattern. 3. Extending AnalyzeEpiloguePattern to detect clipping patterns: - min(max(temp, lower), upper) - max(min(temp, upper), lower) - All commutative variants of min/max at each level 4. Updating BiasReLU pattern matching to handle max(0, x) form in addition to max(x, 0) for better commutativity support. 5. Modifying CreateFusedReductionBlock to apply clipping to the init value: init = min(max(0, lower), upper) 6. Updating BufferReplacer to apply clipping per-iteration: value = min(max(value, lower), upper) 7. Adding validation in BodyPatternAllowFusion to ensure temp appears exactly once in clipping patterns. 8. Creating comprehensive test coverage with 8 test cases: - Basic fusion test - Numerical correctness verification - Multiple epilogue blocks test - 5 commutative variant tests This implementation follows the same per-iteration semantics as BiasReLU, where clipping is applied at each reduction step rather than post-reduction. This semantic change is documented in the docstring with a warning about potential numerical differences. The test suite verifies that all commutative forms of clipping patterns are correctly recognized and that the fused implementation produces numerically identical results to the per-iteration reference implementation. --- src/tir/schedule/primitive/compute_inline.cc | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 27714fda6717..0ab6d7e2b699 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -1090,12 +1090,13 @@ bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue return false; } - // 5. For Clipping pattern, verify temp appears exactly once - if (epilogue_type_ == EpilogueType::Clipping) { - if (loads.size() != 1) { - // Failure: temp must appear exactly once in clipping pattern - return false; - } + // 5. Verify temp appears exactly once in the epilogue pattern + // This ensures correctness for all supported patterns (Bias, BiasReLU, Clipping) + // The reduction result buffer must be used exactly once in the epilogue expression + if (loads.size() != 1) { + // Failure: The reduction result (temp) must be used exactly once in the + // epilogue expression for fusion. + return false; } // 6. Check if producer is a reduction block @@ -1230,6 +1231,13 @@ bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) { epilogue_type_ = EpilogueType::BiasReLU; return true; } + } else if (const auto* load = add_candidate->as()) { + // Handle bias-free ReLU: max(temp, 0) or max(0, temp) + if (load->buffer.same_as(inlined_buffer_)) { + epilogue_addend_ = tir::make_zero(load->dtype); + epilogue_type_ = EpilogueType::BiasReLU; + return true; + } } } }