Commit f02dedb
hyun gyu kim
Currently it is not possible to fuse an epilogue operation (e.g., bias addition)
into a reduction block's initialization statement. This limitation prevents
leveraging hardware-specific instructions that support bias accumulation in
vector ISAs, such as MACC (multiply-accumulate with bias) instructions.
This commit implements a new schedule primitive 'fuse_reduction_epilogue' that
addresses the problem described in:
https://discuss.tvm.apache.org/t/tir-problem-inlining-addition-into-matmul-block/18066
The primitive transforms the following pattern:
Before:
for i, j, k in T.grid(M, N, K):
with T.block("matmul"):
with T.init():
temp[vi, vj] = 0
temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk]
for i, j in T.grid(M, N):
with T.block("bias_add"):
D[vi, vj] = temp[vi, vj] + C[vi, vj]
After:
for i, j, k in T.grid(M, N, K):
with T.block("matmul"):
T.reads(C[vi, vj], A[vi, vk], B[vj, vk])
T.writes(D[vi, vj])
with T.init():
D[vi, vj] = C[vi, vj] # Fused epilogue into init
D[vi, vj] = D[vi, vj] + A[vi, vk] * B[vj, vk]
The transformation removes the intermediate temp buffer and the separate
epilogue block, enabling better tensorization opportunities for hardware
with bias accumulation support.
To resolve the issue where multiple epilogue blocks use the same reduction
output, we modified the code to handle multiple epilogue blocks cases by
adding CheckBufferStillUsed function that checks if other blocks still
reference the reduction buffer, and modified to keep the temp buffer if it's
still referenced. This ensures that when fusing one epilogue block, other
epilogue blocks that still use the intermediate buffer continue to work
correctly.
Implementation:
- ReductionEpilogueFuser class for pattern validation and IR transformation
- BodyPatternAllowFusion: Validates epilogue can be fused
- AnalyzeEpiloguePattern: Detects addition pattern (D = temp + C)
- ExtractEpilogueInfo: Extracts buffer and region information
- CreateFusedReductionBlock: Creates single block with modified T.init()
- SingleBlockFusionReplacer: Replaces blocks and removes temp buffer
- CheckBufferStillUsed: Helper function to check if reduction buffer is
still referenced by other blocks after fusion
- Conditionally removes temp buffer only if no other blocks reference it
- Variable mapping between epilogue and reduction block iter vars
- Proper buffer and region updates with correct read/write ordering
- FFI bindings and Python API following TVM conventions
Changes:
- src/tir/schedule/primitive/compute_inline.cc: Core implementation (~430 lines)
- src/tir/schedule/primitive.h: Function declaration
- include/tvm/tir/schedule/schedule.h: Virtual method in ScheduleNode
- src/tir/schedule/concrete_schedule.{h,cc}: ConcreteScheduleNode implementation
- src/tir/schedule/traced_schedule.{h,cc}: TracedScheduleNode implementation
- src/tir/schedule/schedule.cc: FFI binding registration
- python/tvm/tir/schedule/schedule.py: Python API with documentation
- tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py:
Comprehensive tests including basic fusion, float32 variant, numerical
correctness verification, trace roundtrip validation, and multiple epilogue
blocks test case
Tests can be verified through test_fuse_reduction_epilogue_multiple_epilogue
function in tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py.
Tests can be run using:
python -m pytest tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py1 parent 71ee6b1 commit f02dedb
File tree
1 file changed
+218
-214
lines changed- tests/python/tir-schedule
1 file changed
+218
-214
lines changed
0 commit comments