diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 0ab6d7e2b699..d9a457ff9ef6 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -986,15 +986,8 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sre /*! * \brief Helper to fuse epilogue block into reduction block - * Analyzes epilogue pattern and transforms reduction init/update + * Uses generalized approach to handle any epilogue expression without pattern matching */ -// Epilogue type enumeration -enum class EpilogueType { - Bias, // temp + C - BiasReLU, // max(temp + C, 0) - Clipping, // min(max(temp, lower), upper) -}; - class ReductionEpilogueFuser : public BaseInliner { public: explicit ReductionEpilogueFuser(const Buffer& reduction_buffer, const BlockNode* reduction_block, @@ -1002,8 +995,7 @@ class ReductionEpilogueFuser : public BaseInliner { const StmtSRef& scope_root_sref) : BaseInliner(reduction_buffer, epilogue_block_realize->block, scope_root_sref), reduction_block_(reduction_block), - epilogue_block_(epilogue_block_realize->block.get()), - epilogue_type_(EpilogueType::Bias) { + epilogue_block_(epilogue_block_realize->block.get()) { // Disable opaque access check for epilogue fusion // Epilogue blocks can read multiple buffers (temp + bias), which is allowed has_opaque_access = false; @@ -1023,7 +1015,6 @@ class ReductionEpilogueFuser : public BaseInliner { const BlockRealizeNode* reduction_realize); private: - bool AnalyzeEpiloguePattern(const PrimExpr& value); bool IsReductionBlock(const BlockNode* block); void ExtractEpilogueInfo(); // Helper function to extract BufferLoad nodes from BufferStore @@ -1052,15 +1043,16 @@ class ReductionEpilogueFuser : public BaseInliner { const BlockNode* reduction_block_; const BlockNode* epilogue_block_; - PrimExpr epilogue_addend_{nullptr}; // C[vi, vj] in D = temp + C - Buffer epilogue_output_buffer_{nullptr}; // Output buffer D + // Generalized approach: store the entire epilogue expression + PrimExpr epilogue_expression_{ + nullptr}; // The entire epilogue expression (e.g., temp + C, max(temp + C, 0)) + const BufferLoadNode* reduction_buffer_load_{ + nullptr}; // The reduction buffer load in epilogue expression + Buffer epilogue_output_buffer_{nullptr}; // Output buffer D ffi::Array epilogue_output_indices_{nullptr}; // Indices of D[vi, vj] BufferRegion epilogue_output_region_{nullptr}; // Write region of D - Buffer epilogue_addend_buffer_{nullptr}; // Addend buffer C - BufferRegion epilogue_addend_region_{nullptr}; // Read region of C - EpilogueType epilogue_type_; // Type of epilogue operation - PrimExpr clipping_lower_{nullptr}; // Lower bound for clipping - PrimExpr clipping_upper_{nullptr}; // Upper bound for clipping + Buffer epilogue_addend_buffer_{nullptr}; // Additional buffer (e.g., bias buffer C) + BufferRegion epilogue_addend_region_{nullptr}; // Read region of additional buffer }; bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize) { @@ -1083,22 +1075,18 @@ bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue return false; } - // 4. Analyze epilogue pattern: D[i,j] = temp[i,j] + C[i,j] or - // D[i,j] = min(max(temp[i,j], lower), upper) - if (!AnalyzeEpiloguePattern(inlined_store_->value)) { - // Failure: epilogue is not a supported pattern (Bias, BiasReLU, or Clipping) - return false; - } - - // 5. Verify temp appears exactly once in the epilogue pattern - // This ensures correctness for all supported patterns (Bias, BiasReLU, Clipping) - // The reduction result buffer must be used exactly once in the epilogue expression + // 4. Generalized approach: store the entire epilogue expression + // Verify reduction buffer appears exactly once (required for fusion correctness) if (loads.size() != 1) { // Failure: The reduction result (temp) must be used exactly once in the // epilogue expression for fusion. return false; } + // Store the epilogue expression and reduction buffer load + epilogue_expression_ = inlined_store_->value; + reduction_buffer_load_ = loads[0]; + // 6. Check if producer is a reduction block if (!IsReductionBlock(reduction_block_)) { // Failure: producer is not a reduction block @@ -1111,140 +1099,6 @@ bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue return true; } -bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) { - // Pattern 1: temp[i,j] + C[i,j] or C[i,j] + temp[i,j] (Bias) - if (const auto* add = value.as()) { - const auto* load_a = add->a.as(); - const auto* load_b = add->b.as(); - - bool a_is_target = load_a && load_a->buffer.same_as(inlined_buffer_); - bool b_is_target = load_b && load_b->buffer.same_as(inlined_buffer_); - - // Ensure exactly one operand is from the reduction buffer - if (a_is_target != b_is_target) { - epilogue_addend_ = a_is_target ? add->b : add->a; - epilogue_type_ = EpilogueType::Bias; - return true; - } - } - - // Pattern 2: min(max(temp[i,j], lower), upper) or max(min(temp[i,j], upper), lower) (Clipping) - // Handle all commutative variants of min/max at each level. - - // Helper to check if an expression is a load from the reduction buffer, and - // return the other operand as `other` if so. - auto match_buffer_in_commutative_op = [this](const PrimExpr& a, const PrimExpr& b, - PrimExpr* other) -> bool { - if (const auto* load_a = a.as()) { - if (load_a->buffer.same_as(inlined_buffer_)) { - *other = b; - return true; - } - } - if (const auto* load_b = b.as()) { - if (load_b->buffer.same_as(inlined_buffer_)) { - *other = a; - return true; - } - } - return false; - }; - - // Check for min(max(temp, lower), upper) and commutative variants - if (const auto* min_node = value.as()) { - const MaxNode* max_node = nullptr; - PrimExpr upper; - // Try both (a, b) as possible positions of the inner max - if ((max_node = min_node->a.as())) { - upper = min_node->b; - } else if ((max_node = min_node->b.as())) { - upper = min_node->a; - } - if (max_node != nullptr) { - PrimExpr lower; - if (match_buffer_in_commutative_op(max_node->a, max_node->b, &lower)) { - clipping_lower_ = lower; - clipping_upper_ = upper; - epilogue_type_ = EpilogueType::Clipping; - return true; - } - } - } - - // Check for max(min(temp[i,j], upper), lower) and commutative variants - if (const auto* max_node = value.as()) { - const MinNode* min_node = nullptr; - PrimExpr lower; - // Try both (a, b) as possible positions of the inner min - if ((min_node = max_node->a.as())) { - lower = max_node->b; - } else if ((min_node = max_node->b.as())) { - lower = max_node->a; - } - if (min_node != nullptr) { - PrimExpr upper; - if (match_buffer_in_commutative_op(min_node->a, min_node->b, &upper)) { - clipping_lower_ = lower; - clipping_upper_ = upper; - epilogue_type_ = EpilogueType::Clipping; - return true; - } - } - } - - // Pattern 3: max(temp[i,j] + C[i,j], 0) or max(C[i,j] + temp[i,j], 0) (BiasReLU) - // Also handle max(0, temp[i,j] + C[i,j]) or max(0, C[i,j] + temp[i,j]) - if (const auto* max_node = value.as()) { - // Check if either operand is zero (ReLU: max(x, 0) or max(0, x)) - // Support both integer and float zero constants. - const PrimExpr* add_candidate = nullptr; - bool is_zero_const = false; - auto is_zero_expr = [](const PrimExpr& expr) -> bool { - if (tir::is_zero(expr)) { - return true; - } - if (const auto* float_imm = expr.as()) { - return float_imm->value == 0.0; - } - return false; - }; - - if (is_zero_expr(max_node->a)) { - is_zero_const = true; - add_candidate = &max_node->b; - } else if (is_zero_expr(max_node->b)) { - is_zero_const = true; - add_candidate = &max_node->a; - } - - if (is_zero_const && add_candidate != nullptr) { - if (const auto* add = add_candidate->as()) { - const auto* load_a = add->a.as(); - const auto* load_b = add->b.as(); - - bool a_is_target = load_a && load_a->buffer.same_as(inlined_buffer_); - bool b_is_target = load_b && load_b->buffer.same_as(inlined_buffer_); - - // Ensure exactly one operand is from the reduction buffer - if (a_is_target != b_is_target) { - epilogue_addend_ = a_is_target ? add->b : add->a; - epilogue_type_ = EpilogueType::BiasReLU; - return true; - } - } else if (const auto* load = add_candidate->as()) { - // Handle bias-free ReLU: max(temp, 0) or max(0, temp) - if (load->buffer.same_as(inlined_buffer_)) { - epilogue_addend_ = tir::make_zero(load->dtype); - epilogue_type_ = EpilogueType::BiasReLU; - return true; - } - } - } - } - - return false; -} - bool ReductionEpilogueFuser::IsReductionBlock(const BlockNode* block) { // Check if block has reduction iter vars for (const IterVar& iter : block->iter_vars) { @@ -1268,12 +1122,29 @@ void ReductionEpilogueFuser::ExtractEpilogueInfo() { } } - // Extract epilogue addend buffer and region from epilogue_addend_ - if (const auto* load = epilogue_addend_.as()) { - epilogue_addend_buffer_ = load->buffer; + // Generalized approach: extract all non-reduction buffers from epilogue expression + // Find all buffers in epilogue expression (except the reduction buffer) + struct BufferExtractor : public ExprVisitor { + void VisitExpr_(const BufferLoadNode* load) final { + if (!load->buffer.same_as(reduction_buffer)) { + other_buffers.insert(load->buffer.get()); + } + ExprVisitor::VisitExpr_(load); + } + Buffer reduction_buffer; + std::unordered_set other_buffers; + } extractor; + extractor.reduction_buffer = inlined_buffer_; + extractor(epilogue_expression_); + + // Extract the first non-reduction buffer and its region + // In most cases, there's one additional buffer (e.g., bias buffer) + if (!extractor.other_buffers.empty()) { + const BufferNode* first_buffer = *extractor.other_buffers.begin(); + epilogue_addend_buffer_ = ffi::GetRef(first_buffer); // Find the read region from epilogue block reads for (const BufferRegion& read : epilogue_block_->reads) { - if (read->buffer.same_as(epilogue_addend_buffer_)) { + if (read->buffer.get() == first_buffer) { epilogue_addend_region_ = read; break; } @@ -1308,53 +1179,163 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti var_map[epilogue_data_vars[i]] = reduction_data_vars[i]; } - // 2. Change init to epilogue value based on epilogue type - BufferStore new_init_store; - if (epilogue_type_ == EpilogueType::BiasReLU) { - // For ReLU, init should be max(C[vi, vj], 0) to match per-iteration ReLU semantics - PrimExpr init_value = Substitute(epilogue_addend_, var_map); - PrimExpr zero = tir::make_zero(init_value.dtype()); - new_init_store = BufferStore(epilogue_output_buffer_, Max(init_value, zero), - Substitute(epilogue_output_indices_, var_map)); - } else if (epilogue_type_ == EpilogueType::Clipping) { - // For Clipping, init should be min(max(init_value, lower), upper) - // Since init is typically 0, this becomes min(max(0, lower), upper) - PrimExpr init_value = tir::make_zero(epilogue_output_buffer_->dtype); - PrimExpr clipped_init = Min(Max(init_value, Substitute(clipping_lower_, var_map)), - Substitute(clipping_upper_, var_map)); - new_init_store = BufferStore(epilogue_output_buffer_, clipped_init, - Substitute(epilogue_output_indices_, var_map)); - } else { - // Bias: D[vi, vj] = C[vi, vj] - new_init_store = BufferStore(epilogue_output_buffer_, Substitute(epilogue_addend_, var_map), - Substitute(epilogue_output_indices_, var_map)); - } + // 2. Generalized init transformation: substitute reduction buffer load with identity element (0) + // Create a substituter to replace reduction_buffer_load_ with identity element + class InitSubstituter : public ExprMutator { + public: + InitSubstituter(const Buffer& target_buffer, PrimExpr identity_elem) + : target_buffer_(target_buffer), identity_elem_(identity_elem) {} + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad load = Downcast(ExprMutator::VisitExpr_(op)); + if (load->buffer.same_as(target_buffer_)) { + return identity_elem_; + } + return load; + } + + private: + Buffer target_buffer_; + PrimExpr identity_elem_; + }; + + // Identity element for reduction (assumed to be 0 for addition-based reductions) + PrimExpr identity_elem = tir::make_zero(epilogue_output_buffer_->dtype); + + // Substitute reduction buffer load with identity element + InitSubstituter init_subst(inlined_buffer_, identity_elem); + PrimExpr init_epilogue = init_subst(epilogue_expression_); + + // Apply index mapping + init_epilogue = Substitute(init_epilogue, var_map); + + // Simplify the expression (e.g., 0 + C[vi, vj] -> C[vi, vj]) + arith::Analyzer analyzer; + init_epilogue = analyzer.Simplify(init_epilogue); + + BufferStore new_init_store = BufferStore(epilogue_output_buffer_, init_epilogue, + Substitute(epilogue_output_indices_, var_map)); new_block->init = new_init_store; - // 3. Replace output buffer from temp to D in body - class BufferReplacer : public StmtExprMutator { + // 3. Generalized update transformation: apply epilogue expression with reduction buffer replaced + // If reduction buffer load's parent is Add and other operand is not a reduction buffer, + // remove that operand (bias addend) from update expression + class UpdateSubstituter : public StmtExprMutator { public: - BufferReplacer(Buffer old_buf, Buffer new_buf, EpilogueType epilogue_type, DataType dtype, - PrimExpr clipping_lower = PrimExpr(), PrimExpr clipping_upper = PrimExpr()) + UpdateSubstituter(const Buffer& old_buf, const Buffer& new_buf, const Buffer& reduction_buf, + const PrimExpr& epilogue_expr, const std::unordered_map& var_map) : old_buffer_(old_buf), new_buffer_(new_buf), - epilogue_type_(epilogue_type), - dtype_(dtype), - clipping_lower_(clipping_lower), - clipping_upper_(clipping_upper) {} + reduction_buffer_(reduction_buf), + epilogue_expression_(epilogue_expr), + var_map_(var_map) {} Stmt VisitStmt_(const BufferStoreNode* op) final { BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); if (store->buffer.same_as(old_buffer_)) { - PrimExpr new_value = store->value; - // For ReLU, apply max per iteration to match per-iteration ReLU semantics - if (epilogue_type_ == EpilogueType::BiasReLU) { - PrimExpr zero = tir::make_zero(dtype_); - new_value = Max(new_value, zero); - } else if (epilogue_type_ == EpilogueType::Clipping) { - // For Clipping, apply min(max(value, lower), upper) per iteration - new_value = Min(Max(new_value, clipping_lower_), clipping_upper_); - } + // Replace old_buffer_ in store->value with new_buffer_ to get the reduction update + // expression This ensures store->value references new_buffer_ instead of old_buffer_ + class ReductionUpdateReplacer : public ExprMutator { + public: + ReductionUpdateReplacer(const Buffer& old_buf, const Buffer& new_buf) + : old_buffer_(old_buf), new_buffer_(new_buf) {} + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad load = Downcast(ExprMutator::VisitExpr_(op)); + if (load->buffer.same_as(old_buffer_)) { + return BufferLoad(new_buffer_, load->indices); + } + return load; + } + + private: + Buffer old_buffer_; + Buffer new_buffer_; + }; + + ReductionUpdateReplacer reduction_replacer(old_buffer_, new_buffer_); + PrimExpr reduction_update = reduction_replacer(store->value); + + // Generalized approach: apply epilogue expression with reduction buffer load replaced + // If reduction buffer load's direct parent is Add and the other operand is not a reduction + // buffer, remove that operand (bias addend) from the update expression + class GeneralizedEpilogueApplier : public ExprMutator { + public: + GeneralizedEpilogueApplier(const Buffer& target_buf, const Buffer& reduction_buf, + const PrimExpr& replacement) + : target_buffer_(target_buf), + reduction_buffer_(reduction_buf), + replacement_(replacement), + found_target_load_(false) {} + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad load = Downcast(ExprMutator::VisitExpr_(op)); + if (load->buffer.same_as(target_buffer_)) { + found_target_load_ = true; + // Check if parent is Add (will be checked in VisitExpr_(const AddNode*)) + return replacement_; + } + return load; + } + + PrimExpr VisitExpr_(const AddNode* op) final { + // Visit children first to see if we find the target buffer load + bool found_before = found_target_load_; + found_target_load_ = false; + + PrimExpr a = VisitExpr(op->a); + bool found_in_a = found_target_load_; + found_target_load_ = false; + + PrimExpr b = VisitExpr(op->b); + bool found_in_b = found_target_load_; + + // If target buffer load was found in this Add node + if (found_in_a || found_in_b) { + // Check if the other operand is NOT from the reduction buffer + // If so, it's likely a bias addend that should be removed in update + bool other_is_reduction = false; + if (found_in_a) { + // Check if b is from reduction buffer + if (const auto* load_b = b.as()) { + other_is_reduction = load_b->buffer.same_as(reduction_buffer_); + } + if (!other_is_reduction) { + // b is the bias addend, remove it + return a; + } + } else { // found_in_b + // Check if a is from reduction buffer + if (const auto* load_a = a.as()) { + other_is_reduction = load_a->buffer.same_as(reduction_buffer_); + } + if (!other_is_reduction) { + // a is the bias addend, remove it + return b; + } + } + // If other operand is also from reduction buffer, keep the Add + return Add(a, b); + } + + // Target buffer load not found in this Add, return as is + found_target_load_ = found_before; + return Add(a, b); + } + + private: + const Buffer& target_buffer_; + const Buffer& reduction_buffer_; + const PrimExpr& replacement_; + bool found_target_load_; + }; + + GeneralizedEpilogueApplier applier(old_buffer_, reduction_buffer_, reduction_update); + PrimExpr new_value = applier(epilogue_expression_); + + // Apply index mapping + new_value = Substitute(new_value, var_map_); + return BufferStore(new_buffer_, new_value, store->indices); } return store; @@ -1371,19 +1352,16 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti private: Buffer old_buffer_; Buffer new_buffer_; - EpilogueType epilogue_type_; - DataType dtype_; - PrimExpr clipping_lower_; - PrimExpr clipping_upper_; + Buffer reduction_buffer_; + PrimExpr epilogue_expression_; + std::unordered_map var_map_; }; - DataType dtype = epilogue_output_buffer_->dtype; - PrimExpr clipping_lower_subst = - epilogue_type_ == EpilogueType::Clipping ? Substitute(clipping_lower_, var_map) : PrimExpr(); - PrimExpr clipping_upper_subst = - epilogue_type_ == EpilogueType::Clipping ? Substitute(clipping_upper_, var_map) : PrimExpr(); - BufferReplacer replacer(inlined_buffer_, epilogue_output_buffer_, epilogue_type_, dtype, - clipping_lower_subst, clipping_upper_subst); + // Apply index mapping to epilogue expression first + PrimExpr epilogue_expr_mapped = Substitute(epilogue_expression_, var_map); + + UpdateSubstituter replacer(inlined_buffer_, epilogue_output_buffer_, inlined_buffer_, + epilogue_expr_mapped, var_map); new_block->body = replacer(reduction_block->body); // 4. Update write regions @@ -1398,21 +1376,22 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti } new_block->writes = new_writes; - // 5. Update read regions (C first, then A, B) + // 5. Update read regions: add all buffers from epilogue expression (except reduction buffer) ffi::Array new_reads; std::unordered_set read_bufs; - // Add C buffer read first (used in init) - if (epilogue_addend_buffer_.defined()) { - new_reads.push_back(BufferRegion(epilogue_addend_buffer_, - Substitute(epilogue_addend_region_->region, var_map))); - read_bufs.insert(epilogue_addend_buffer_.get()); + // Add all non-reduction buffers from epilogue expression + for (const BufferRegion& read : epilogue_block_->reads) { + if (!read->buffer.same_as(inlined_buffer_)) { + new_reads.push_back(BufferRegion(read->buffer, Substitute(read->region, var_map))); + read_bufs.insert(read->buffer.get()); + } } - // Add existing read regions (A, B, etc.) + // Add existing read regions from reduction block (A, B, etc.) for (const BufferRegion& read : reduction_block->reads) { if (!read->buffer.same_as(inlined_buffer_)) { - // Only add non-temp buffers + // Only add non-temp buffers that haven't been added yet if (read_bufs.find(read->buffer.get()) == read_bufs.end()) { new_reads.push_back(read); read_bufs.insert(read->buffer.get());