[TLX] Interleave TMA stores across MMA groups in Blackwell GEMM epilogue#1003
[TLX] Interleave TMA stores across MMA groups in Blackwell GEMM epilogue#1003htyu wants to merge 1 commit intofacebookexperimental:mainfrom
Conversation
…gue (facebookexperimental#1003) Summary: Add an interleaved epilogue mode to the Blackwell warp-specialized GEMM kernel. When `INTERLEAVE_EPILOGUE=1`, the epilogue alternates TMA stores between MMA group 0 and group 1 instead of draining each group sequentially. This overlaps the TMA store latency of one group with the TMEM read of the other, improving memory throughput on store-bound shapes. The interleaved path is enabled by default for GPU-saturated shapes and for tall-M shapes with small K (low arithmetic intensity). It is disabled for Split-K configs (which use atomic reductions) and for the tall-M high-arithmetic-intensity path with BLOCK_K=128. Autotuning is also extended to explore `INTERLEAVE_EPILOGUE` in {0, 1}, with config pruning updated to filter invalid combinations (interleave requires `NUM_MMA_GROUPS == 2` and `SPLIT_K == 1`). Perf on B200 (tflops): ``` aten tlx_matmul_ws (M, N, K) matmul before after delta (8192, 8192, 8192) 1142.09 1168.46 1182.49 +1.2% (3159809, 384, 384) 647.23 664.21 644.05 -3.0% (1152, 12800, 32768) 1124.60 1076.00 1069.98 -0.6% (1024, 256, 16384) 363.73 209.39 209.72 +0.2% (560849, 512, 896) 889.47 898.91 938.96 +4.5% (589824, 512, 2048) 915.12 959.46 959.39 -0.0% (1152, 65536, 1024) 1071.84 926.53 962.52 +3.9% (8192, 4608, 6144) 1170.41 1176.01 1195.01 +1.6% (16384, 11264, 5632) 1089.06 1132.88 1149.37 +1.5% (8192, 8192, 2048) 1193.88 1141.82 1162.53 +1.8% average 960.74 935.37 947.40 +1.3% ``` Differential Revision: D94608909
…gue (facebookexperimental#1003) Summary: Add an interleaved epilogue mode to the Blackwell warp-specialized GEMM kernel. When `INTERLEAVE_EPILOGUE=1`, the epilogue alternates TMA stores between MMA group 0 and group 1 instead of draining each group sequentially. This overlaps the TMA store latency of one group with the TMEM read of the other, improving memory throughput on store-bound shapes. The interleaved path is enabled by default for GPU-saturated shapes and for tall-M shapes with small K (low arithmetic intensity). It is disabled for Split-K configs (which use atomic reductions) and for the tall-M high-arithmetic-intensity path with BLOCK_K=128. Autotuning is also extended to explore `INTERLEAVE_EPILOGUE` in {0, 1}, with config pruning updated to filter invalid combinations (interleave requires `NUM_MMA_GROUPS == 2` and `SPLIT_K == 1`). Perf on B200 (tflops): ``` aten tlx_matmul_ws (M, N, K) matmul before after delta (8192, 8192, 8192) 1142.09 1168.46 1182.49 +1.2% (3159809, 384, 384) 647.23 664.21 644.05 -3.0% (1152, 12800, 32768) 1124.60 1076.00 1069.98 -0.6% (1024, 256, 16384) 363.73 209.39 209.72 +0.2% (560849, 512, 896) 889.47 898.91 938.96 +4.5% (589824, 512, 2048) 915.12 959.46 959.39 -0.0% (1152, 65536, 1024) 1071.84 926.53 962.52 +3.9% (8192, 4608, 6144) 1170.41 1176.01 1195.01 +1.6% (16384, 11264, 5632) 1089.06 1132.88 1149.37 +1.5% (8192, 8192, 2048) 1193.88 1141.82 1162.53 +1.8% average 960.74 935.37 947.40 +1.3% ``` Differential Revision: D94608909
…gue (facebookexperimental#1003) Summary: Add an interleaved epilogue mode to the Blackwell warp-specialized GEMM kernel. When `INTERLEAVE_EPILOGUE=1`, the epilogue alternates TMA stores between MMA group 0 and group 1 instead of draining each group sequentially. This overlaps the TMA store latency of one group with the TMEM read of the other, improving memory throughput on store-bound shapes. The interleaved path is enabled by default for GPU-saturated shapes and for tall-M shapes with small K (low arithmetic intensity). It is disabled for Split-K configs (which use atomic reductions) and for the tall-M high-arithmetic-intensity path with BLOCK_K=128. Autotuning is also extended to explore `INTERLEAVE_EPILOGUE` in {0, 1}, with config pruning updated to filter invalid combinations (interleave requires `NUM_MMA_GROUPS == 2` and `SPLIT_K == 1`). Perf on B200 (tflops): ``` aten tlx_matmul_ws (M, N, K) matmul before after delta (8192, 8192, 8192) 1142.09 1168.46 1182.49 +1.2% (3159809, 384, 384) 647.23 664.21 644.05 -3.0% (1152, 12800, 32768) 1124.60 1076.00 1069.98 -0.6% (1024, 256, 16384) 363.73 209.39 209.72 +0.2% (560849, 512, 896) 889.47 898.91 938.96 +4.5% (589824, 512, 2048) 915.12 959.46 959.39 -0.0% (1152, 65536, 1024) 1071.84 926.53 962.52 +3.9% (8192, 4608, 6144) 1170.41 1176.01 1195.01 +1.6% (16384, 11264, 5632) 1089.06 1132.88 1149.37 +1.5% (8192, 8192, 2048) 1193.88 1141.82 1162.53 +1.8% average 960.74 935.37 947.40 +1.3% ``` Differential Revision: D94608909
3873895 to
270d56a
Compare
njriasan
left a comment
There was a problem hiding this comment.
Overall this looks good. A couple of minor suggestions, but these can be done as followup or ignored.
| c, | ||
| store_reduce="add", | ||
| c_smem = c_smem_buffers[0] | ||
| tlx.async_descriptor_store_wait(1) |
There was a problem hiding this comment.
Am I correct in saying that hardcoding this to 1 assumes that if we ever wanted larger blocks we would subtile less (e.g. we don't want 2 blocks per MMA subtiled by 4). If that's something we may consider then may be we can parameterize this?
It also seems like with some small refactoring this would work for both 1 and 2 MMAs since the pattern is simple. What do you think?
There was a problem hiding this comment.
The 1 here only works for two MMA groups, as with only one group, TMA stores for different subtiles share the same underlying smem buffer, and overlapping them could be an issue.
With one group we can use do multi-buffering between the TMA stores. That's something I'm working on separately.
There was a problem hiding this comment.
On the one group I was partially suggesting that tlx.async_descriptor_store_wait(NUM_MMAs - 1) would make the code general. Not a functionality overlap but a code cleanup fix.
|
This pull request has been merged in 17af1ee. |
…gue (#1003) Summary: Pull Request resolved: #1003 Add an interleaved epilogue mode to the Blackwell warp-specialized GEMM kernel. When `INTERLEAVE_EPILOGUE=1`, the epilogue alternates TMA stores between MMA group 0 and group 1 instead of draining each group sequentially. This overlaps the TMA store latency of one group with the TMEM read of the other, improving memory throughput on store-bound shapes. The interleaved path is enabled by default for GPU-saturated shapes and for tall-M shapes with small K (low arithmetic intensity). It is disabled for Split-K configs (which use atomic reductions) and for the tall-M high-arithmetic-intensity path with BLOCK_K=128. Autotuning is also extended to explore `INTERLEAVE_EPILOGUE` in {0, 1}, with config pruning updated to filter invalid combinations (interleave requires `NUM_MMA_GROUPS == 2` and `SPLIT_K == 1`). Perf on B200 (tflops): ``` aten tlx_matmul_ws (M, N, K) matmul before after delta (8192, 8192, 8192) 1142.09 1168.46 1182.49 +1.2% (3159809, 384, 384) 647.23 664.21 644.05 -3.0% (1152, 12800, 32768) 1124.60 1076.00 1069.98 -0.6% (1024, 256, 16384) 363.73 209.39 209.72 +0.2% (560849, 512, 896) 889.47 898.91 938.96 +4.5% (589824, 512, 2048) 915.12 959.46 959.39 -0.0% (1152, 65536, 1024) 1071.84 926.53 962.52 +3.9% (8192, 4608, 6144) 1170.41 1176.01 1195.01 +1.6% (16384, 11264, 5632) 1089.06 1132.88 1149.37 +1.5% (8192, 8192, 2048) 1193.88 1141.82 1162.53 +1.8% average 960.74 935.37 947.40 +1.3% ``` Reviewed By: njriasan Differential Revision: D94608909 fbshipit-source-id: 1d41228ab875a00a7050e1025d838e447833a368
Summary:
Add an interleaved epilogue mode to the Blackwell warp-specialized GEMM
kernel. When
INTERLEAVE_EPILOGUE=1, the epilogue alternates TMA storesbetween MMA group 0 and group 1 instead of draining each group
sequentially. This overlaps the TMA store latency of one group with the
TMEM read of the other, improving memory throughput on store-bound shapes.
The interleaved path is enabled by default for GPU-saturated shapes and
for tall-M shapes with small K (low arithmetic intensity). It is disabled
for Split-K configs (which use atomic reductions) and for the tall-M
high-arithmetic-intensity path with BLOCK_K=128.
Autotuning is also extended to explore
INTERLEAVE_EPILOGUEin {0, 1},with config pruning updated to filter invalid combinations (interleave
requires
NUM_MMA_GROUPS == 2andSPLIT_K == 1).Perf on B200 (tflops):
Differential Revision: D94608909