Commit 2032e71
[TIR][Schedule] Add FuseReductionEpilogue primitive to fuse epilogue … (#18418)
Currently it is not possible to fuse an epilogue operation (e.g., bias
addition) into a reduction block's initialization statement. This
limitation prevents leveraging hardware-specific instructions that
support bias accumulation in vector ISAs, such as MACC
(multiply-accumulate with bias) instructions.
This commit implements a new schedule primitive
'fuse_reduction_epilogue' that addresses the problem described in:
https://discuss.tvm.apache.org/t/tir-problem-inlining-addition-into-matmul-block/18066
The primitive transforms the following pattern:
Before:
for i, j, k in T.grid(M, N, K):
with T.block("matmul"):
with T.init():
temp[vi, vj] = 0
temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk]
for i, j in T.grid(M, N):
with T.block("bias_add"):
D[vi, vj] = temp[vi, vj] + C[vi, vj]
After:
for i, j, k in T.grid(M, N, K):
with T.block("matmul"):
T.reads(C[vi, vj], A[vi, vk], B[vj, vk])
T.writes(D[vi, vj])
with T.init():
D[vi, vj] = C[vi, vj] # Fused epilogue into init
D[vi, vj] = D[vi, vj] + A[vi, vk] * B[vj, vk]
The transformation removes the intermediate temp buffer and the separate
epilogue block, enabling better tensorization opportunities for hardware
with bias accumulation support.
Implementation:
- ReductionEpilogueFuser class for pattern validation and IR
transformation
- BodyPatternAllowFusion: Validates epilogue can be fused
- AnalyzeEpiloguePattern: Detects addition pattern (D = temp + C)
- ExtractEpilogueInfo: Extracts buffer and region information
- CreateFusedReductionBlock: Creates single block with modified T.init()
- SingleBlockFusionReplacer: Replaces blocks and removes temp buffer
- Variable mapping between epilogue and reduction block iter vars
- Proper buffer and region updates with correct read/write ordering
- FFI bindings and Python API following TVM conventions
Changes:
- src/tir/schedule/primitive/compute_inline.cc: Core implementation
(~430 lines)
- src/tir/schedule/primitive.h: Function declaration
- include/tvm/tir/schedule/schedule.h: Virtual method in ScheduleNode
- src/tir/schedule/concrete_schedule.{h,cc}: ConcreteScheduleNode
implementation
- src/tir/schedule/traced_schedule.{h,cc}: TracedScheduleNode
implementation
- src/tir/schedule/schedule.cc: FFI binding registration
- python/tvm/tir/schedule/schedule.py: Python API with documentation
-
tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py:
Comprehensive tests including basic fusion, float32 variant, numerical
correctness verification, and trace roundtrip validation
Run tests with:
pytest
tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue.py
-v
And, Could you please also take a look at #18240? Thx :)
---------
Co-authored-by: hyun gyu kim <[email protected]>1 parent faab2e7 commit 2032e71
File tree
10 files changed
+778
-1
lines changed- include/tvm/tir/schedule
- python/tvm/tir/schedule
- src/tir/schedule
- primitive
- tests/python/tir-schedule
10 files changed
+778
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
608 | 608 | | |
609 | 609 | | |
610 | 610 | | |
| 611 | + | |
| 612 | + | |
| 613 | + | |
| 614 | + | |
| 615 | + | |
| 616 | + | |
| 617 | + | |
611 | 618 | | |
612 | 619 | | |
613 | 620 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2345 | 2345 | | |
2346 | 2346 | | |
2347 | 2347 | | |
| 2348 | + | |
| 2349 | + | |
| 2350 | + | |
| 2351 | + | |
| 2352 | + | |
| 2353 | + | |
| 2354 | + | |
| 2355 | + | |
| 2356 | + | |
| 2357 | + | |
| 2358 | + | |
| 2359 | + | |
| 2360 | + | |
| 2361 | + | |
| 2362 | + | |
| 2363 | + | |
| 2364 | + | |
| 2365 | + | |
| 2366 | + | |
| 2367 | + | |
| 2368 | + | |
| 2369 | + | |
| 2370 | + | |
| 2371 | + | |
| 2372 | + | |
| 2373 | + | |
| 2374 | + | |
2348 | 2375 | | |
2349 | 2376 | | |
2350 | 2377 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
832 | 832 | | |
833 | 833 | | |
834 | 834 | | |
| 835 | + | |
| 836 | + | |
| 837 | + | |
| 838 | + | |
| 839 | + | |
| 840 | + | |
| 841 | + | |
| 842 | + | |
| 843 | + | |
835 | 844 | | |
836 | 845 | | |
837 | 846 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
147 | 147 | | |
148 | 148 | | |
149 | 149 | | |
| 150 | + | |
| 151 | + | |
150 | 152 | | |
151 | 153 | | |
152 | 154 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
509 | 509 | | |
510 | 510 | | |
511 | 511 | | |
| 512 | + | |
| 513 | + | |
| 514 | + | |
| 515 | + | |
| 516 | + | |
| 517 | + | |
| 518 | + | |
| 519 | + | |
512 | 520 | | |
513 | 521 | | |
514 | 522 | | |
| |||
0 commit comments