Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,13 @@ class ScheduleNode : public runtime::Object {
* \param block The block to be inlined to its producer
*/
virtual void ReverseComputeInline(const BlockRV& block) = 0;
/*!
* \brief Fuse an epilogue block into a reduction block
* \param reduction_block The reduction block (e.g., matmul)
* \param epilogue_block The epilogue block to be fused (e.g., bias add)
*/
virtual void FuseReductionEpilogue(const BlockRV& reduction_block,
const BlockRV& epilogue_block) = 0;
/******** Schedule: Reduction ********/
/*!
* \brief Decompose a reduction block into two separate blocks.
Expand Down
27 changes: 27 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2345,6 +2345,33 @@ def after_inline(a: T.handle, c: T.handle) -> None:
# pylint: disable-next=no-member
_ffi_api.ScheduleReverseComputeInline(self, block) # type: ignore

@type_checked
def fuse_reduction_epilogue(
self,
reduction_block: Union[BlockRV, str],
epilogue_block: Union[BlockRV, str],
) -> None:
"""Fuse an epilogue block into a reduction block.

It requires:
1) The reduction block is a complete reduction block
2) The epilogue block only reads from the reduction block's output
3) The epilogue performs a simple addition: output = reduction_result + bias

Parameters
----------
reduction_block : Union[BlockRV, str]
The reduction block (e.g., matmul)
epilogue_block : Union[BlockRV, str]
The epilogue block to be fused (e.g., bias add)
"""
reduction_block = self._normalize_block_arg(reduction_block)
epilogue_block = self._normalize_block_arg(epilogue_block)
# pylint: disable-next=no-member
_ffi_api.ScheduleFuseReductionEpilogue(
self, reduction_block, epilogue_block
) # type: ignore

########## Schedule: Reduction ##########

@type_checked
Expand Down
9 changes: 9 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,15 @@ void ConcreteScheduleNode::ReverseComputeInline(const BlockRV& block_rv) {
this->state_->DebugVerify();
}

void ConcreteScheduleNode::FuseReductionEpilogue(const BlockRV& reduction_block_rv,
const BlockRV& epilogue_block_rv) {
TVM_TIR_SCHEDULE_BEGIN();
tir::FuseReductionEpilogue(state_, this->GetSRef(reduction_block_rv),
this->GetSRef(epilogue_block_rv));
TVM_TIR_SCHEDULE_END("fuse-reduction-epilogue", this->error_render_level_);
this->state_->DebugVerify();
}

/******** Schedule: Block Annotation ********/

void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, int axis,
Expand Down
2 changes: 2 additions & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ class ConcreteScheduleNode : public ScheduleNode {
int index = -1) override;
void ComputeInline(const BlockRV& block) override;
void ReverseComputeInline(const BlockRV& block) override;
void FuseReductionEpilogue(const BlockRV& reduction_block,
const BlockRV& epilogue_block) override;
/******** Schedule: Reduction ********/
BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override;
BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) override;
Expand Down
8 changes: 8 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,14 @@ TVM_DLL void ComputeInline(ScheduleState self, const StmtSRef& block_sref);
* \param block_sref The sref to the block to be inlined to its producer
*/
TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref);
/*!
* \brief Fuse an epilogue block into a reduction block
* \param self The state of the schedule
* \param reduction_block_sref The sref to the reduction block
* \param epilogue_block_sref The sref to the epilogue block to be fused
*/
TVM_DLL void FuseReductionEpilogue(ScheduleState self, const StmtSRef& reduction_block_sref,
const StmtSRef& epilogue_block_sref);
/******** Schedule: Reduction ********/
/*!
* \brief Decompose a reduction block into two separate blocks.
Expand Down
Loading