diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 60deae801f87..a768a7dd4f31 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -608,6 +608,13 @@ class ScheduleNode : public runtime::Object { * \param block The block to be inlined to its producer */ virtual void ReverseComputeInline(const BlockRV& block) = 0; + /*! + * \brief Fuse an epilogue block into a reduction block + * \param reduction_block The reduction block (e.g., matmul) + * \param epilogue_block The epilogue block to be fused (e.g., bias add) + */ + virtual void FuseReductionEpilogue(const BlockRV& reduction_block, + const BlockRV& epilogue_block) = 0; /******** Schedule: Reduction ********/ /*! * \brief Decompose a reduction block into two separate blocks. diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index ffa7e7174f28..92d082274682 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2345,6 +2345,33 @@ def after_inline(a: T.handle, c: T.handle) -> None: # pylint: disable-next=no-member _ffi_api.ScheduleReverseComputeInline(self, block) # type: ignore + @type_checked + def fuse_reduction_epilogue( + self, + reduction_block: Union[BlockRV, str], + epilogue_block: Union[BlockRV, str], + ) -> None: + """Fuse an epilogue block into a reduction block. + + It requires: + 1) The reduction block is a complete reduction block + 2) The epilogue block only reads from the reduction block's output + 3) The epilogue performs a simple addition: output = reduction_result + bias + + Parameters + ---------- + reduction_block : Union[BlockRV, str] + The reduction block (e.g., matmul) + epilogue_block : Union[BlockRV, str] + The epilogue block to be fused (e.g., bias add) + """ + reduction_block = self._normalize_block_arg(reduction_block) + epilogue_block = self._normalize_block_arg(epilogue_block) + # pylint: disable-next=no-member + _ffi_api.ScheduleFuseReductionEpilogue( + self, reduction_block, epilogue_block + ) # type: ignore + ########## Schedule: Reduction ########## @type_checked diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 89ece537713d..00f421e733e2 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -832,6 +832,15 @@ void ConcreteScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { this->state_->DebugVerify(); } +void ConcreteScheduleNode::FuseReductionEpilogue(const BlockRV& reduction_block_rv, + const BlockRV& epilogue_block_rv) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::FuseReductionEpilogue(state_, this->GetSRef(reduction_block_rv), + this->GetSRef(epilogue_block_rv)); + TVM_TIR_SCHEDULE_END("fuse-reduction-epilogue", this->error_render_level_); + this->state_->DebugVerify(); +} + /******** Schedule: Block Annotation ********/ void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index b6f87a3aae8f..7ee54961415b 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -147,6 +147,8 @@ class ConcreteScheduleNode : public ScheduleNode { int index = -1) override; void ComputeInline(const BlockRV& block) override; void ReverseComputeInline(const BlockRV& block) override; + void FuseReductionEpilogue(const BlockRV& reduction_block, + const BlockRV& epilogue_block) override; /******** Schedule: Reduction ********/ BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override; BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 0c3e5a0efd21..1af0033791f4 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -509,6 +509,14 @@ TVM_DLL void ComputeInline(ScheduleState self, const StmtSRef& block_sref); * \param block_sref The sref to the block to be inlined to its producer */ TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref); +/*! + * \brief Fuse an epilogue block into a reduction block + * \param self The state of the schedule + * \param reduction_block_sref The sref to the reduction block + * \param epilogue_block_sref The sref to the epilogue block to be fused + */ +TVM_DLL void FuseReductionEpilogue(ScheduleState self, const StmtSRef& reduction_block_sref, + const StmtSRef& epilogue_block_sref); /******** Schedule: Reduction ********/ /*! * \brief Decompose a reduction block into two separate blocks. diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index e480c68ff4ad..e0be73dcf441 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -984,6 +984,469 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sre ReverseComputeInlineImpl(self, consumer_block_sref); } +/*! + * \brief Helper to fuse epilogue block into reduction block + * Analyzes epilogue pattern and transforms reduction init/update + */ +class ReductionEpilogueFuser : public BaseInliner { + public: + explicit ReductionEpilogueFuser(const Buffer& reduction_buffer, const BlockNode* reduction_block, + const BlockRealize& epilogue_block_realize, + const StmtSRef& scope_root_sref, const IRModule& mod) + : BaseInliner(reduction_buffer, epilogue_block_realize->block, scope_root_sref), + reduction_block_(reduction_block), + epilogue_block_(epilogue_block_realize->block.get()), + mod_(mod) {} + + bool BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize); + + // Step 2: Create single fused reduction block + Block CreateFusedReductionBlock(const BlockNode* reduction_block, + const BlockRealizeNode* reduction_realize); + + private: + bool AnalyzeEpiloguePattern(const PrimExpr& value); + bool IsReductionBlock(const BlockNode* block); + void ExtractEpilogueInfo(); + // Helper function to extract BufferLoad nodes from BufferStore + static std::vector ExtractBufferLoad(const Buffer& buffer, + const BufferStoreNode* from) { + struct Extractor : public ExprVisitor { + void VisitExpr_(const BufferLoadNode* load) final { + if (load->buffer.get() == buffer) { + result.push_back(load); + } + ExprVisitor::VisitExpr_(load); + } + const BufferNode* buffer; + std::vector result; + } extractor; + extractor.buffer = buffer.get(); + for (const PrimExpr& expr : from->indices) { + extractor(expr); + } + extractor(from->value); + return std::move(extractor.result); + } + + const BlockNode* reduction_block_; + const BlockNode* epilogue_block_; + const IRModule& mod_; + PrimExpr epilogue_addend_{nullptr}; // C[vi, vj] in D = temp + C + Buffer epilogue_output_buffer_{nullptr}; // Output buffer D + ffi::Array epilogue_output_indices_{nullptr}; // Indices of D[vi, vj] + BufferRegion epilogue_output_region_{nullptr}; // Write region of D + Buffer epilogue_addend_buffer_{nullptr}; // Addend buffer C + BufferRegion epilogue_addend_region_{nullptr}; // Read region of C +}; + +bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize) { + // 1. Validate predicate + if (!is_one(epilogue_block_realize->predicate)) { + // Failure: Predicate in epilogue block is not supported + return false; + } + + // 2. Check if epilogue body is BufferStore + if (inlined_store_ == nullptr) { + // Failure: epilogue block body is not BufferStore + return false; + } + + // 3. Check if epilogue reads from reduction buffer + std::vector loads = ExtractBufferLoad(inlined_buffer_, inlined_store_); + if (loads.size() == 0) { + // Failure: no BufferLoad from the reduction buffer + return false; + } + + // 4. Analyze epilogue pattern: D[i,j] = temp[i,j] + C[i,j] + if (!AnalyzeEpiloguePattern(inlined_store_->value)) { + // Failure: epilogue is not a simple addition pattern + return false; + } + + // 5. Check if producer is a reduction block + if (!IsReductionBlock(reduction_block_)) { + // Failure: producer is not a reduction block + return false; + } + + // 6. Extract epilogue information (output buffer, indices, regions, etc.) + ExtractEpilogueInfo(); + + return true; +} + +bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) { + // Pattern: temp[i,j] + C[i,j] or C[i,j] + temp[i,j] + if (const auto* add = value.as()) { + const auto* load_a = add->a.as(); + const auto* load_b = add->b.as(); + + bool a_is_target = load_a && load_a->buffer.same_as(inlined_buffer_); + bool b_is_target = load_b && load_b->buffer.same_as(inlined_buffer_); + + // Ensure exactly one operand is from the reduction buffer + if (a_is_target != b_is_target) { + epilogue_addend_ = a_is_target ? add->b : add->a; + return true; + } + } + + return false; +} + +bool ReductionEpilogueFuser::IsReductionBlock(const BlockNode* block) { + // Check if block has reduction iter vars + for (const IterVar& iter : block->iter_vars) { + if (iter->iter_type == kCommReduce) { + return true; + } + } + return false; +} + +void ReductionEpilogueFuser::ExtractEpilogueInfo() { + // Extract epilogue output buffer and indices + epilogue_output_buffer_ = inlined_store_->buffer; + epilogue_output_indices_ = inlined_store_->indices; + + // Extract epilogue output region from epilogue block writes + for (const BufferRegion& write : epilogue_block_->writes) { + if (write->buffer.same_as(epilogue_output_buffer_)) { + epilogue_output_region_ = write; + break; + } + } + + // Extract epilogue addend buffer and region from epilogue_addend_ + if (const auto* load = epilogue_addend_.as()) { + epilogue_addend_buffer_ = load->buffer; + // Find the read region from epilogue block reads + for (const BufferRegion& read : epilogue_block_->reads) { + if (read->buffer.same_as(epilogue_addend_buffer_)) { + epilogue_addend_region_ = read; + break; + } + } + } +} + +Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reduction_block, + const BlockRealizeNode* reduction_realize) { + ObjectPtr new_block = ffi::make_object(*reduction_block); + + // 1. Map epilogue block vars to reduction block vars + std::vector reduction_data_vars; + for (const IterVar& iter_var : reduction_block->iter_vars) { + if (iter_var->iter_type == IterVarType::kDataPar) { + reduction_data_vars.push_back(iter_var->var); + } + } + std::vector epilogue_data_vars; + for (const IterVar& iter_var : epilogue_block_->iter_vars) { + if (iter_var->iter_type == IterVarType::kDataPar) { + epilogue_data_vars.push_back(iter_var->var); + } + } + + ICHECK_EQ(reduction_data_vars.size(), epilogue_data_vars.size()) + << "ValueError: The number of data parallel iter vars must be the same in the reduction " + "and epilogue blocks."; + + std::unordered_map var_map; + for (size_t i = 0; i < reduction_data_vars.size(); ++i) { + var_map[epilogue_data_vars[i]] = reduction_data_vars[i]; + } + + // 2. Change init to epilogue value: D[vi, vj] = C[vi, vj] + BufferStore new_init_store(epilogue_output_buffer_, Substitute(epilogue_addend_, var_map), + Substitute(epilogue_output_indices_, var_map)); + new_block->init = new_init_store; + + // 3. Replace output buffer from temp to D in body + class BufferReplacer : public StmtExprMutator { + public: + BufferReplacer(Buffer old_buf, Buffer new_buf) : old_buffer_(old_buf), new_buffer_(new_buf) {} + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + if (store->buffer.same_as(old_buffer_)) { + return BufferStore(new_buffer_, store->value, store->indices); + } + return store; + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); + if (load->buffer.same_as(old_buffer_)) { + return BufferLoad(new_buffer_, load->indices); + } + return load; + } + + private: + Buffer old_buffer_; + Buffer new_buffer_; + }; + + BufferReplacer replacer(inlined_buffer_, epilogue_output_buffer_); + new_block->body = replacer(reduction_block->body); + + // 4. Update write regions + ffi::Array new_writes; + for (const BufferRegion& write : reduction_block->writes) { + if (write->buffer.same_as(inlined_buffer_)) { + new_writes.push_back( + BufferRegion(epilogue_output_buffer_, Substitute(write->region, var_map))); + } else { + new_writes.push_back(write); + } + } + new_block->writes = new_writes; + + // 5. Update read regions (C first, then A, B) + ffi::Array new_reads; + std::unordered_set read_bufs; + + // Add C buffer read first (used in init) + if (epilogue_addend_buffer_.defined()) { + new_reads.push_back(BufferRegion(epilogue_addend_buffer_, + Substitute(epilogue_addend_region_->region, var_map))); + read_bufs.insert(epilogue_addend_buffer_.get()); + } + + // Add existing read regions (A, B, etc.) + for (const BufferRegion& read : reduction_block->reads) { + if (!read->buffer.same_as(inlined_buffer_)) { + // Only add non-temp buffers + if (read_bufs.find(read->buffer.get()) == read_bufs.end()) { + new_reads.push_back(read); + read_bufs.insert(read->buffer.get()); + } + } + } + + new_block->reads = new_reads; + + return Block(new_block); +} + +/*! + * \brief Check if a buffer is still referenced by other blocks in the scope + */ +static bool CheckBufferStillUsed(const Block& scope_root, const Buffer& buffer) { + class BufferUsageChecker : public StmtVisitor { + public: + explicit BufferUsageChecker(const Buffer& buffer) : buffer_(buffer) {} + + bool CheckStmt(const Stmt& stmt) { + found_usage_ = false; + VisitStmt(stmt); + return found_usage_; + } + + private: + void VisitStmt_(const BlockRealizeNode* op) final { + if (found_usage_) return; + + if (!op || !op->block.defined()) { + StmtVisitor::VisitStmt_(op); + return; + } + + const BlockNode* block = op->block.get(); + if (!block) { + StmtVisitor::VisitStmt_(op); + return; + } + + // Check reads + for (const BufferRegion& read : block->reads) { + if (read->buffer.same_as(buffer_)) { + found_usage_ = true; + return; + } + } + + // Check writes + for (const BufferRegion& write : block->writes) { + if (write->buffer.same_as(buffer_)) { + found_usage_ = true; + return; + } + } + + // Continue visiting nested blocks + StmtVisitor::VisitStmt_(op); + } + + void VisitStmt_(const BlockNode* op) final { + if (found_usage_) return; + if (!op) return; + + // Check alloc_buffers + for (const Buffer& buf : op->alloc_buffers) { + if (buf.same_as(buffer_)) { + found_usage_ = true; + return; + } + } + + StmtVisitor::VisitStmt_(op); + } + + const Buffer& buffer_; + bool found_usage_{false}; + }; + + if (!scope_root->body.defined()) { + return false; + } + + BufferUsageChecker checker(buffer); + return checker.CheckStmt(scope_root->body); +} + +/*! + * \brief Helper class to replace reduction and epilogue blocks with a single fused block + */ +class SingleBlockFusionReplacer : public StmtMutator { + public: + static Block Replace(Block old_scope_root, Block new_fused_block, Block old_reduction_block, + Block old_epilogue_block, Buffer reduction_buffer) { + SingleBlockFusionReplacer replacer(std::move(new_fused_block), std::move(old_reduction_block), + std::move(old_epilogue_block), std::move(reduction_buffer)); + Block result = Downcast(replacer(std::move(old_scope_root))); + + // Check if reduction_buffer is still referenced by other blocks + bool buffer_still_used = CheckBufferStillUsed(result, reduction_buffer); + + // Remove intermediate temp buffer only if it's not used by other blocks + if (!buffer_still_used) { + BlockNode* p = result.CopyOnWrite(); + ffi::Array new_alloc_buffers; + for (const Buffer& buf : p->alloc_buffers) { + if (!buf.same_as(reduction_buffer)) { + new_alloc_buffers.push_back(buf); + } + } + p->alloc_buffers = new_alloc_buffers; + } + + return result; + } + + private: + explicit SingleBlockFusionReplacer(Block new_fused_block, Block old_reduction_block, + Block old_epilogue_block, Buffer reduction_buffer) + : new_fused_block_(std::move(new_fused_block)), + old_reduction_block_(std::move(old_reduction_block)), + old_epilogue_block_(std::move(old_epilogue_block)), + reduction_buffer_(std::move(reduction_buffer)) {} + + Stmt VisitStmt_(const ForNode* loop) final { + Stmt mutated_body = StmtMutator::VisitStmt(loop->body); + // Remove empty loops (containing only Evaluate(0)) + if (mutated_body.as()) { + return mutated_body; // Return Evaluate(0) to be removed by SeqStmt + } + + return For(loop->loop_var, loop->min, loop->extent, loop->kind, mutated_body, + loop->thread_binding, loop->annotations); + } + + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + if (realize->block.same_as(old_reduction_block_)) { + // Replace reduction block with new fused block + ObjectPtr new_realize = ffi::make_object(*realize); + new_realize->block = new_fused_block_; + return BlockRealize(new_realize); + } else if (realize->block.same_as(old_epilogue_block_)) { + // Remove epilogue block completely + return Evaluate(0); + } + return StmtMutator::VisitStmt_(realize); + } + + Stmt VisitStmt_(const SeqStmtNode* seq) final { + ffi::Array new_stmts; + for (const Stmt& stmt : seq->seq) { + Stmt new_stmt = VisitStmt(stmt); + // Remove Evaluate(0) + if (!new_stmt.as()) { + new_stmts.push_back(new_stmt); + } + } + return SeqStmt::Flatten(new_stmts); + } + + private: + Block new_fused_block_; + Block old_reduction_block_; + Block old_epilogue_block_; + Buffer reduction_buffer_; +}; + +void FuseReductionEpilogueImpl(ScheduleState self, const StmtSRef& reduction_block_sref, + const StmtSRef& epilogue_block_sref, bool check_only = false) { + const BlockNode* _reduction_block = TVM_SREF_TO_BLOCK(reduction_block_sref); + const BlockNode* _epilogue_block = TVM_SREF_TO_BLOCK(epilogue_block_sref); + + Block reduction_block = ffi::GetRef(_reduction_block); + Block epilogue_block = ffi::GetRef(_epilogue_block); + BlockRealize epilogue_block_realize = GetBlockRealize(self, epilogue_block_sref); + + // Step 1. Get the scope block + StmtSRef scope_root_sref = + GetScopeRoot(self, epilogue_block_sref, /*require_stage_pipeline=*/true); + + // Step 2. Get the reduction buffer (intermediate buffer) + Buffer reduction_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, reduction_block); + + // Step 3. Check completeness and reduction block properties + CheckReductionBlock(self, reduction_block_sref, scope_root_sref); + CheckCompleteBlock(self, epilogue_block_sref, scope_root_sref); + CheckNotOutputBlock(self, reduction_block_sref, scope_root_sref); + + // Step 4. Analyze the epilogue pattern + ReductionEpilogueFuser fuser(reduction_buffer, _reduction_block, epilogue_block_realize, + scope_root_sref, self->mod); + if (!fuser.BodyPatternAllowFusion(epilogue_block_realize)) { + throw BodyAnalysisError(true, self->mod, epilogue_block); + } + + if (check_only) { + return; + } + + // Step 5. Create single fused reduction block + BlockRealize reduction_realize = GetBlockRealize(self, reduction_block_sref); + Block fused_block = fuser.CreateFusedReductionBlock(_reduction_block, reduction_realize.get()); + + // Step 6. Transform and replace IR + const BlockNode* old_scope_root = TVM_SREF_TO_BLOCK(scope_root_sref); + + Block new_scope_root = + SingleBlockFusionReplacer::Replace(ffi::GetRef(old_scope_root), fused_block, + reduction_block, epilogue_block, reduction_buffer); + + // Step 7. Update schedule state + ffi::Map block_reuse; + block_reuse.Set(ffi::GetRef(old_scope_root), new_scope_root); + block_reuse.Set(reduction_block, fused_block); + self->Replace(scope_root_sref, new_scope_root, block_reuse); + + // Step 8. Update BlockInfo + self->UpdateScopeBlockInfo(GetBlockRealize(self, scope_root_sref)); +} + +void FuseReductionEpilogue(ScheduleState self, const StmtSRef& reduction_block_sref, + const StmtSRef& epilogue_block_sref) { + FuseReductionEpilogueImpl(self, reduction_block_sref, epilogue_block_sref); +} + /******** InstructionKind Registration ********/ struct ComputeInlineTraits : public UnpackedInstTraits { @@ -1035,5 +1498,34 @@ struct ReverseComputeInlineTraits : public UnpackedInstTraits { + static constexpr const char* kName = "FuseReductionEpilogue"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, BlockRV reduction_block_rv, + BlockRV epilogue_block_rv) { + return sch->FuseReductionEpilogue(reduction_block_rv, epilogue_block_rv); + } + + static ffi::String UnpackedAsPython(ffi::Array outputs, + ffi::String reduction_block_rv, + ffi::String epilogue_block_rv) { + PythonAPICall py("fuse_reduction_epilogue"); + py.Input("reduction_block", reduction_block_rv); + py.Input("epilogue_block", epilogue_block_rv); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(FuseReductionEpilogueTraits); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 845bbb5cc278..35b221561978 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -227,7 +227,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def_method("tir.schedule.ScheduleComputeAt", &ScheduleNode::ComputeAt) .def_method("tir.schedule.ScheduleReverseComputeAt", &ScheduleNode::ReverseComputeAt) .def_method("tir.schedule.ScheduleComputeInline", &ScheduleNode::ComputeInline) - .def_method("tir.schedule.ScheduleReverseComputeInline", &ScheduleNode::ReverseComputeInline); + .def_method("tir.schedule.ScheduleReverseComputeInline", &ScheduleNode::ReverseComputeInline) + .def_method("tir.schedule.ScheduleFuseReductionEpilogue", + &ScheduleNode::FuseReductionEpilogue); } /******** (FFI) Reduction ********/ TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 8129f43833c4..72606f243d69 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -532,6 +532,17 @@ void TracedScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { /*outputs=*/{})); } +void TracedScheduleNode::FuseReductionEpilogue(const BlockRV& reduction_block_rv, + const BlockRV& epilogue_block_rv) { + ConcreteScheduleNode::FuseReductionEpilogue(reduction_block_rv, epilogue_block_rv); + + static const InstructionKind& kind = InstructionKind::Get("FuseReductionEpilogue"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{reduction_block_rv, epilogue_block_rv}, + /*attrs=*/{}, + /*outputs=*/{})); +} + /******** Schedule: Reduction ********/ BlockRV TracedScheduleNode::DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) { diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 0b91dc283392..8c7b16a47e8d 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -109,6 +109,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { int index = -1) final; void ComputeInline(const BlockRV& block_rv) final; void ReverseComputeInline(const BlockRV& block_rv) final; + void FuseReductionEpilogue(const BlockRV& reduction_block, const BlockRV& epilogue_block) final; /******** Schedule: Reduction ********/ BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) final; BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) final; diff --git a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py new file mode 100644 index 000000000000..82a488851ae7 --- /dev/null +++ b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py @@ -0,0 +1,218 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import ( + verify_trace_roundtrip, + assert_structural_equal_ignore_global_symbol, +) +import numpy as np + +# pylint: disable=no-member,invalid-name,unused-variable + + +@T.prim_func +def matmul_bias_before( + A: T.Buffer((16, 16), "int8"), + B: T.Buffer((16, 16), "int8"), + C: T.Buffer((16, 16), "int32"), + D: T.Buffer((16, 16), "int32"), +) -> None: + temp = T.alloc_buffer((16, 16), dtype="int32") + for i, j, k in T.grid(16, 16, 16): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.int32(0) + temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") + for i, j in T.grid(16, 16): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = temp[vi, vj] + C[vi, vj] + + +@T.prim_func +def matmul_bias_expected( + A: T.Buffer((16, 16), "int8"), + B: T.Buffer((16, 16), "int8"), + C: T.Buffer((16, 16), "int32"), + D: T.Buffer((16, 16), "int32"), +) -> None: + temp = T.alloc_buffer((16, 16), dtype="int32") + for i, j, k in T.grid(16, 16, 16): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) + T.writes(D[vi, vj]) + with T.init(): + D[vi, vj] = C[vi, vj] + D[vi, vj] = D[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") + + +@T.prim_func +def matmul_bias_fp32_before( + A: T.Buffer((32, 32), "float32"), + B: T.Buffer((32, 32), "float32"), + C: T.Buffer((32, 32), "float32"), + D: T.Buffer((32, 32), "float32"), +) -> None: + temp = T.alloc_buffer((32, 32), dtype="float32") + for i, j, k in T.grid(32, 32, 32): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.float32(0) + temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk] + for i, j in T.grid(32, 32): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = temp[vi, vj] + C[vi, vj] + + +@T.prim_func +def matmul_bias_fp32_expected( + A: T.Buffer((32, 32), "float32"), + B: T.Buffer((32, 32), "float32"), + C: T.Buffer((32, 32), "float32"), + D: T.Buffer((32, 32), "float32"), +) -> None: + temp = T.alloc_buffer((32, 32), dtype="float32") + for i, j, k in T.grid(32, 32, 32): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) + T.writes(D[vi, vj]) + with T.init(): + D[vi, vj] = C[vi, vj] + D[vi, vj] = D[vi, vj] + A[vi, vk] * B[vj, vk] + + +@T.prim_func +def matmul_bias_multiple_epilogue_before( + A: T.Buffer((16, 16), "int8"), + B: T.Buffer((16, 16), "int8"), + C: T.Buffer((16, 16), "int32"), + D: T.Buffer((16, 16), "int32"), + E: T.Buffer((16, 16), "int32"), +) -> None: + temp = T.alloc_buffer((16, 16), dtype="int32") + for i, j, k in T.grid(16, 16, 16): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + temp[vi, vj] = T.int32(0) + temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") + for i, j in T.grid(16, 16): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = temp[vi, vj] + C[vi, vj] + for i, j in T.grid(16, 16): + with T.block("add2"): + vi, vj = T.axis.remap("SS", [i, j]) + E[vi, vj] = temp[vi, vj] + C[vi, vj] + + +@T.prim_func +def matmul_bias_multiple_epilogue_expected( + A: T.Buffer((16, 16), "int8"), + B: T.Buffer((16, 16), "int8"), + C: T.Buffer((16, 16), "int32"), + D: T.Buffer((16, 16), "int32"), + E: T.Buffer((16, 16), "int32"), +) -> None: + temp = T.alloc_buffer((16, 16), dtype="int32") + for i, j, k in T.grid(16, 16, 16): + with T.block("multiply"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) + T.writes(D[vi, vj]) + with T.init(): + D[vi, vj] = C[vi, vj] + D[vi, vj] = D[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj, vk], "int32") + for i, j in T.grid(16, 16): + with T.block("add2"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(temp[vi, vj], C[vi, vj]) + T.writes(E[vi, vj]) + E[vi, vj] = temp[vi, vj] + C[vi, vj] + + +def test_fuse_reduction_epilogue_basic(): + sch = tir.Schedule(matmul_bias_before, debug_mask="all") + sch.fuse_reduction_epilogue("multiply", "add") + assert_structural_equal_ignore_global_symbol(sch.mod["main"], matmul_bias_expected) + verify_trace_roundtrip(sch=sch, mod=matmul_bias_before) + + +def test_fuse_reduction_epilogue_fp32(): + sch = tir.Schedule(matmul_bias_fp32_before, debug_mask="all") + sch.fuse_reduction_epilogue("multiply", "add") + assert_structural_equal_ignore_global_symbol(sch.mod["main"], matmul_bias_fp32_expected) + verify_trace_roundtrip(sch=sch, mod=matmul_bias_fp32_before) + + +def test_fuse_reduction_epilogue_numerical_correctness(): + sch_original = tir.Schedule(matmul_bias_before, debug_mask="all") + mod_original = tvm.compile(sch_original.mod["main"], target="llvm") + + sch_fused = tir.Schedule(matmul_bias_before, debug_mask="all") + sch_fused.fuse_reduction_epilogue("multiply", "add") + mod_fused = tvm.compile(sch_fused.mod["main"], target="llvm") + + A_np = np.random.randint(-128, 127, size=(16, 16), dtype="int8") + B_np = np.random.randint(-128, 127, size=(16, 16), dtype="int8") + C_np = np.random.randint(-1000, 1000, size=(16, 16), dtype="int32") + + expected = (A_np.astype("int32") @ B_np.T.astype("int32")) + C_np + + D_original_tvm = tvm.runtime.tensor(np.zeros((16, 16), dtype="int32")) + D_fused_tvm = tvm.runtime.tensor(np.zeros((16, 16), dtype="int32")) + + mod_original( + tvm.runtime.tensor(A_np), tvm.runtime.tensor(B_np), tvm.runtime.tensor(C_np), D_original_tvm + ) + + mod_fused( + tvm.runtime.tensor(A_np), tvm.runtime.tensor(B_np), tvm.runtime.tensor(C_np), D_fused_tvm + ) + + D_original = D_original_tvm.numpy() + D_fused = D_fused_tvm.numpy() + + np.testing.assert_allclose(D_original, expected, rtol=1e-5) + np.testing.assert_allclose(D_fused, expected, rtol=1e-5) + np.testing.assert_allclose(D_fused, D_original, rtol=1e-5) + + +def test_fuse_reduction_epilogue_multiple_epilogue(): + sch = tir.Schedule(matmul_bias_multiple_epilogue_before, debug_mask="all") + sch.fuse_reduction_epilogue("multiply", "add") + assert_structural_equal_ignore_global_symbol( + sch.mod["main"], matmul_bias_multiple_epilogue_expected + ) + verify_trace_roundtrip(sch=sch, mod=matmul_bias_multiple_epilogue_before) + + mod = tvm.compile(sch.mod["main"], target="llvm") + assert mod is not None + + +if __name__ == "__main__": + tvm.testing.main()