diff --git a/.gitignore b/.gitignore index 5bcbd5e37314..6fa10a5e7651 100644 --- a/.gitignore +++ b/.gitignore @@ -274,3 +274,6 @@ tvm-site/ # GDB history file .gdb_history + +# Less command history file +.lesshst diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 0d41ffe94307..b1e1a3f5d532 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2356,14 +2356,41 @@ def fuse_reduction_epilogue( It requires: 1) The reduction block is a complete reduction block 2) The epilogue block only reads from the reduction block's output - 3) The epilogue performs a simple addition: output = reduction_result + bias + 3) The epilogue matches one of the supported patterns: + - Bias: ``output = reduction_result + bias`` + - BiasReLU: ``output = max(reduction_result + bias, 0)`` + - Clipping: ``output = min(max(reduction_result, lower), upper)`` + or their commutative variants + + .. warning:: + + **Semantic Change for Non-Linear Epilogues (BiasReLU, Clipping):** + + For non-linear epilogues (BiasReLU and Clipping), fusion changes the + computation semantics from post-reduction application to per-iteration + application. This can lead to different numerical results. + + **Example with Clipping to [-5, 5] and inputs [6, -2]:** + + - **Post-reduction clipping** (original): ``clip(sum([6, -2])) = clip(4) = 4`` + - **Per-iteration clipping** (fused): ``acc=0 → clip(0+6)=5 → clip(5+(-2))=3`` + + The fused version applies clipping at each reduction iteration, which + may be an intended optimization for some models but can cause unexpected + correctness issues if users are not aware of this behavior. + + For linear epilogues (Bias), fusion preserves exact numerical equivalence. Parameters ---------- reduction_block : Union[BlockRV, str] The reduction block (e.g., matmul) epilogue_block : Union[BlockRV, str] - The epilogue block to be fused (e.g., bias add) + The epilogue block to be fused (e.g., bias add, ReLU, clipping) + + Examples + -------- + See :py:func:`test_tir_schedule_fuse_reduction_epilogue` for examples. """ reduction_block = self._normalize_block_arg(reduction_block) epilogue_block = self._normalize_block_arg(epilogue_block) diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index cc3785d5c103..0ab6d7e2b699 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -988,6 +988,13 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sre * \brief Helper to fuse epilogue block into reduction block * Analyzes epilogue pattern and transforms reduction init/update */ +// Epilogue type enumeration +enum class EpilogueType { + Bias, // temp + C + BiasReLU, // max(temp + C, 0) + Clipping, // min(max(temp, lower), upper) +}; + class ReductionEpilogueFuser : public BaseInliner { public: explicit ReductionEpilogueFuser(const Buffer& reduction_buffer, const BlockNode* reduction_block, @@ -995,7 +1002,19 @@ class ReductionEpilogueFuser : public BaseInliner { const StmtSRef& scope_root_sref) : BaseInliner(reduction_buffer, epilogue_block_realize->block, scope_root_sref), reduction_block_(reduction_block), - epilogue_block_(epilogue_block_realize->block.get()) {} + epilogue_block_(epilogue_block_realize->block.get()), + epilogue_type_(EpilogueType::Bias) { + // Disable opaque access check for epilogue fusion + // Epilogue blocks can read multiple buffers (temp + bias), which is allowed + has_opaque_access = false; + } + + // Override CheckOpaqueAccess to allow multiple buffer reads + void CheckOpaqueAccess(const VarNode* buffer_var) { + // For epilogue fusion, we allow multiple buffer reads (temp + bias) + // So we don't check for opaque access + // BaseInliner::CheckOpaqueAccess(buffer_var); // Don't call base class + } bool BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize); @@ -1012,18 +1031,21 @@ class ReductionEpilogueFuser : public BaseInliner { const BufferStoreNode* from) { struct Extractor : public ExprVisitor { void VisitExpr_(const BufferLoadNode* load) final { - if (load->buffer.get() == buffer) { + if (load->buffer.same_as(buffer)) { result.push_back(load); } + // Continue visiting child nodes (indices) ExprVisitor::VisitExpr_(load); } - const BufferNode* buffer; + Buffer buffer; std::vector result; } extractor; - extractor.buffer = buffer.get(); + extractor.buffer = buffer; + // Visit indices first (though they typically don't contain BufferLoad) for (const PrimExpr& expr : from->indices) { extractor(expr); } + // Visit the value expression (e.g., max(temp + C, 0) for ReLU) extractor(from->value); return std::move(extractor.result); } @@ -1036,6 +1058,9 @@ class ReductionEpilogueFuser : public BaseInliner { BufferRegion epilogue_output_region_{nullptr}; // Write region of D Buffer epilogue_addend_buffer_{nullptr}; // Addend buffer C BufferRegion epilogue_addend_region_{nullptr}; // Read region of C + EpilogueType epilogue_type_; // Type of epilogue operation + PrimExpr clipping_lower_{nullptr}; // Lower bound for clipping + PrimExpr clipping_upper_{nullptr}; // Upper bound for clipping }; bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize) { @@ -1058,26 +1083,36 @@ bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue return false; } - // 4. Analyze epilogue pattern: D[i,j] = temp[i,j] + C[i,j] + // 4. Analyze epilogue pattern: D[i,j] = temp[i,j] + C[i,j] or + // D[i,j] = min(max(temp[i,j], lower), upper) if (!AnalyzeEpiloguePattern(inlined_store_->value)) { - // Failure: epilogue is not a simple addition pattern + // Failure: epilogue is not a supported pattern (Bias, BiasReLU, or Clipping) + return false; + } + + // 5. Verify temp appears exactly once in the epilogue pattern + // This ensures correctness for all supported patterns (Bias, BiasReLU, Clipping) + // The reduction result buffer must be used exactly once in the epilogue expression + if (loads.size() != 1) { + // Failure: The reduction result (temp) must be used exactly once in the + // epilogue expression for fusion. return false; } - // 5. Check if producer is a reduction block + // 6. Check if producer is a reduction block if (!IsReductionBlock(reduction_block_)) { // Failure: producer is not a reduction block return false; } - // 6. Extract epilogue information (output buffer, indices, regions, etc.) + // 7. Extract epilogue information (output buffer, indices, regions, etc.) ExtractEpilogueInfo(); return true; } bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) { - // Pattern: temp[i,j] + C[i,j] or C[i,j] + temp[i,j] + // Pattern 1: temp[i,j] + C[i,j] or C[i,j] + temp[i,j] (Bias) if (const auto* add = value.as()) { const auto* load_a = add->a.as(); const auto* load_b = add->b.as(); @@ -1088,10 +1123,125 @@ bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) { // Ensure exactly one operand is from the reduction buffer if (a_is_target != b_is_target) { epilogue_addend_ = a_is_target ? add->b : add->a; + epilogue_type_ = EpilogueType::Bias; return true; } } + // Pattern 2: min(max(temp[i,j], lower), upper) or max(min(temp[i,j], upper), lower) (Clipping) + // Handle all commutative variants of min/max at each level. + + // Helper to check if an expression is a load from the reduction buffer, and + // return the other operand as `other` if so. + auto match_buffer_in_commutative_op = [this](const PrimExpr& a, const PrimExpr& b, + PrimExpr* other) -> bool { + if (const auto* load_a = a.as()) { + if (load_a->buffer.same_as(inlined_buffer_)) { + *other = b; + return true; + } + } + if (const auto* load_b = b.as()) { + if (load_b->buffer.same_as(inlined_buffer_)) { + *other = a; + return true; + } + } + return false; + }; + + // Check for min(max(temp, lower), upper) and commutative variants + if (const auto* min_node = value.as()) { + const MaxNode* max_node = nullptr; + PrimExpr upper; + // Try both (a, b) as possible positions of the inner max + if ((max_node = min_node->a.as())) { + upper = min_node->b; + } else if ((max_node = min_node->b.as())) { + upper = min_node->a; + } + if (max_node != nullptr) { + PrimExpr lower; + if (match_buffer_in_commutative_op(max_node->a, max_node->b, &lower)) { + clipping_lower_ = lower; + clipping_upper_ = upper; + epilogue_type_ = EpilogueType::Clipping; + return true; + } + } + } + + // Check for max(min(temp[i,j], upper), lower) and commutative variants + if (const auto* max_node = value.as()) { + const MinNode* min_node = nullptr; + PrimExpr lower; + // Try both (a, b) as possible positions of the inner min + if ((min_node = max_node->a.as())) { + lower = max_node->b; + } else if ((min_node = max_node->b.as())) { + lower = max_node->a; + } + if (min_node != nullptr) { + PrimExpr upper; + if (match_buffer_in_commutative_op(min_node->a, min_node->b, &upper)) { + clipping_lower_ = lower; + clipping_upper_ = upper; + epilogue_type_ = EpilogueType::Clipping; + return true; + } + } + } + + // Pattern 3: max(temp[i,j] + C[i,j], 0) or max(C[i,j] + temp[i,j], 0) (BiasReLU) + // Also handle max(0, temp[i,j] + C[i,j]) or max(0, C[i,j] + temp[i,j]) + if (const auto* max_node = value.as()) { + // Check if either operand is zero (ReLU: max(x, 0) or max(0, x)) + // Support both integer and float zero constants. + const PrimExpr* add_candidate = nullptr; + bool is_zero_const = false; + auto is_zero_expr = [](const PrimExpr& expr) -> bool { + if (tir::is_zero(expr)) { + return true; + } + if (const auto* float_imm = expr.as()) { + return float_imm->value == 0.0; + } + return false; + }; + + if (is_zero_expr(max_node->a)) { + is_zero_const = true; + add_candidate = &max_node->b; + } else if (is_zero_expr(max_node->b)) { + is_zero_const = true; + add_candidate = &max_node->a; + } + + if (is_zero_const && add_candidate != nullptr) { + if (const auto* add = add_candidate->as()) { + const auto* load_a = add->a.as(); + const auto* load_b = add->b.as(); + + bool a_is_target = load_a && load_a->buffer.same_as(inlined_buffer_); + bool b_is_target = load_b && load_b->buffer.same_as(inlined_buffer_); + + // Ensure exactly one operand is from the reduction buffer + if (a_is_target != b_is_target) { + epilogue_addend_ = a_is_target ? add->b : add->a; + epilogue_type_ = EpilogueType::BiasReLU; + return true; + } + } else if (const auto* load = add_candidate->as()) { + // Handle bias-free ReLU: max(temp, 0) or max(0, temp) + if (load->buffer.same_as(inlined_buffer_)) { + epilogue_addend_ = tir::make_zero(load->dtype); + epilogue_type_ = EpilogueType::BiasReLU; + return true; + } + } + } + } + return false; } @@ -1158,20 +1308,54 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti var_map[epilogue_data_vars[i]] = reduction_data_vars[i]; } - // 2. Change init to epilogue value: D[vi, vj] = C[vi, vj] - BufferStore new_init_store(epilogue_output_buffer_, Substitute(epilogue_addend_, var_map), - Substitute(epilogue_output_indices_, var_map)); + // 2. Change init to epilogue value based on epilogue type + BufferStore new_init_store; + if (epilogue_type_ == EpilogueType::BiasReLU) { + // For ReLU, init should be max(C[vi, vj], 0) to match per-iteration ReLU semantics + PrimExpr init_value = Substitute(epilogue_addend_, var_map); + PrimExpr zero = tir::make_zero(init_value.dtype()); + new_init_store = BufferStore(epilogue_output_buffer_, Max(init_value, zero), + Substitute(epilogue_output_indices_, var_map)); + } else if (epilogue_type_ == EpilogueType::Clipping) { + // For Clipping, init should be min(max(init_value, lower), upper) + // Since init is typically 0, this becomes min(max(0, lower), upper) + PrimExpr init_value = tir::make_zero(epilogue_output_buffer_->dtype); + PrimExpr clipped_init = Min(Max(init_value, Substitute(clipping_lower_, var_map)), + Substitute(clipping_upper_, var_map)); + new_init_store = BufferStore(epilogue_output_buffer_, clipped_init, + Substitute(epilogue_output_indices_, var_map)); + } else { + // Bias: D[vi, vj] = C[vi, vj] + new_init_store = BufferStore(epilogue_output_buffer_, Substitute(epilogue_addend_, var_map), + Substitute(epilogue_output_indices_, var_map)); + } new_block->init = new_init_store; // 3. Replace output buffer from temp to D in body class BufferReplacer : public StmtExprMutator { public: - BufferReplacer(Buffer old_buf, Buffer new_buf) : old_buffer_(old_buf), new_buffer_(new_buf) {} + BufferReplacer(Buffer old_buf, Buffer new_buf, EpilogueType epilogue_type, DataType dtype, + PrimExpr clipping_lower = PrimExpr(), PrimExpr clipping_upper = PrimExpr()) + : old_buffer_(old_buf), + new_buffer_(new_buf), + epilogue_type_(epilogue_type), + dtype_(dtype), + clipping_lower_(clipping_lower), + clipping_upper_(clipping_upper) {} Stmt VisitStmt_(const BufferStoreNode* op) final { BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); if (store->buffer.same_as(old_buffer_)) { - return BufferStore(new_buffer_, store->value, store->indices); + PrimExpr new_value = store->value; + // For ReLU, apply max per iteration to match per-iteration ReLU semantics + if (epilogue_type_ == EpilogueType::BiasReLU) { + PrimExpr zero = tir::make_zero(dtype_); + new_value = Max(new_value, zero); + } else if (epilogue_type_ == EpilogueType::Clipping) { + // For Clipping, apply min(max(value, lower), upper) per iteration + new_value = Min(Max(new_value, clipping_lower_), clipping_upper_); + } + return BufferStore(new_buffer_, new_value, store->indices); } return store; } @@ -1187,9 +1371,19 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti private: Buffer old_buffer_; Buffer new_buffer_; + EpilogueType epilogue_type_; + DataType dtype_; + PrimExpr clipping_lower_; + PrimExpr clipping_upper_; }; - BufferReplacer replacer(inlined_buffer_, epilogue_output_buffer_); + DataType dtype = epilogue_output_buffer_->dtype; + PrimExpr clipping_lower_subst = + epilogue_type_ == EpilogueType::Clipping ? Substitute(clipping_lower_, var_map) : PrimExpr(); + PrimExpr clipping_upper_subst = + epilogue_type_ == EpilogueType::Clipping ? Substitute(clipping_upper_, var_map) : PrimExpr(); + BufferReplacer replacer(inlined_buffer_, epilogue_output_buffer_, epilogue_type_, dtype, + clipping_lower_subst, clipping_upper_subst); new_block->body = replacer(reduction_block->body); // 4. Update write regions diff --git a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_clipping.py b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_clipping.py new file mode 100644 index 000000000000..6b3338b9a164 --- /dev/null +++ b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_clipping.py @@ -0,0 +1,271 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import ( + verify_trace_roundtrip, + assert_structural_equal_ignore_global_symbol, +) +import numpy as np + +# pylint: disable=no-member,invalid-name,unused-variable + + +@T.prim_func +def matmul_clipping_before( + A: T.Buffer((16, 16), "float32"), + B: T.Buffer((16, 16), "float32"), + D: T.Buffer((16, 16), "float32"), + lower: T.float32, + upper: T.float32, +) -> None: + """Original function with separate reduction and clipping epilogue blocks.""" + temp = T.alloc_buffer((16, 16), dtype="float32") + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.float32(0) + temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk] + + for i, j in T.grid(16, 16): + with T.block("clipping"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = T.min(T.max(temp[vi, vj], lower), upper) + + +@T.prim_func +def matmul_clipping_expected( + A: T.Buffer((16, 16), "float32"), + B: T.Buffer((16, 16), "float32"), + D: T.Buffer((16, 16), "float32"), + lower: T.float32, + upper: T.float32, +) -> None: + """Expected function after fusion (Clipping).""" + temp = T.alloc_buffer((16, 16), dtype="float32") + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + T.reads(A[vi, vk], B[vj, vk]) + T.writes(D[vi, vj]) + with T.init(): + D[vi, vj] = T.min(T.max(T.float32(0), lower), upper) + D[vi, vj] = T.min(T.max(D[vi, vj] + A[vi, vk] * B[vj, vk], lower), upper) + + +def test_matmul_clipping(): + """Test fusion of matmul with clipping epilogue.""" + sch = tir.Schedule(matmul_clipping_before, debug_mask="all") + sch.fuse_reduction_epilogue("matmul", "clipping") + assert_structural_equal_ignore_global_symbol(sch.mod["main"], matmul_clipping_expected) + verify_trace_roundtrip(sch=sch, mod=matmul_clipping_before) + + +@T.prim_func +def matmul_clipping_before_per_iteration( + A: T.Buffer((16, 16), "float32"), + B: T.Buffer((16, 16), "float32"), + D: T.Buffer((16, 16), "float32"), +) -> None: + """Original function with per-iteration clipping (same semantics as fused).""" + temp = T.alloc_buffer((16, 16), dtype="float32") + lower = T.float32(-5.0) + upper = T.float32(5.0) + for i, j in T.grid(16, 16): + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) + temp[vi, vj] = T.min(T.max(T.float32(0), lower), upper) # Clip init + + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + # Per-iteration clipping + temp[vi, vj] = T.min(T.max(temp[vi, vj] + A[vi, vk] * B[vj, vk], lower), upper) + + for i, j in T.grid(16, 16): + with T.block("copy"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = temp[vi, vj] + + +def test_matmul_clipping_correctness_unified(): + """Test that original and fused produce identical results with per-iteration clipping.""" + A_np = np.random.randn(16, 16).astype("float32") + B_np = np.random.randn(16, 16).astype("float32") + lower = -5.0 + upper = 5.0 + + # NumPy reference for per-iteration clipping + D_ref = np.clip(0.0, lower, upper) # init with clipping + for k in range(16): + D_ref = np.clip(D_ref + np.outer(A_np[:, k], B_np[:, k]), lower, upper) + + # TVM execution (original with per-iteration clipping) + mod_original = tvm.compile(matmul_clipping_before_per_iteration, target="llvm") + D_original_tvm = tvm.runtime.tensor(np.zeros((16, 16), dtype="float32")) + mod_original( + tvm.runtime.tensor(A_np), + tvm.runtime.tensor(B_np), + D_original_tvm, + ) + + # TVM execution (fused) + sch = tir.Schedule(matmul_clipping_before) + sch.fuse_reduction_epilogue("matmul", "clipping") + mod_fused = tvm.compile(sch.mod["main"], target="llvm") + D_fused_tvm = tvm.runtime.tensor(np.zeros((16, 16), dtype="float32")) + # Pass scalar values directly as Python floats + mod_fused( + tvm.runtime.tensor(A_np), + tvm.runtime.tensor(B_np), + D_fused_tvm, + lower, + upper, + ) + + D_original = D_original_tvm.numpy() + D_fused = D_fused_tvm.numpy() + + # Now both should match exactly + np.testing.assert_allclose(D_original, D_ref, rtol=1e-5, atol=1e-6) + np.testing.assert_allclose(D_fused, D_ref, rtol=1e-5, atol=1e-6) + np.testing.assert_allclose(D_original, D_fused, rtol=1e-5, atol=1e-6) + + +@T.prim_func +def matmul_clipping_multiple_epilogue_before( + A: T.Buffer((16, 16), "float32"), + B: T.Buffer((16, 16), "float32"), + D: T.Buffer((16, 16), "float32"), + E: T.Buffer((16, 16), "float32"), + lower: T.float32, + upper: T.float32, +) -> None: + """Original function with separate reduction and multiple epilogue blocks (one with clipping, one without).""" + temp = T.alloc_buffer((16, 16), dtype="float32") + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.float32(0) + temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk] + + for i, j in T.grid(16, 16): + with T.block("clipping"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = T.min(T.max(temp[vi, vj], lower), upper) + + for i, j in T.grid(16, 16): + with T.block("copy"): + vi, vj = T.axis.remap("SS", [i, j]) + E[vi, vj] = temp[vi, vj] + + +@T.prim_func +def matmul_clipping_multiple_epilogue_expected( + A: T.Buffer((16, 16), "float32"), + B: T.Buffer((16, 16), "float32"), + D: T.Buffer((16, 16), "float32"), + E: T.Buffer((16, 16), "float32"), + lower: T.float32, + upper: T.float32, +) -> None: + """Expected function after fusion (Clipping) with multiple epilogue blocks.""" + temp = T.alloc_buffer((16, 16), dtype="float32") + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + T.reads(A[vi, vk], B[vj, vk]) + T.writes(D[vi, vj]) + with T.init(): + D[vi, vj] = T.min(T.max(T.float32(0), lower), upper) + D[vi, vj] = T.min(T.max(D[vi, vj] + A[vi, vk] * B[vj, vk], lower), upper) + for i, j in T.grid(16, 16): + with T.block("copy"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(temp[vi, vj]) + T.writes(E[vi, vj]) + E[vi, vj] = temp[vi, vj] + + +def test_matmul_clipping_multiple_epilogue(): + """Test fusion with multiple epilogue blocks - one with clipping, one without. + + Following the same pattern as test_fuse_reduction_epilogue_multiple_epilogue, + this test verifies that fusion works correctly when there are multiple + epilogue blocks. The temp buffer is kept because the second epilogue block + still needs it. + """ + sch = tir.Schedule(matmul_clipping_multiple_epilogue_before, debug_mask="all") + sch.fuse_reduction_epilogue("matmul", "clipping") + assert_structural_equal_ignore_global_symbol( + sch.mod["main"], matmul_clipping_multiple_epilogue_expected + ) + verify_trace_roundtrip(sch=sch, mod=matmul_clipping_multiple_epilogue_before) + + mod = tvm.compile(sch.mod["main"], target="llvm") + assert mod is not None + + +# Test commutative variants of clipping patterns +@pytest.mark.parametrize( + "pattern_func", + [ + lambda temp, lower, upper: T.min(T.max(temp, lower), upper), # min(max(temp, lower), upper) + lambda temp, lower, upper: T.min(upper, T.max(temp, lower)), # min(upper, max(temp, lower)) + lambda temp, lower, upper: T.min(T.max(lower, temp), upper), # min(max(lower, temp), upper) + lambda temp, lower, upper: T.max(T.min(temp, upper), lower), # max(min(temp, upper), lower) + lambda temp, lower, upper: T.max(lower, T.min(temp, upper)), # max(lower, min(temp, upper)) + ], +) +def test_matmul_clipping_commutative_variants(pattern_func): + """Test that all commutative variants of clipping patterns are recognized.""" + lower = -5.0 + upper = 5.0 + + @T.prim_func + def test_func( + A: T.Buffer((8, 8), "float32"), + B: T.Buffer((8, 8), "float32"), + D: T.Buffer((8, 8), "float32"), + ) -> None: + temp = T.alloc_buffer((8, 8), dtype="float32") + for i, j, k in T.grid(8, 8, 8): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.float32(0) + temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk] + + for i, j in T.grid(8, 8): + with T.block("clipping"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = pattern_func(temp[vi, vj], T.float32(lower), T.float32(upper)) + + sch = tir.Schedule(test_func, debug_mask="all") + # Should not raise an error - all variants should be recognized + sch.fuse_reduction_epilogue("matmul", "clipping") + verify_trace_roundtrip(sch=sch, mod=test_func) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_relu.py b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_relu.py new file mode 100644 index 000000000000..66e5e52e43db --- /dev/null +++ b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_relu.py @@ -0,0 +1,229 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import ( + verify_trace_roundtrip, + assert_structural_equal_ignore_global_symbol, +) +import numpy as np + +# pylint: disable=no-member,invalid-name,unused-variable + + +@T.prim_func +def matmul_bias_relu_before( + A: T.Buffer((16, 16), "float32"), + B: T.Buffer((16, 16), "float32"), + C: T.Buffer((16, 16), "float32"), + D: T.Buffer((16, 16), "float32"), +) -> None: + """Original function with separate reduction and epilogue blocks (Bias + ReLU).""" + temp = T.alloc_buffer((16, 16), dtype="float32") + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.float32(0) + temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk] + + for i, j in T.grid(16, 16): + with T.block("bias_relu"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = T.max(temp[vi, vj] + C[vi, vj], T.float32(0)) + + +@T.prim_func +def matmul_bias_relu_before_per_iteration( + A: T.Buffer((16, 16), "float32"), + B: T.Buffer((16, 16), "float32"), + C: T.Buffer((16, 16), "float32"), + D: T.Buffer((16, 16), "float32"), +) -> None: + """Original function with per-iteration ReLU (same semantics as fused).""" + temp = T.alloc_buffer((16, 16), dtype="float32") + for i, j in T.grid(16, 16): + with T.block("init"): + vi, vj = T.axis.remap("SS", [i, j]) + temp[vi, vj] = T.max(C[vi, vj], T.float32(0)) # ReLU on bias + + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + # Per-iteration ReLU + temp[vi, vj] = T.max(temp[vi, vj] + A[vi, vk] * B[vj, vk], T.float32(0)) + + for i, j in T.grid(16, 16): + with T.block("copy"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = temp[vi, vj] + + +@T.prim_func +def matmul_bias_relu_expected( + A: T.Buffer((16, 16), "float32"), + B: T.Buffer((16, 16), "float32"), + C: T.Buffer((16, 16), "float32"), + D: T.Buffer((16, 16), "float32"), +) -> None: + """Expected function after fusion (Bias + ReLU).""" + temp = T.alloc_buffer((16, 16), dtype="float32") + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) + T.writes(D[vi, vj]) + with T.init(): + D[vi, vj] = T.max(C[vi, vj], T.float32(0)) + D[vi, vj] = T.max(D[vi, vj] + A[vi, vk] * B[vj, vk], T.float32(0)) + + +def test_matmul_bias_relu(): + """Test fusion of matmul with bias + ReLU epilogue.""" + sch = tir.Schedule(matmul_bias_relu_before, debug_mask="all") + sch.fuse_reduction_epilogue("matmul", "bias_relu") + assert_structural_equal_ignore_global_symbol(sch.mod["main"], matmul_bias_relu_expected) + verify_trace_roundtrip(sch=sch, mod=matmul_bias_relu_before) + + +def test_matmul_bias_relu_correctness_unified(): + """Test that original and fused produce identical results with per-iteration ReLU.""" + A_np = np.random.randn(16, 16).astype("float32") + B_np = np.random.randn(16, 16).astype("float32") + C_np = np.random.randn(16, 16).astype("float32") + + # NumPy reference for per-iteration ReLU + # Simulate per-iteration ReLU behavior + # Original code computes A[vi, vk] * B[vj, vk] which is A[i, k] * B[j, k] + # For each k: add outer product of A[:, k] and B[:, k] + D_ref = np.maximum(C_np, 0) # init with ReLU on bias + for k in range(16): + # A[:, k] is shape (16,), B[:, k] is shape (16,) + # Outer product: A[:, k] * B[:, k] for all i, j = A[i, k] * B[j, k] + # Using broadcasting: A[:, k:k+1] * B[:, k:k+1].T gives (16, 1) * (1, 16) = (16, 16) + D_ref = np.maximum(D_ref + np.outer(A_np[:, k], B_np[:, k]), 0) + + # TVM execution (original with per-iteration ReLU) + mod_original = tvm.compile(matmul_bias_relu_before_per_iteration, target="llvm") + D_original_tvm = tvm.runtime.tensor(np.zeros((16, 16), dtype="float32")) + mod_original( + tvm.runtime.tensor(A_np), + tvm.runtime.tensor(B_np), + tvm.runtime.tensor(C_np), + D_original_tvm, + ) + + # TVM execution (fused) + sch = tir.Schedule(matmul_bias_relu_before) + sch.fuse_reduction_epilogue("matmul", "bias_relu") + mod_fused = tvm.compile(sch.mod["main"], target="llvm") + D_fused_tvm = tvm.runtime.tensor(np.zeros((16, 16), dtype="float32")) + mod_fused( + tvm.runtime.tensor(A_np), + tvm.runtime.tensor(B_np), + tvm.runtime.tensor(C_np), + D_fused_tvm, + ) + + D_original = D_original_tvm.numpy() + D_fused = D_fused_tvm.numpy() + + # Now both should match exactly + np.testing.assert_allclose(D_original, D_ref, rtol=1e-5, atol=1e-6) + np.testing.assert_allclose(D_fused, D_ref, rtol=1e-5, atol=1e-6) + np.testing.assert_allclose(D_original, D_fused, rtol=1e-5, atol=1e-6) + + +@T.prim_func +def matmul_bias_relu_multiple_epilogue_before( + A: T.Buffer((16, 16), "float32"), + B: T.Buffer((16, 16), "float32"), + C: T.Buffer((16, 16), "float32"), + D: T.Buffer((16, 16), "float32"), + E: T.Buffer((16, 16), "float32"), +) -> None: + """Original function with separate reduction and multiple epilogue blocks (one with ReLU, one without).""" + temp = T.alloc_buffer((16, 16), dtype="float32") + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.float32(0) + temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk] + + for i, j in T.grid(16, 16): + with T.block("bias_relu"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = T.max(temp[vi, vj] + C[vi, vj], T.float32(0)) + + for i, j in T.grid(16, 16): + with T.block("bias"): + vi, vj = T.axis.remap("SS", [i, j]) + E[vi, vj] = temp[vi, vj] + C[vi, vj] + + +@T.prim_func +def matmul_bias_relu_multiple_epilogue_expected( + A: T.Buffer((16, 16), "float32"), + B: T.Buffer((16, 16), "float32"), + C: T.Buffer((16, 16), "float32"), + D: T.Buffer((16, 16), "float32"), + E: T.Buffer((16, 16), "float32"), +) -> None: + """Expected function after fusion (Bias + ReLU) with multiple epilogue blocks.""" + temp = T.alloc_buffer((16, 16), dtype="float32") + for i, j, k in T.grid(16, 16, 16): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) + T.writes(D[vi, vj]) + with T.init(): + D[vi, vj] = T.max(C[vi, vj], T.float32(0)) + D[vi, vj] = T.max(D[vi, vj] + A[vi, vk] * B[vj, vk], T.float32(0)) + for i, j in T.grid(16, 16): + with T.block("bias"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(temp[vi, vj], C[vi, vj]) + T.writes(E[vi, vj]) + E[vi, vj] = temp[vi, vj] + C[vi, vj] + + +def test_matmul_bias_relu_multiple_epilogue(): + """Test fusion with multiple epilogue blocks - one with ReLU, one without. + + Following the same pattern as test_fuse_reduction_epilogue_multiple_epilogue, + this test verifies that fusion works correctly when there are multiple + epilogue blocks. The temp buffer is kept because the second epilogue block + still needs it. + """ + sch = tir.Schedule(matmul_bias_relu_multiple_epilogue_before, debug_mask="all") + sch.fuse_reduction_epilogue("matmul", "bias_relu") + assert_structural_equal_ignore_global_symbol( + sch.mod["main"], matmul_bias_relu_multiple_epilogue_expected + ) + verify_trace_roundtrip(sch=sch, mod=matmul_bias_relu_multiple_epilogue_before) + + mod = tvm.compile(sch.mod["main"], target="llvm") + assert mod is not None + + +if __name__ == "__main__": + tvm.testing.main()