From 0fc40e7c2092692c48d1b6afc18ae3d9b34568f7 Mon Sep 17 00:00:00 2001 From: hyun gyu kim Date: Tue, 4 Nov 2025 18:27:14 +0900 Subject: [PATCH 1/4] [TIR][Schedule] Add FuseReductionEpilogue primitive to fuse epilogue into reduction init Currently it is not possible to fuse an epilogue operation (e.g., bias addition) into a reduction block's initialization statement. This limitation prevents leveraging hardware-specific instructions that support bias accumulation in vector ISAs, such as MACC (multiply-accumulate with bias) instructions. This commit implements a new schedule primitive 'fuse_reduction_epilogue' that addresses the problem described in: https://discuss.tvm.apache.org/t/tir-problem-inlining-addition-into-matmul-block/18066 The primitive transforms the following pattern: Before: for i, j, k in T.grid(M, N, K): with T.block("matmul"): with T.init(): temp[vi, vj] = 0 temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk] for i, j in T.grid(M, N): with T.block("bias_add"): D[vi, vj] = temp[vi, vj] + C[vi, vj] After: for i, j, k in T.grid(M, N, K): with T.block("matmul"): T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) T.writes(D[vi, vj]) with T.init(): D[vi, vj] = C[vi, vj] # Fused epilogue into init D[vi, vj] = D[vi, vj] + A[vi, vk] * B[vj, vk] The transformation removes the intermediate temp buffer and the separate epilogue block, enabling better tensorization opportunities for hardware with bias accumulation support. Implementation: - ReductionEpilogueFuser class for pattern validation and IR transformation - BodyPatternAllowFusion: Validates epilogue can be fused - AnalyzeEpiloguePattern: Detects addition pattern (D = temp + C) - ExtractEpilogueInfo: Extracts buffer and region information - CreateFusedReductionBlock: Creates single block with modified T.init() - SingleBlockFusionReplacer: Replaces blocks and removes temp buffer - Variable mapping between epilogue and reduction block iter vars - Proper buffer and region updates with correct read/write ordering - FFI bindings and Python API following TVM conventions Changes: - src/tir/schedule/primitive/compute_inline.cc: Core implementation (~430 lines) - src/tir/schedule/primitive.h: Function declaration - include/tvm/tir/schedule/schedule.h: Virtual method in ScheduleNode - src/tir/schedule/concrete_schedule.{h,cc}: ConcreteScheduleNode implementation - src/tir/schedule/traced_schedule.{h,cc}: TracedScheduleNode implementation - src/tir/schedule/schedule.cc: FFI binding registration - python/tvm/tir/schedule/schedule.py: Python API with documentation - tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py: Comprehensive tests including basic fusion, float32 variant, numerical correctness verification, and trace roundtrip validation Run tests with: pytest tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py -v --- include/tvm/tir/schedule/schedule.h | 7 + python/tvm/tir/schedule/schedule.py | 27 ++ src/tir/schedule/concrete_schedule.cc | 9 + src/tir/schedule/concrete_schedule.h | 2 + src/tir/schedule/primitive.h | 8 + src/tir/schedule/primitive/compute_inline.cc | 414 ++++++++++++++++++ src/tir/schedule/schedule.cc | 4 +- src/tir/schedule/traced_schedule.cc | 11 + src/tir/schedule/traced_schedule.h | 1 + ...st_tir_schedule_fuse_reduction_epilogue.py | 160 +++++++ 10 files changed, 642 insertions(+), 1 deletion(-) create mode 100644 tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 60deae801f87..a768a7dd4f31 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -608,6 +608,13 @@ class ScheduleNode : public runtime::Object { * \param block The block to be inlined to its producer */ virtual void ReverseComputeInline(const BlockRV& block) = 0; + /*! + * \brief Fuse an epilogue block into a reduction block + * \param reduction_block The reduction block (e.g., matmul) + * \param epilogue_block The epilogue block to be fused (e.g., bias add) + */ + virtual void FuseReductionEpilogue(const BlockRV& reduction_block, + const BlockRV& epilogue_block) = 0; /******** Schedule: Reduction ********/ /*! * \brief Decompose a reduction block into two separate blocks. diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index ffa7e7174f28..92d082274682 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2345,6 +2345,33 @@ def after_inline(a: T.handle, c: T.handle) -> None: # pylint: disable-next=no-member _ffi_api.ScheduleReverseComputeInline(self, block) # type: ignore + @type_checked + def fuse_reduction_epilogue( + self, + reduction_block: Union[BlockRV, str], + epilogue_block: Union[BlockRV, str], + ) -> None: + """Fuse an epilogue block into a reduction block. + + It requires: + 1) The reduction block is a complete reduction block + 2) The epilogue block only reads from the reduction block's output + 3) The epilogue performs a simple addition: output = reduction_result + bias + + Parameters + ---------- + reduction_block : Union[BlockRV, str] + The reduction block (e.g., matmul) + epilogue_block : Union[BlockRV, str] + The epilogue block to be fused (e.g., bias add) + """ + reduction_block = self._normalize_block_arg(reduction_block) + epilogue_block = self._normalize_block_arg(epilogue_block) + # pylint: disable-next=no-member + _ffi_api.ScheduleFuseReductionEpilogue( + self, reduction_block, epilogue_block + ) # type: ignore + ########## Schedule: Reduction ########## @type_checked diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 89ece537713d..00f421e733e2 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -832,6 +832,15 @@ void ConcreteScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { this->state_->DebugVerify(); } +void ConcreteScheduleNode::FuseReductionEpilogue(const BlockRV& reduction_block_rv, + const BlockRV& epilogue_block_rv) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::FuseReductionEpilogue(state_, this->GetSRef(reduction_block_rv), + this->GetSRef(epilogue_block_rv)); + TVM_TIR_SCHEDULE_END("fuse-reduction-epilogue", this->error_render_level_); + this->state_->DebugVerify(); +} + /******** Schedule: Block Annotation ********/ void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index b6f87a3aae8f..7ee54961415b 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -147,6 +147,8 @@ class ConcreteScheduleNode : public ScheduleNode { int index = -1) override; void ComputeInline(const BlockRV& block) override; void ReverseComputeInline(const BlockRV& block) override; + void FuseReductionEpilogue(const BlockRV& reduction_block, + const BlockRV& epilogue_block) override; /******** Schedule: Reduction ********/ BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override; BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 0c3e5a0efd21..1af0033791f4 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -509,6 +509,14 @@ TVM_DLL void ComputeInline(ScheduleState self, const StmtSRef& block_sref); * \param block_sref The sref to the block to be inlined to its producer */ TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref); +/*! + * \brief Fuse an epilogue block into a reduction block + * \param self The state of the schedule + * \param reduction_block_sref The sref to the reduction block + * \param epilogue_block_sref The sref to the epilogue block to be fused + */ +TVM_DLL void FuseReductionEpilogue(ScheduleState self, const StmtSRef& reduction_block_sref, + const StmtSRef& epilogue_block_sref); /******** Schedule: Reduction ********/ /*! * \brief Decompose a reduction block into two separate blocks. diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index e480c68ff4ad..4406574562da 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -984,6 +984,391 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sre ReverseComputeInlineImpl(self, consumer_block_sref); } +/*! + * \brief Helper to fuse epilogue block into reduction block + * Analyzes epilogue pattern and transforms reduction init/update + */ +class ReductionEpilogueFuser : public BaseInliner { + public: + explicit ReductionEpilogueFuser(const Buffer& reduction_buffer, const BlockNode* reduction_block, + const BlockRealize& epilogue_block_realize, + const StmtSRef& scope_root_sref, const IRModule& mod) + : BaseInliner(reduction_buffer, epilogue_block_realize->block, scope_root_sref), + reduction_block_(reduction_block), + epilogue_block_(epilogue_block_realize->block.get()), + mod_(mod) {} + + bool BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize); + + // Step 2: Create single fused reduction block + Block CreateFusedReductionBlock(const BlockNode* reduction_block, + const BlockRealizeNode* reduction_realize); + + private: + bool AnalyzeEpiloguePattern(const PrimExpr& value); + bool IsReductionBlock(const BlockNode* block); + void ExtractEpilogueInfo(); + // Helper function to extract BufferLoad nodes from BufferStore + static std::vector ExtractBufferLoad(const Buffer& buffer, + const BufferStoreNode* from) { + struct Extractor : public ExprVisitor { + void VisitExpr_(const BufferLoadNode* load) final { + if (load->buffer.get() == buffer) { + result.push_back(load); + } + ExprVisitor::VisitExpr_(load); + } + const BufferNode* buffer; + std::vector result; + } extractor; + extractor.buffer = buffer.get(); + for (const PrimExpr& expr : from->indices) { + extractor(expr); + } + extractor(from->value); + return std::move(extractor.result); + } + + const BlockNode* reduction_block_; + const BlockNode* epilogue_block_; + const IRModule& mod_; + PrimExpr epilogue_addend_{nullptr}; // C[vi, vj] in D = temp + C + Buffer epilogue_output_buffer_{nullptr}; // Output buffer D + ffi::Array epilogue_output_indices_{nullptr}; // Indices of D[vi, vj] + BufferRegion epilogue_output_region_{nullptr}; // Write region of D + Buffer epilogue_addend_buffer_{nullptr}; // Addend buffer C + BufferRegion epilogue_addend_region_{nullptr}; // Read region of C +}; + +bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize) { + // 1. Validate predicate + if (!is_one(epilogue_block_realize->predicate)) { + // Failure: Predicate in epilogue block is not supported + return false; + } + + // 2. Check if epilogue body is BufferStore + if (inlined_store_ == nullptr) { + // Failure: epilogue block body is not BufferStore + return false; + } + + // 3. Check if epilogue reads from reduction buffer + std::vector loads = ExtractBufferLoad(inlined_buffer_, inlined_store_); + if (loads.size() == 0) { + // Failure: no BufferLoad from the reduction buffer + return false; + } + + // 4. Analyze epilogue pattern: D[i,j] = temp[i,j] + C[i,j] + if (!AnalyzeEpiloguePattern(inlined_store_->value)) { + // Failure: epilogue is not a simple addition pattern + return false; + } + + // 5. Check if producer is a reduction block + if (!IsReductionBlock(reduction_block_)) { + // Failure: producer is not a reduction block + return false; + } + + // 6. Extract epilogue information (output buffer, indices, regions, etc.) + ExtractEpilogueInfo(); + + return true; +} + +bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) { + // Pattern: temp[i,j] + C[i,j] or C[i,j] + temp[i,j] + if (const auto* add = value.as()) { + const auto* load_a = add->a.as(); + const auto* load_b = add->b.as(); + + bool a_is_target = load_a && load_a->buffer.same_as(inlined_buffer_); + bool b_is_target = load_b && load_b->buffer.same_as(inlined_buffer_); + + // Ensure exactly one operand is from the reduction buffer + if (a_is_target != b_is_target) { + epilogue_addend_ = a_is_target ? add->b : add->a; + return true; + } + } + + return false; +} + +bool ReductionEpilogueFuser::IsReductionBlock(const BlockNode* block) { + // Check if block has reduction iter vars + for (const IterVar& iter : block->iter_vars) { + if (iter->iter_type == kCommReduce) { + return true; + } + } + return false; +} + +void ReductionEpilogueFuser::ExtractEpilogueInfo() { + // Extract epilogue output buffer and indices + epilogue_output_buffer_ = inlined_store_->buffer; + epilogue_output_indices_ = inlined_store_->indices; + + // Extract epilogue output region from epilogue block writes + for (const BufferRegion& write : epilogue_block_->writes) { + if (write->buffer.same_as(epilogue_output_buffer_)) { + epilogue_output_region_ = write; + break; + } + } + + // Extract epilogue addend buffer and region from epilogue_addend_ + if (const auto* load = epilogue_addend_.as()) { + epilogue_addend_buffer_ = load->buffer; + // Find the read region from epilogue block reads + for (const BufferRegion& read : epilogue_block_->reads) { + if (read->buffer.same_as(epilogue_addend_buffer_)) { + epilogue_addend_region_ = read; + break; + } + } + } +} + +Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reduction_block, + const BlockRealizeNode* reduction_realize) { + ObjectPtr new_block = ffi::make_object(*reduction_block); + + // 1. Keep all iter vars (data parallel + reduction) + new_block->iter_vars = reduction_block->iter_vars; + + // 2. Map epilogue block vars to reduction block vars + std::vector reduction_data_vars; + for (const IterVar& iter_var : reduction_block->iter_vars) { + if (iter_var->iter_type == IterVarType::kDataPar) { + reduction_data_vars.push_back(iter_var->var); + } + } + std::vector epilogue_data_vars; + for (const IterVar& iter_var : epilogue_block_->iter_vars) { + if (iter_var->iter_type == IterVarType::kDataPar) { + epilogue_data_vars.push_back(iter_var->var); + } + } + + ICHECK_EQ(reduction_data_vars.size(), epilogue_data_vars.size()) + << "ValueError: The number of data parallel iter vars must be the same in the reduction " + "and epilogue blocks."; + + std::unordered_map var_map; + for (size_t i = 0; i < reduction_data_vars.size(); ++i) { + var_map[epilogue_data_vars[i]] = reduction_data_vars[i]; + } + + // 3. Change init to epilogue value: D[vi, vj] = C[vi, vj] + BufferStore new_init_store(epilogue_output_buffer_, Substitute(epilogue_addend_, var_map), + Substitute(epilogue_output_indices_, var_map)); + new_block->init = new_init_store; + + // 4. Replace output buffer from temp to D in body + class BufferReplacer : public StmtExprMutator { + public: + BufferReplacer(Buffer old_buf, Buffer new_buf) : old_buffer_(old_buf), new_buffer_(new_buf) {} + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + if (store->buffer.same_as(old_buffer_)) { + return BufferStore(new_buffer_, store->value, store->indices); + } + return store; + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); + if (load->buffer.same_as(old_buffer_)) { + return BufferLoad(new_buffer_, load->indices); + } + return load; + } + + private: + Buffer old_buffer_; + Buffer new_buffer_; + }; + + BufferReplacer replacer(inlined_buffer_, epilogue_output_buffer_); + new_block->body = replacer(reduction_block->body); + + // 5. Update write regions + ffi::Array new_writes; + for (const BufferRegion& write : reduction_block->writes) { + if (write->buffer.same_as(inlined_buffer_)) { + new_writes.push_back( + BufferRegion(epilogue_output_buffer_, Substitute(write->region, var_map))); + } else { + new_writes.push_back(write); + } + } + new_block->writes = new_writes; + + // 6. Update read regions (C first, then A, B) + ffi::Array new_reads; + std::unordered_set read_bufs; + + // Add C buffer read first (used in init) + if (epilogue_addend_buffer_.defined()) { + new_reads.push_back(BufferRegion(epilogue_addend_buffer_, + Substitute(epilogue_addend_region_->region, var_map))); + read_bufs.insert(epilogue_addend_buffer_.get()); + } + + // Add existing read regions (A, B, etc.) + for (const BufferRegion& read : reduction_block->reads) { + if (!read->buffer.same_as(inlined_buffer_)) { + // Only add non-temp buffers + if (read_bufs.find(read->buffer.get()) == read_bufs.end()) { + new_reads.push_back(read); + read_bufs.insert(read->buffer.get()); + } + } + } + + new_block->reads = new_reads; + + return Block(new_block); +} + +/*! + * \brief Helper class to replace reduction and epilogue blocks with a single fused block + */ +class SingleBlockFusionReplacer : public StmtMutator { + public: + static Block Replace(Block old_scope_root, Block new_fused_block, Block old_reduction_block, + Block old_epilogue_block, Buffer reduction_buffer) { + SingleBlockFusionReplacer replacer(std::move(new_fused_block), std::move(old_reduction_block), + std::move(old_epilogue_block), std::move(reduction_buffer)); + Block result = Downcast(replacer(std::move(old_scope_root))); + + // Remove intermediate temp buffer + BlockNode* p = result.CopyOnWrite(); + ffi::Array new_alloc_buffers; + for (const Buffer& buf : p->alloc_buffers) { + if (!buf.same_as(replacer.reduction_buffer_)) { + new_alloc_buffers.push_back(buf); + } + } + p->alloc_buffers = new_alloc_buffers; + + return result; + } + + private: + explicit SingleBlockFusionReplacer(Block new_fused_block, Block old_reduction_block, + Block old_epilogue_block, Buffer reduction_buffer) + : new_fused_block_(std::move(new_fused_block)), + old_reduction_block_(std::move(old_reduction_block)), + old_epilogue_block_(std::move(old_epilogue_block)), + reduction_buffer_(std::move(reduction_buffer)) {} + + Stmt VisitStmt_(const ForNode* loop) final { + Stmt mutated_body = StmtMutator::VisitStmt(loop->body); + // Remove empty loops (containing only Evaluate(0)) + if (mutated_body.as()) { + return mutated_body; // Return Evaluate(0) to be removed by SeqStmt + } + + return For(loop->loop_var, loop->min, loop->extent, loop->kind, mutated_body, + loop->thread_binding, loop->annotations); + } + + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + if (realize->block.same_as(old_reduction_block_)) { + // Replace reduction block with new fused block + ObjectPtr new_realize = ffi::make_object(*realize); + new_realize->block = new_fused_block_; + return BlockRealize(new_realize); + } else if (realize->block.same_as(old_epilogue_block_)) { + // Remove epilogue block completely + return Evaluate(0); + } + return StmtMutator::VisitStmt_(realize); + } + + Stmt VisitStmt_(const SeqStmtNode* seq) final { + ffi::Array new_stmts; + for (const Stmt& stmt : seq->seq) { + Stmt new_stmt = VisitStmt(stmt); + // Remove Evaluate(0) + if (!new_stmt.as()) { + new_stmts.push_back(new_stmt); + } + } + return SeqStmt::Flatten(new_stmts); + } + + private: + Block new_fused_block_; + Block old_reduction_block_; + Block old_epilogue_block_; + Buffer reduction_buffer_; +}; + +void FuseReductionEpilogueImpl(ScheduleState self, const StmtSRef& reduction_block_sref, + const StmtSRef& epilogue_block_sref, bool check_only = false) { + const BlockNode* _reduction_block = TVM_SREF_TO_BLOCK(reduction_block_sref); + const BlockNode* _epilogue_block = TVM_SREF_TO_BLOCK(epilogue_block_sref); + + Block reduction_block = ffi::GetRef(_reduction_block); + Block epilogue_block = ffi::GetRef(_epilogue_block); + BlockRealize epilogue_block_realize = GetBlockRealize(self, epilogue_block_sref); + + // Step 1. Get the scope block + StmtSRef scope_root_sref = + GetScopeRoot(self, epilogue_block_sref, /*require_stage_pipeline=*/true); + + // Step 2. Get the reduction buffer (intermediate buffer) + Buffer reduction_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, reduction_block); + + // Step 3. Check completeness and reduction block properties + CheckReductionBlock(self, reduction_block_sref, scope_root_sref); + CheckCompleteBlock(self, epilogue_block_sref, scope_root_sref); + CheckNotOutputBlock(self, reduction_block_sref, scope_root_sref); + + // Step 4. Analyze the epilogue pattern + ReductionEpilogueFuser fuser(reduction_buffer, _reduction_block, epilogue_block_realize, + scope_root_sref, self->mod); + if (!fuser.BodyPatternAllowFusion(epilogue_block_realize)) { + throw BodyAnalysisError(true, self->mod, epilogue_block); + } + + if (check_only) { + return; + } + + // Step 5. Create single fused reduction block + BlockRealize reduction_realize = GetBlockRealize(self, reduction_block_sref); + Block fused_block = fuser.CreateFusedReductionBlock(_reduction_block, reduction_realize.get()); + + // Step 6. Transform and replace IR + const BlockNode* old_scope_root = TVM_SREF_TO_BLOCK(scope_root_sref); + + Block new_scope_root = + SingleBlockFusionReplacer::Replace(ffi::GetRef(old_scope_root), fused_block, + reduction_block, epilogue_block, reduction_buffer); + + // Step 7. Update schedule state + ffi::Map block_reuse; + block_reuse.Set(ffi::GetRef(old_scope_root), new_scope_root); + block_reuse.Set(reduction_block, fused_block); + self->Replace(scope_root_sref, new_scope_root, block_reuse); + + // Step 8. Update BlockInfo + self->UpdateScopeBlockInfo(GetBlockRealize(self, scope_root_sref)); +} + +void FuseReductionEpilogue(ScheduleState self, const StmtSRef& reduction_block_sref, + const StmtSRef& epilogue_block_sref) { + FuseReductionEpilogueImpl(self, reduction_block_sref, epilogue_block_sref); +} + /******** InstructionKind Registration ********/ struct ComputeInlineTraits : public UnpackedInstTraits { @@ -1035,5 +1420,34 @@ struct ReverseComputeInlineTraits : public UnpackedInstTraits { + static constexpr const char* kName = "FuseReductionEpilogue"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, BlockRV reduction_block_rv, + BlockRV epilogue_block_rv) { + return sch->FuseReductionEpilogue(reduction_block_rv, epilogue_block_rv); + } + + static ffi::String UnpackedAsPython(ffi::Array outputs, + ffi::String reduction_block_rv, + ffi::String epilogue_block_rv) { + PythonAPICall py("fuse_reduction_epilogue"); + py.Input("reduction_block", reduction_block_rv); + py.Input("epilogue_block", epilogue_block_rv); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(FuseReductionEpilogueTraits); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 845bbb5cc278..35b221561978 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -227,7 +227,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def_method("tir.schedule.ScheduleComputeAt", &ScheduleNode::ComputeAt) .def_method("tir.schedule.ScheduleReverseComputeAt", &ScheduleNode::ReverseComputeAt) .def_method("tir.schedule.ScheduleComputeInline", &ScheduleNode::ComputeInline) - .def_method("tir.schedule.ScheduleReverseComputeInline", &ScheduleNode::ReverseComputeInline); + .def_method("tir.schedule.ScheduleReverseComputeInline", &ScheduleNode::ReverseComputeInline) + .def_method("tir.schedule.ScheduleFuseReductionEpilogue", + &ScheduleNode::FuseReductionEpilogue); } /******** (FFI) Reduction ********/ TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 8129f43833c4..72606f243d69 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -532,6 +532,17 @@ void TracedScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { /*outputs=*/{})); } +void TracedScheduleNode::FuseReductionEpilogue(const BlockRV& reduction_block_rv, + const BlockRV& epilogue_block_rv) { + ConcreteScheduleNode::FuseReductionEpilogue(reduction_block_rv, epilogue_block_rv); + + static const InstructionKind& kind = InstructionKind::Get("FuseReductionEpilogue"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{reduction_block_rv, epilogue_block_rv}, + /*attrs=*/{}, + /*outputs=*/{})); +} + /******** Schedule: Reduction ********/ BlockRV TracedScheduleNode::DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) { diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 0b91dc283392..8c7b16a47e8d 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -109,6 +109,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { int index = -1) final; void ComputeInline(const BlockRV& block_rv) final; void ReverseComputeInline(const BlockRV& block_rv) final; + void FuseReductionEpilogue(const BlockRV& reduction_block, const BlockRV& epilogue_block) final; /******** Schedule: Reduction ********/ BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) final; BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) final; diff --git a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py new file mode 100644 index 000000000000..18d5d58ec644 --- /dev/null +++ b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py @@ -0,0 +1,160 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import ( + verify_trace_roundtrip, + assert_structural_equal_ignore_global_symbol, +) + +# pylint: disable=no-member,invalid-name,unused-variable + + +########## Test cases for fuse_reduction_epilogue ########## + + +@T.prim_func +def matmul_bias_before( + A: T.Buffer((16, 16), "int8"), + B: T.Buffer((16, 16), "int8"), + C: T.Buffer((16, 16), "int32"), + D: T.Buffer((16, 16), "int32"), +) -> None: + """Original function with separate reduction and epilogue blocks.""" + temp = T.alloc_buffer((16, 16), dtype="int32") + for i, j, k in T.grid(16, 16, 16): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.int32(0) + temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") + for i, j in T.grid(16, 16): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = temp[vi, vj] + C[vi, vj] + + +@T.prim_func +def matmul_bias_expected( + A: T.Buffer((16, 16), "int8"), + B: T.Buffer((16, 16), "int8"), + C: T.Buffer((16, 16), "int32"), + D: T.Buffer((16, 16), "int32"), +) -> None: + """Expected function after fusing epilogue into reduction init.""" + for i, j, k in T.grid(16, 16, 16): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) + T.writes(D[vi, vj]) + with T.init(): + D[vi, vj] = C[vi, vj] + D[vi, vj] = D[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") + + +@T.prim_func +def matmul_bias_fp32_before( + A: T.Buffer((32, 32), "float32"), + B: T.Buffer((32, 32), "float32"), + C: T.Buffer((32, 32), "float32"), + D: T.Buffer((32, 32), "float32"), +) -> None: + """Float32 version for additional coverage.""" + temp = T.alloc_buffer((32, 32), dtype="float32") + for i, j, k in T.grid(32, 32, 32): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.float32(0) + temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j in T.grid(32, 32): + with T.block("bias"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = temp[vi, vj] + C[vi, vj] + + +@T.prim_func +def matmul_bias_fp32_expected( + A: T.Buffer((32, 32), "float32"), + B: T.Buffer((32, 32), "float32"), + C: T.Buffer((32, 32), "float32"), + D: T.Buffer((32, 32), "float32"), +) -> None: + """Expected float32 version after fusion.""" + for i, j, k in T.grid(32, 32, 32): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) + T.writes(D[vi, vj]) + with T.init(): + D[vi, vj] = C[vi, vj] + D[vi, vj] = D[vi, vj] + A[vi, vk] * B[vj, vk] + + +def test_fuse_reduction_epilogue_basic(): + """Test basic fusion of epilogue into reduction init.""" + sch = tir.Schedule(matmul_bias_before, debug_mask="all") + sch.fuse_reduction_epilogue("multiply", "add") + assert_structural_equal_ignore_global_symbol(sch.mod["main"], matmul_bias_expected) + verify_trace_roundtrip(sch=sch, mod=matmul_bias_before) + + +def test_fuse_reduction_epilogue_fp32(): + """Test fusion with float32 data type.""" + sch = tir.Schedule(matmul_bias_fp32_before, debug_mask="all") + sch.fuse_reduction_epilogue("matmul", "bias") + assert_structural_equal_ignore_global_symbol(sch.mod["main"], matmul_bias_fp32_expected) + verify_trace_roundtrip(sch=sch, mod=matmul_bias_fp32_before) + + +def test_fuse_reduction_epilogue_numerical_correctness(): + """Test that fusion preserves numerical correctness.""" + import numpy as np + + # Generate random test data + np.random.seed(0) + A_np = np.random.randint(-128, 127, size=(16, 16), dtype=np.int8) + B_np = np.random.randint(-128, 127, size=(16, 16), dtype=np.int8) + C_np = np.random.randint(-1000, 1000, size=(16, 16), dtype=np.int32) + D_original = np.zeros((16, 16), dtype=np.int32) + D_fused = np.zeros((16, 16), dtype=np.int32) + + # Run original version + mod_original = tvm.build(matmul_bias_before, target="llvm") + A_tvm = tvm.runtime.tensor(A_np) + B_tvm = tvm.runtime.tensor(B_np) + C_tvm = tvm.runtime.tensor(C_np) + D_tvm_original = tvm.runtime.tensor(D_original) + mod_original(A_tvm, B_tvm, C_tvm, D_tvm_original) + + # Run fused version + sch = tir.Schedule(matmul_bias_before) + sch.fuse_reduction_epilogue("multiply", "add") + mod_fused = tvm.build(sch.mod["main"], target="llvm") + D_tvm_fused = tvm.runtime.tensor(D_fused) + mod_fused(A_tvm, B_tvm, C_tvm, D_tvm_fused) + + # Verify results match + tvm.testing.assert_allclose(D_tvm_original.numpy(), D_tvm_fused.numpy()) + + +if __name__ == "__main__": + tvm.testing.main() From 71ee6b1d2475d4cbbd03c4091235d8e2bdd86167 Mon Sep 17 00:00:00 2001 From: hyun gyu kim Date: Tue, 18 Nov 2025 15:54:49 +0900 Subject: [PATCH 2/4] [TIR][Schedule] Support multiple epilogue blocks in FuseReductionEpilogue - Add CheckBufferStillUsed helper function to check if reduction buffer is still referenced by other blocks after fusion - Only remove intermediate temp buffer if no other blocks reference it - Add test case for multiple epilogue blocks scenario where one epilogue is fused while another still uses the intermediate buffer - This addresses the case where multiple epilogue blocks use the same reduction output, ensuring the temp buffer is preserved when needed Related issue: https://discuss.tvm.apache.org/t/... --- src/tir/schedule/primitive/compute_inline.cc | 95 ++++- ...st_tir_schedule_fuse_reduction_epilogue.py | 374 ++++++++++-------- 2 files changed, 302 insertions(+), 167 deletions(-) diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 4406574562da..ced04da37d60 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -1236,6 +1236,82 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti return Block(new_block); } +/*! + * \brief Check if a buffer is still referenced by other blocks in the scope + */ +static bool CheckBufferStillUsed(const Block& scope_root, const Buffer& buffer) { + class BufferUsageChecker : public StmtVisitor { + public: + explicit BufferUsageChecker(const Buffer& buffer) : buffer_(buffer) {} + + bool CheckStmt(const Stmt& stmt) { + found_usage_ = false; + VisitStmt(stmt); + return found_usage_; + } + + private: + void VisitStmt_(const BlockRealizeNode* op) final { + if (found_usage_) return; + + if (!op || !op->block.defined()) { + StmtVisitor::VisitStmt_(op); + return; + } + + const BlockNode* block = op->block.get(); + if (!block) { + StmtVisitor::VisitStmt_(op); + return; + } + + // Check reads + for (const BufferRegion& read : block->reads) { + if (read->buffer.same_as(buffer_)) { + found_usage_ = true; + return; + } + } + + // Check writes + for (const BufferRegion& write : block->writes) { + if (write->buffer.same_as(buffer_)) { + found_usage_ = true; + return; + } + } + + // Continue visiting nested blocks + StmtVisitor::VisitStmt_(op); + } + + void VisitStmt_(const BlockNode* op) final { + if (found_usage_) return; + if (!op) return; + + // Check alloc_buffers + for (const Buffer& buf : op->alloc_buffers) { + if (buf.same_as(buffer_)) { + found_usage_ = true; + return; + } + } + + StmtVisitor::VisitStmt_(op); + } + + const Buffer& buffer_; + bool found_usage_{false}; + }; + + if (!scope_root->body.defined()) { + return false; + } + + BufferUsageChecker checker(buffer); + return checker.CheckStmt(scope_root->body); +} + /*! * \brief Helper class to replace reduction and epilogue blocks with a single fused block */ @@ -1247,15 +1323,20 @@ class SingleBlockFusionReplacer : public StmtMutator { std::move(old_epilogue_block), std::move(reduction_buffer)); Block result = Downcast(replacer(std::move(old_scope_root))); - // Remove intermediate temp buffer - BlockNode* p = result.CopyOnWrite(); - ffi::Array new_alloc_buffers; - for (const Buffer& buf : p->alloc_buffers) { - if (!buf.same_as(replacer.reduction_buffer_)) { - new_alloc_buffers.push_back(buf); + // Check if reduction_buffer is still referenced by other blocks + bool buffer_still_used = CheckBufferStillUsed(result, reduction_buffer); + + // Remove intermediate temp buffer only if it's not used by other blocks + if (!buffer_still_used) { + BlockNode* p = result.CopyOnWrite(); + ffi::Array new_alloc_buffers; + for (const Buffer& buf : p->alloc_buffers) { + if (!buf.same_as(reduction_buffer)) { + new_alloc_buffers.push_back(buf); + } } + p->alloc_buffers = new_alloc_buffers; } - p->alloc_buffers = new_alloc_buffers; return result; } diff --git a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py index 18d5d58ec644..37076dd10a31 100644 --- a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py +++ b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py @@ -1,160 +1,214 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=missing-function-docstring,missing-module-docstring -import pytest -import tvm -import tvm.testing -from tvm import tir -from tvm.script import tir as T -from tvm.tir.schedule.testing import ( - verify_trace_roundtrip, - assert_structural_equal_ignore_global_symbol, -) - -# pylint: disable=no-member,invalid-name,unused-variable - - -########## Test cases for fuse_reduction_epilogue ########## - - -@T.prim_func -def matmul_bias_before( - A: T.Buffer((16, 16), "int8"), - B: T.Buffer((16, 16), "int8"), - C: T.Buffer((16, 16), "int32"), - D: T.Buffer((16, 16), "int32"), -) -> None: - """Original function with separate reduction and epilogue blocks.""" - temp = T.alloc_buffer((16, 16), dtype="int32") - for i, j, k in T.grid(16, 16, 16): - with T.block("multiply"): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - with T.init(): - temp[vi, vj] = T.int32(0) - temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") - for i, j in T.grid(16, 16): - with T.block("add"): - vi, vj = T.axis.remap("SS", [i, j]) - D[vi, vj] = temp[vi, vj] + C[vi, vj] - - -@T.prim_func -def matmul_bias_expected( - A: T.Buffer((16, 16), "int8"), - B: T.Buffer((16, 16), "int8"), - C: T.Buffer((16, 16), "int32"), - D: T.Buffer((16, 16), "int32"), -) -> None: - """Expected function after fusing epilogue into reduction init.""" - for i, j, k in T.grid(16, 16, 16): - with T.block("multiply"): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) - T.writes(D[vi, vj]) - with T.init(): - D[vi, vj] = C[vi, vj] - D[vi, vj] = D[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") - - -@T.prim_func -def matmul_bias_fp32_before( - A: T.Buffer((32, 32), "float32"), - B: T.Buffer((32, 32), "float32"), - C: T.Buffer((32, 32), "float32"), - D: T.Buffer((32, 32), "float32"), -) -> None: - """Float32 version for additional coverage.""" - temp = T.alloc_buffer((32, 32), dtype="float32") - for i, j, k in T.grid(32, 32, 32): - with T.block("matmul"): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - with T.init(): - temp[vi, vj] = T.float32(0) - temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk] - for i, j in T.grid(32, 32): - with T.block("bias"): - vi, vj = T.axis.remap("SS", [i, j]) - D[vi, vj] = temp[vi, vj] + C[vi, vj] - - -@T.prim_func -def matmul_bias_fp32_expected( - A: T.Buffer((32, 32), "float32"), - B: T.Buffer((32, 32), "float32"), - C: T.Buffer((32, 32), "float32"), - D: T.Buffer((32, 32), "float32"), -) -> None: - """Expected float32 version after fusion.""" - for i, j, k in T.grid(32, 32, 32): - with T.block("matmul"): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) - T.writes(D[vi, vj]) - with T.init(): - D[vi, vj] = C[vi, vj] - D[vi, vj] = D[vi, vj] + A[vi, vk] * B[vj, vk] - - -def test_fuse_reduction_epilogue_basic(): - """Test basic fusion of epilogue into reduction init.""" - sch = tir.Schedule(matmul_bias_before, debug_mask="all") - sch.fuse_reduction_epilogue("multiply", "add") - assert_structural_equal_ignore_global_symbol(sch.mod["main"], matmul_bias_expected) - verify_trace_roundtrip(sch=sch, mod=matmul_bias_before) - - -def test_fuse_reduction_epilogue_fp32(): - """Test fusion with float32 data type.""" - sch = tir.Schedule(matmul_bias_fp32_before, debug_mask="all") - sch.fuse_reduction_epilogue("matmul", "bias") - assert_structural_equal_ignore_global_symbol(sch.mod["main"], matmul_bias_fp32_expected) - verify_trace_roundtrip(sch=sch, mod=matmul_bias_fp32_before) - - -def test_fuse_reduction_epilogue_numerical_correctness(): - """Test that fusion preserves numerical correctness.""" - import numpy as np - - # Generate random test data - np.random.seed(0) - A_np = np.random.randint(-128, 127, size=(16, 16), dtype=np.int8) - B_np = np.random.randint(-128, 127, size=(16, 16), dtype=np.int8) - C_np = np.random.randint(-1000, 1000, size=(16, 16), dtype=np.int32) - D_original = np.zeros((16, 16), dtype=np.int32) - D_fused = np.zeros((16, 16), dtype=np.int32) - - # Run original version - mod_original = tvm.build(matmul_bias_before, target="llvm") - A_tvm = tvm.runtime.tensor(A_np) - B_tvm = tvm.runtime.tensor(B_np) - C_tvm = tvm.runtime.tensor(C_np) - D_tvm_original = tvm.runtime.tensor(D_original) - mod_original(A_tvm, B_tvm, C_tvm, D_tvm_original) - - # Run fused version - sch = tir.Schedule(matmul_bias_before) - sch.fuse_reduction_epilogue("multiply", "add") - mod_fused = tvm.build(sch.mod["main"], target="llvm") - D_tvm_fused = tvm.runtime.tensor(D_fused) - mod_fused(A_tvm, B_tvm, C_tvm, D_tvm_fused) - - # Verify results match - tvm.testing.assert_allclose(D_tvm_original.numpy(), D_tvm_fused.numpy()) - - -if __name__ == "__main__": - tvm.testing.main() +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import ( + verify_trace_roundtrip, + assert_structural_equal_ignore_global_symbol, +) +import numpy as np + +# pylint: disable=no-member,invalid-name,unused-variable + + +@T.prim_func +def matmul_bias_before( + A: T.Buffer((16, 16), "int8"), + B: T.Buffer((16, 16), "int8"), + C: T.Buffer((16, 16), "int32"), + D: T.Buffer((16, 16), "int32"), +) -> None: + temp = T.alloc_buffer((16, 16), dtype="int32") + for i, j, k in T.grid(16, 16, 16): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.int32(0) + temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") + for i, j in T.grid(16, 16): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = temp[vi, vj] + C[vi, vj] + + +@T.prim_func +def matmul_bias_expected( + A: T.Buffer((16, 16), "int8"), + B: T.Buffer((16, 16), "int8"), + C: T.Buffer((16, 16), "int32"), + D: T.Buffer((16, 16), "int32"), +) -> None: + temp = T.alloc_buffer((16, 16), dtype="int32") + for i, j, k in T.grid(16, 16, 16): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) + T.writes(D[vi, vj]) + with T.init(): + D[vi, vj] = C[vi, vj] + D[vi, vj] = D[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") + + +@T.prim_func +def matmul_bias_fp32_before( + A: T.Buffer((32, 32), "float32"), + B: T.Buffer((32, 32), "float32"), + C: T.Buffer((32, 32), "float32"), + D: T.Buffer((32, 32), "float32"), +) -> None: + temp = T.alloc_buffer((32, 32), dtype="float32") + for i, j, k in T.grid(32, 32, 32): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.float32(0) + temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j in T.grid(32, 32): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = temp[vi, vj] + C[vi, vj] + + +@T.prim_func +def matmul_bias_fp32_expected( + A: T.Buffer((32, 32), "float32"), + B: T.Buffer((32, 32), "float32"), + C: T.Buffer((32, 32), "float32"), + D: T.Buffer((32, 32), "float32"), +) -> None: + temp = T.alloc_buffer((32, 32), dtype="float32") + for i, j, k in T.grid(32, 32, 32): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) + T.writes(D[vi, vj]) + with T.init(): + D[vi, vj] = C[vi, vj] + D[vi, vj] = D[vi, vj] + A[vi, vk] * B[vj, vk] + + +@T.prim_func +def matmul_bias_multiple_epilogue_before( + A: T.Buffer((16, 16), "int8"), + B: T.Buffer((16, 16), "int8"), + C: T.Buffer((16, 16), "int32"), + D: T.Buffer((16, 16), "int32"), + E: T.Buffer((16, 16), "int32"), +) -> None: + temp = T.alloc_buffer((16, 16), dtype="int32") + for i, j, k in T.grid(16, 16, 16): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.int32(0) + temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") + for i, j in T.grid(16, 16): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = temp[vi, vj] + C[vi, vj] + for i, j in T.grid(16, 16): + with T.block("add2"): + vi, vj = T.axis.remap("SS", [i, j]) + E[vi, vj] = temp[vi, vj] + C[vi, vj] + + +@T.prim_func +def matmul_bias_multiple_epilogue_expected( + A: T.Buffer((16, 16), "int8"), + B: T.Buffer((16, 16), "int8"), + C: T.Buffer((16, 16), "int32"), + D: T.Buffer((16, 16), "int32"), + E: T.Buffer((16, 16), "int32"), +) -> None: + temp = T.alloc_buffer((16, 16), dtype="int32") + for i, j, k in T.grid(16, 16, 16): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) + T.writes(D[vi, vj]) + with T.init(): + D[vi, vj] = C[vi, vj] + D[vi, vj] = D[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") + for i, j in T.grid(16, 16): + with T.block("add2"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(temp[vi, vj], C[vi, vj]) + T.writes(E[vi, vj]) + E[vi, vj] = temp[vi, vj] + C[vi, vj] + + +def test_fuse_reduction_epilogue_basic(): + sch = tir.Schedule(matmul_bias_before, debug_mask="all") + sch.fuse_reduction_epilogue("multiply", "add") + assert_structural_equal_ignore_global_symbol(sch.mod["main"], matmul_bias_expected) + verify_trace_roundtrip(sch=sch, mod=matmul_bias_before) + + +def test_fuse_reduction_epilogue_fp32(): + sch = tir.Schedule(matmul_bias_fp32_before, debug_mask="all") + sch.fuse_reduction_epilogue("multiply", "add") + assert_structural_equal_ignore_global_symbol(sch.mod["main"], matmul_bias_fp32_expected) + verify_trace_roundtrip(sch=sch, mod=matmul_bias_fp32_before) + + +def test_fuse_reduction_epilogue_numerical_correctness(): + sch_original = tir.Schedule(matmul_bias_before, debug_mask="all") + mod_original = tvm.compile(sch_original.mod["main"], target="llvm") + + sch_fused = tir.Schedule(matmul_bias_before, debug_mask="all") + sch_fused.fuse_reduction_epilogue("multiply", "add") + mod_fused = tvm.compile(sch_fused.mod["main"], target="llvm") + + A_np = np.random.randint(-128, 127, size=(16, 16), dtype="int8") + B_np = np.random.randint(-128, 127, size=(16, 16), dtype="int8") + C_np = np.random.randint(-1000, 1000, size=(16, 16), dtype="int32") + + expected = (A_np.astype("int32") @ B_np.T.astype("int32")) + C_np + + D_original_tvm = tvm.runtime.tensor(np.zeros((16, 16), dtype="int32")) + D_fused_tvm = tvm.runtime.tensor(np.zeros((16, 16), dtype="int32")) + + mod_original(tvm.runtime.tensor(A_np), tvm.runtime.tensor(B_np), + tvm.runtime.tensor(C_np), D_original_tvm) + + mod_fused(tvm.runtime.tensor(A_np), tvm.runtime.tensor(B_np), + tvm.runtime.tensor(C_np), D_fused_tvm) + + D_original = D_original_tvm.numpy() + D_fused = D_fused_tvm.numpy() + + np.testing.assert_allclose(D_original, expected, rtol=1e-5) + np.testing.assert_allclose(D_fused, expected, rtol=1e-5) + np.testing.assert_allclose(D_fused, D_original, rtol=1e-5) + + +def test_fuse_reduction_epilogue_multiple_epilogue(): + sch = tir.Schedule(matmul_bias_multiple_epilogue_before, debug_mask="all") + sch.fuse_reduction_epilogue("multiply", "add") + assert_structural_equal_ignore_global_symbol(sch.mod["main"], matmul_bias_multiple_epilogue_expected) + verify_trace_roundtrip(sch=sch, mod=matmul_bias_multiple_epilogue_before) + + mod = tvm.compile(sch.mod["main"], target="llvm") + assert mod is not None + + +if __name__ == "__main__": + tvm.testing.main() \ No newline at end of file From b19f54772a3c0a178a5ac1cc95b40dd8de10583d Mon Sep 17 00:00:00 2001 From: hyun gyu kim Date: Tue, 18 Nov 2025 16:26:46 +0900 Subject: [PATCH 3/4] [TIR][Schedule] Add FuseReductionEpilogue primitive to fuse epilogue into reduction init Currently it is not possible to fuse an epilogue operation (e.g., bias addition) into a reduction block's initialization statement. This limitation prevents leveraging hardware-specific instructions that support bias accumulation in vector ISAs, such as MACC (multiply-accumulate with bias) instructions. This commit implements a new schedule primitive 'fuse_reduction_epilogue' that addresses the problem described in: https://discuss.tvm.apache.org/t/tir-problem-inlining-addition-into-matmul-block/18066 The primitive transforms the following pattern: Before: for i, j, k in T.grid(M, N, K): with T.block("matmul"): with T.init(): temp[vi, vj] = 0 temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk] for i, j in T.grid(M, N): with T.block("bias_add"): D[vi, vj] = temp[vi, vj] + C[vi, vj] After: for i, j, k in T.grid(M, N, K): with T.block("matmul"): T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) T.writes(D[vi, vj]) with T.init(): D[vi, vj] = C[vi, vj] # Fused epilogue into init D[vi, vj] = D[vi, vj] + A[vi, vk] * B[vj, vk] The transformation removes the intermediate temp buffer and the separate epilogue block, enabling better tensorization opportunities for hardware with bias accumulation support. To resolve the issue where multiple epilogue blocks use the same reduction output, we modified the code to handle multiple epilogue blocks cases by adding CheckBufferStillUsed function that checks if other blocks still reference the reduction buffer, and modified to keep the temp buffer if it's still referenced. This ensures that when fusing one epilogue block, other epilogue blocks that still use the intermediate buffer continue to work correctly. Implementation: - ReductionEpilogueFuser class for pattern validation and IR transformation - BodyPatternAllowFusion: Validates epilogue can be fused - AnalyzeEpiloguePattern: Detects addition pattern (D = temp + C) - ExtractEpilogueInfo: Extracts buffer and region information - CreateFusedReductionBlock: Creates single block with modified T.init() - SingleBlockFusionReplacer: Replaces blocks and removes temp buffer - CheckBufferStillUsed: Helper function to check if reduction buffer is still referenced by other blocks after fusion - Conditionally removes temp buffer only if no other blocks reference it - Variable mapping between epilogue and reduction block iter vars - Proper buffer and region updates with correct read/write ordering - FFI bindings and Python API following TVM conventions Changes: - src/tir/schedule/primitive/compute_inline.cc: Core implementation (~430 lines) - src/tir/schedule/primitive.h: Function declaration - include/tvm/tir/schedule/schedule.h: Virtual method in ScheduleNode - src/tir/schedule/concrete_schedule.{h,cc}: ConcreteScheduleNode implementation - src/tir/schedule/traced_schedule.{h,cc}: TracedScheduleNode implementation - src/tir/schedule/schedule.cc: FFI binding registration - python/tvm/tir/schedule/schedule.py: Python API with documentation - tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py: Comprehensive tests including basic fusion, float32 variant, numerical correctness verification, trace roundtrip validation, and multiple epilogue blocks test case Tests can be verified through test_fuse_reduction_epilogue_multiple_epilogue function in tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py. Tests can be run using: python -m pytest tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py --- ...st_tir_schedule_fuse_reduction_epilogue.py | 432 +++++++++--------- 1 file changed, 218 insertions(+), 214 deletions(-) diff --git a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py index 37076dd10a31..82a488851ae7 100644 --- a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py +++ b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py @@ -1,214 +1,218 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=missing-function-docstring,missing-module-docstring -import pytest -import tvm -import tvm.testing -from tvm import tir -from tvm.script import tir as T -from tvm.tir.schedule.testing import ( - verify_trace_roundtrip, - assert_structural_equal_ignore_global_symbol, -) -import numpy as np - -# pylint: disable=no-member,invalid-name,unused-variable - - -@T.prim_func -def matmul_bias_before( - A: T.Buffer((16, 16), "int8"), - B: T.Buffer((16, 16), "int8"), - C: T.Buffer((16, 16), "int32"), - D: T.Buffer((16, 16), "int32"), -) -> None: - temp = T.alloc_buffer((16, 16), dtype="int32") - for i, j, k in T.grid(16, 16, 16): - with T.block("multiply"): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - with T.init(): - temp[vi, vj] = T.int32(0) - temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") - for i, j in T.grid(16, 16): - with T.block("add"): - vi, vj = T.axis.remap("SS", [i, j]) - D[vi, vj] = temp[vi, vj] + C[vi, vj] - - -@T.prim_func -def matmul_bias_expected( - A: T.Buffer((16, 16), "int8"), - B: T.Buffer((16, 16), "int8"), - C: T.Buffer((16, 16), "int32"), - D: T.Buffer((16, 16), "int32"), -) -> None: - temp = T.alloc_buffer((16, 16), dtype="int32") - for i, j, k in T.grid(16, 16, 16): - with T.block("multiply"): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) - T.writes(D[vi, vj]) - with T.init(): - D[vi, vj] = C[vi, vj] - D[vi, vj] = D[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") - - -@T.prim_func -def matmul_bias_fp32_before( - A: T.Buffer((32, 32), "float32"), - B: T.Buffer((32, 32), "float32"), - C: T.Buffer((32, 32), "float32"), - D: T.Buffer((32, 32), "float32"), -) -> None: - temp = T.alloc_buffer((32, 32), dtype="float32") - for i, j, k in T.grid(32, 32, 32): - with T.block("multiply"): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - with T.init(): - temp[vi, vj] = T.float32(0) - temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk] - for i, j in T.grid(32, 32): - with T.block("add"): - vi, vj = T.axis.remap("SS", [i, j]) - D[vi, vj] = temp[vi, vj] + C[vi, vj] - - -@T.prim_func -def matmul_bias_fp32_expected( - A: T.Buffer((32, 32), "float32"), - B: T.Buffer((32, 32), "float32"), - C: T.Buffer((32, 32), "float32"), - D: T.Buffer((32, 32), "float32"), -) -> None: - temp = T.alloc_buffer((32, 32), dtype="float32") - for i, j, k in T.grid(32, 32, 32): - with T.block("multiply"): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) - T.writes(D[vi, vj]) - with T.init(): - D[vi, vj] = C[vi, vj] - D[vi, vj] = D[vi, vj] + A[vi, vk] * B[vj, vk] - - -@T.prim_func -def matmul_bias_multiple_epilogue_before( - A: T.Buffer((16, 16), "int8"), - B: T.Buffer((16, 16), "int8"), - C: T.Buffer((16, 16), "int32"), - D: T.Buffer((16, 16), "int32"), - E: T.Buffer((16, 16), "int32"), -) -> None: - temp = T.alloc_buffer((16, 16), dtype="int32") - for i, j, k in T.grid(16, 16, 16): - with T.block("multiply"): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - with T.init(): - temp[vi, vj] = T.int32(0) - temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") - for i, j in T.grid(16, 16): - with T.block("add"): - vi, vj = T.axis.remap("SS", [i, j]) - D[vi, vj] = temp[vi, vj] + C[vi, vj] - for i, j in T.grid(16, 16): - with T.block("add2"): - vi, vj = T.axis.remap("SS", [i, j]) - E[vi, vj] = temp[vi, vj] + C[vi, vj] - - -@T.prim_func -def matmul_bias_multiple_epilogue_expected( - A: T.Buffer((16, 16), "int8"), - B: T.Buffer((16, 16), "int8"), - C: T.Buffer((16, 16), "int32"), - D: T.Buffer((16, 16), "int32"), - E: T.Buffer((16, 16), "int32"), -) -> None: - temp = T.alloc_buffer((16, 16), dtype="int32") - for i, j, k in T.grid(16, 16, 16): - with T.block("multiply"): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) - T.writes(D[vi, vj]) - with T.init(): - D[vi, vj] = C[vi, vj] - D[vi, vj] = D[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") - for i, j in T.grid(16, 16): - with T.block("add2"): - vi, vj = T.axis.remap("SS", [i, j]) - T.reads(temp[vi, vj], C[vi, vj]) - T.writes(E[vi, vj]) - E[vi, vj] = temp[vi, vj] + C[vi, vj] - - -def test_fuse_reduction_epilogue_basic(): - sch = tir.Schedule(matmul_bias_before, debug_mask="all") - sch.fuse_reduction_epilogue("multiply", "add") - assert_structural_equal_ignore_global_symbol(sch.mod["main"], matmul_bias_expected) - verify_trace_roundtrip(sch=sch, mod=matmul_bias_before) - - -def test_fuse_reduction_epilogue_fp32(): - sch = tir.Schedule(matmul_bias_fp32_before, debug_mask="all") - sch.fuse_reduction_epilogue("multiply", "add") - assert_structural_equal_ignore_global_symbol(sch.mod["main"], matmul_bias_fp32_expected) - verify_trace_roundtrip(sch=sch, mod=matmul_bias_fp32_before) - - -def test_fuse_reduction_epilogue_numerical_correctness(): - sch_original = tir.Schedule(matmul_bias_before, debug_mask="all") - mod_original = tvm.compile(sch_original.mod["main"], target="llvm") - - sch_fused = tir.Schedule(matmul_bias_before, debug_mask="all") - sch_fused.fuse_reduction_epilogue("multiply", "add") - mod_fused = tvm.compile(sch_fused.mod["main"], target="llvm") - - A_np = np.random.randint(-128, 127, size=(16, 16), dtype="int8") - B_np = np.random.randint(-128, 127, size=(16, 16), dtype="int8") - C_np = np.random.randint(-1000, 1000, size=(16, 16), dtype="int32") - - expected = (A_np.astype("int32") @ B_np.T.astype("int32")) + C_np - - D_original_tvm = tvm.runtime.tensor(np.zeros((16, 16), dtype="int32")) - D_fused_tvm = tvm.runtime.tensor(np.zeros((16, 16), dtype="int32")) - - mod_original(tvm.runtime.tensor(A_np), tvm.runtime.tensor(B_np), - tvm.runtime.tensor(C_np), D_original_tvm) - - mod_fused(tvm.runtime.tensor(A_np), tvm.runtime.tensor(B_np), - tvm.runtime.tensor(C_np), D_fused_tvm) - - D_original = D_original_tvm.numpy() - D_fused = D_fused_tvm.numpy() - - np.testing.assert_allclose(D_original, expected, rtol=1e-5) - np.testing.assert_allclose(D_fused, expected, rtol=1e-5) - np.testing.assert_allclose(D_fused, D_original, rtol=1e-5) - - -def test_fuse_reduction_epilogue_multiple_epilogue(): - sch = tir.Schedule(matmul_bias_multiple_epilogue_before, debug_mask="all") - sch.fuse_reduction_epilogue("multiply", "add") - assert_structural_equal_ignore_global_symbol(sch.mod["main"], matmul_bias_multiple_epilogue_expected) - verify_trace_roundtrip(sch=sch, mod=matmul_bias_multiple_epilogue_before) - - mod = tvm.compile(sch.mod["main"], target="llvm") - assert mod is not None - - -if __name__ == "__main__": - tvm.testing.main() \ No newline at end of file +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import ( + verify_trace_roundtrip, + assert_structural_equal_ignore_global_symbol, +) +import numpy as np + +# pylint: disable=no-member,invalid-name,unused-variable + + +@T.prim_func +def matmul_bias_before( + A: T.Buffer((16, 16), "int8"), + B: T.Buffer((16, 16), "int8"), + C: T.Buffer((16, 16), "int32"), + D: T.Buffer((16, 16), "int32"), +) -> None: + temp = T.alloc_buffer((16, 16), dtype="int32") + for i, j, k in T.grid(16, 16, 16): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.int32(0) + temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") + for i, j in T.grid(16, 16): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = temp[vi, vj] + C[vi, vj] + + +@T.prim_func +def matmul_bias_expected( + A: T.Buffer((16, 16), "int8"), + B: T.Buffer((16, 16), "int8"), + C: T.Buffer((16, 16), "int32"), + D: T.Buffer((16, 16), "int32"), +) -> None: + temp = T.alloc_buffer((16, 16), dtype="int32") + for i, j, k in T.grid(16, 16, 16): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) + T.writes(D[vi, vj]) + with T.init(): + D[vi, vj] = C[vi, vj] + D[vi, vj] = D[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") + + +@T.prim_func +def matmul_bias_fp32_before( + A: T.Buffer((32, 32), "float32"), + B: T.Buffer((32, 32), "float32"), + C: T.Buffer((32, 32), "float32"), + D: T.Buffer((32, 32), "float32"), +) -> None: + temp = T.alloc_buffer((32, 32), dtype="float32") + for i, j, k in T.grid(32, 32, 32): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.float32(0) + temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j in T.grid(32, 32): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = temp[vi, vj] + C[vi, vj] + + +@T.prim_func +def matmul_bias_fp32_expected( + A: T.Buffer((32, 32), "float32"), + B: T.Buffer((32, 32), "float32"), + C: T.Buffer((32, 32), "float32"), + D: T.Buffer((32, 32), "float32"), +) -> None: + temp = T.alloc_buffer((32, 32), dtype="float32") + for i, j, k in T.grid(32, 32, 32): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) + T.writes(D[vi, vj]) + with T.init(): + D[vi, vj] = C[vi, vj] + D[vi, vj] = D[vi, vj] + A[vi, vk] * B[vj, vk] + + +@T.prim_func +def matmul_bias_multiple_epilogue_before( + A: T.Buffer((16, 16), "int8"), + B: T.Buffer((16, 16), "int8"), + C: T.Buffer((16, 16), "int32"), + D: T.Buffer((16, 16), "int32"), + E: T.Buffer((16, 16), "int32"), +) -> None: + temp = T.alloc_buffer((16, 16), dtype="int32") + for i, j, k in T.grid(16, 16, 16): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.int32(0) + temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") + for i, j in T.grid(16, 16): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = temp[vi, vj] + C[vi, vj] + for i, j in T.grid(16, 16): + with T.block("add2"): + vi, vj = T.axis.remap("SS", [i, j]) + E[vi, vj] = temp[vi, vj] + C[vi, vj] + + +@T.prim_func +def matmul_bias_multiple_epilogue_expected( + A: T.Buffer((16, 16), "int8"), + B: T.Buffer((16, 16), "int8"), + C: T.Buffer((16, 16), "int32"), + D: T.Buffer((16, 16), "int32"), + E: T.Buffer((16, 16), "int32"), +) -> None: + temp = T.alloc_buffer((16, 16), dtype="int32") + for i, j, k in T.grid(16, 16, 16): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) + T.writes(D[vi, vj]) + with T.init(): + D[vi, vj] = C[vi, vj] + D[vi, vj] = D[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") + for i, j in T.grid(16, 16): + with T.block("add2"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(temp[vi, vj], C[vi, vj]) + T.writes(E[vi, vj]) + E[vi, vj] = temp[vi, vj] + C[vi, vj] + + +def test_fuse_reduction_epilogue_basic(): + sch = tir.Schedule(matmul_bias_before, debug_mask="all") + sch.fuse_reduction_epilogue("multiply", "add") + assert_structural_equal_ignore_global_symbol(sch.mod["main"], matmul_bias_expected) + verify_trace_roundtrip(sch=sch, mod=matmul_bias_before) + + +def test_fuse_reduction_epilogue_fp32(): + sch = tir.Schedule(matmul_bias_fp32_before, debug_mask="all") + sch.fuse_reduction_epilogue("multiply", "add") + assert_structural_equal_ignore_global_symbol(sch.mod["main"], matmul_bias_fp32_expected) + verify_trace_roundtrip(sch=sch, mod=matmul_bias_fp32_before) + + +def test_fuse_reduction_epilogue_numerical_correctness(): + sch_original = tir.Schedule(matmul_bias_before, debug_mask="all") + mod_original = tvm.compile(sch_original.mod["main"], target="llvm") + + sch_fused = tir.Schedule(matmul_bias_before, debug_mask="all") + sch_fused.fuse_reduction_epilogue("multiply", "add") + mod_fused = tvm.compile(sch_fused.mod["main"], target="llvm") + + A_np = np.random.randint(-128, 127, size=(16, 16), dtype="int8") + B_np = np.random.randint(-128, 127, size=(16, 16), dtype="int8") + C_np = np.random.randint(-1000, 1000, size=(16, 16), dtype="int32") + + expected = (A_np.astype("int32") @ B_np.T.astype("int32")) + C_np + + D_original_tvm = tvm.runtime.tensor(np.zeros((16, 16), dtype="int32")) + D_fused_tvm = tvm.runtime.tensor(np.zeros((16, 16), dtype="int32")) + + mod_original( + tvm.runtime.tensor(A_np), tvm.runtime.tensor(B_np), tvm.runtime.tensor(C_np), D_original_tvm + ) + + mod_fused( + tvm.runtime.tensor(A_np), tvm.runtime.tensor(B_np), tvm.runtime.tensor(C_np), D_fused_tvm + ) + + D_original = D_original_tvm.numpy() + D_fused = D_fused_tvm.numpy() + + np.testing.assert_allclose(D_original, expected, rtol=1e-5) + np.testing.assert_allclose(D_fused, expected, rtol=1e-5) + np.testing.assert_allclose(D_fused, D_original, rtol=1e-5) + + +def test_fuse_reduction_epilogue_multiple_epilogue(): + sch = tir.Schedule(matmul_bias_multiple_epilogue_before, debug_mask="all") + sch.fuse_reduction_epilogue("multiply", "add") + assert_structural_equal_ignore_global_symbol( + sch.mod["main"], matmul_bias_multiple_epilogue_expected + ) + verify_trace_roundtrip(sch=sch, mod=matmul_bias_multiple_epilogue_before) + + mod = tvm.compile(sch.mod["main"], target="llvm") + assert mod is not None + + +if __name__ == "__main__": + tvm.testing.main() From bd697ccbaabf5a52658d3a8f7a74499ad803e4f3 Mon Sep 17 00:00:00 2001 From: hyun gyu kim Date: Wed, 19 Nov 2025 14:12:39 +0900 Subject: [PATCH 4/4] [TIR][Schedule] Add FuseReductionEpilogue primitive to fuse epilogue into reduction init Currently it is not possible to fuse an epilogue operation (e.g., bias addition) into a reduction block's initialization statement. This limitation prevents leveraging hardware-specific instructions that support bias accumulation in vector ISAs, such as MACC (multiply-accumulate with bias) instructions. This commit implements a new schedule primitive 'fuse_reduction_epilogue' that addresses the problem described in: https://discuss.tvm.apache.org/t/tir-problem-inlining-addition-into-matmul-block/18066 The primitive transforms the following pattern: Before: for i, j, k in T.grid(M, N, K): with T.block("matmul"): with T.init(): temp[vi, vj] = 0 temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk] for i, j in T.grid(M, N): with T.block("bias_add"): D[vi, vj] = temp[vi, vj] + C[vi, vj] After: for i, j, k in T.grid(M, N, K): with T.block("matmul"): T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) T.writes(D[vi, vj]) with T.init(): D[vi, vj] = C[vi, vj] # Fused epilogue into init D[vi, vj] = D[vi, vj] + A[vi, vk] * B[vj, vk] The transformation removes the intermediate temp buffer and the separate epilogue block, enabling better tensorization opportunities for hardware with bias accumulation support. To resolve the issue where multiple epilogue blocks use the same reduction output, we modified the code to handle multiple epilogue blocks cases by adding CheckBufferStillUsed function that checks if other blocks still reference the reduction buffer, and modified to keep the temp buffer if it's still referenced. This ensures that when fusing one epilogue block, other epilogue blocks that still use the intermediate buffer continue to work correctly. Implementation: - ReductionEpilogueFuser class for pattern validation and IR transformation - BodyPatternAllowFusion: Validates epilogue can be fused - AnalyzeEpiloguePattern: Detects addition pattern (D = temp + C) - ExtractEpilogueInfo: Extracts buffer and region information - CreateFusedReductionBlock: Creates single block with modified T.init() - SingleBlockFusionReplacer: Replaces blocks and removes temp buffer - CheckBufferStillUsed: Helper function to check if reduction buffer is still referenced by other blocks after fusion - Conditionally removes temp buffer only if no other blocks reference it - Variable mapping between epilogue and reduction block iter vars - Proper buffer and region updates with correct read/write ordering - FFI bindings and Python API following TVM conventions Changes: - src/tir/schedule/primitive/compute_inline.cc: Core implementation (~430 lines) - src/tir/schedule/primitive.h: Function declaration - include/tvm/tir/schedule/schedule.h: Virtual method in ScheduleNode - src/tir/schedule/concrete_schedule.{h,cc}: ConcreteScheduleNode implementation - src/tir/schedule/traced_schedule.{h,cc}: TracedScheduleNode implementation - src/tir/schedule/schedule.cc: FFI binding registration - python/tvm/tir/schedule/schedule.py: Python API with documentation - tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py: Comprehensive tests including basic fusion, float32 variant, numerical correctness verification, trace roundtrip validation, and multiple epilogue blocks test case Tests can be verified through test_fuse_reduction_epilogue_multiple_epilogue function in tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py. Tests can be run using: python -m pytest tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py --- src/tir/schedule/primitive/compute_inline.cc | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index ced04da37d60..e0be73dcf441 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -1137,10 +1137,7 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti const BlockRealizeNode* reduction_realize) { ObjectPtr new_block = ffi::make_object(*reduction_block); - // 1. Keep all iter vars (data parallel + reduction) - new_block->iter_vars = reduction_block->iter_vars; - - // 2. Map epilogue block vars to reduction block vars + // 1. Map epilogue block vars to reduction block vars std::vector reduction_data_vars; for (const IterVar& iter_var : reduction_block->iter_vars) { if (iter_var->iter_type == IterVarType::kDataPar) { @@ -1163,12 +1160,12 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti var_map[epilogue_data_vars[i]] = reduction_data_vars[i]; } - // 3. Change init to epilogue value: D[vi, vj] = C[vi, vj] + // 2. Change init to epilogue value: D[vi, vj] = C[vi, vj] BufferStore new_init_store(epilogue_output_buffer_, Substitute(epilogue_addend_, var_map), Substitute(epilogue_output_indices_, var_map)); new_block->init = new_init_store; - // 4. Replace output buffer from temp to D in body + // 3. Replace output buffer from temp to D in body class BufferReplacer : public StmtExprMutator { public: BufferReplacer(Buffer old_buf, Buffer new_buf) : old_buffer_(old_buf), new_buffer_(new_buf) {} @@ -1197,7 +1194,7 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti BufferReplacer replacer(inlined_buffer_, epilogue_output_buffer_); new_block->body = replacer(reduction_block->body); - // 5. Update write regions + // 4. Update write regions ffi::Array new_writes; for (const BufferRegion& write : reduction_block->writes) { if (write->buffer.same_as(inlined_buffer_)) { @@ -1209,7 +1206,7 @@ Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti } new_block->writes = new_writes; - // 6. Update read regions (C first, then A, B) + // 5. Update read regions (C first, then A, B) ffi::Array new_reads; std::unordered_set read_bufs;