Skip to content

[TLX] Interleave TMA stores across MMA groups in Blackwell GEMM epilogue#1003

Closed
htyu wants to merge 1 commit intofacebookexperimental:mainfrom
htyu:export-D94608909
Closed

[TLX] Interleave TMA stores across MMA groups in Blackwell GEMM epilogue#1003
htyu wants to merge 1 commit intofacebookexperimental:mainfrom
htyu:export-D94608909

Conversation

@htyu
Copy link
Contributor

@htyu htyu commented Feb 27, 2026

Summary:
Add an interleaved epilogue mode to the Blackwell warp-specialized GEMM
kernel. When INTERLEAVE_EPILOGUE=1, the epilogue alternates TMA stores
between MMA group 0 and group 1 instead of draining each group
sequentially. This overlaps the TMA store latency of one group with the
TMEM read of the other, improving memory throughput on store-bound shapes.

The interleaved path is enabled by default for GPU-saturated shapes and
for tall-M shapes with small K (low arithmetic intensity). It is disabled
for Split-K configs (which use atomic reductions) and for the tall-M
high-arithmetic-intensity path with BLOCK_K=128.

Autotuning is also extended to explore INTERLEAVE_EPILOGUE in {0, 1},
with config pruning updated to filter invalid combinations (interleave
requires NUM_MMA_GROUPS == 2 and SPLIT_K == 1).

Perf on B200 (tflops):

                         aten      tlx_matmul_ws
           (M, N, K)  matmul   before    after   delta
  (8192, 8192, 8192) 1142.09  1168.46  1182.49  +1.2%
 (3159809, 384, 384)  647.23   664.21   644.05  -3.0%
(1152, 12800, 32768) 1124.60  1076.00  1069.98  -0.6%
  (1024, 256, 16384)  363.73   209.39   209.72  +0.2%
  (560849, 512, 896)  889.47   898.91   938.96  +4.5%
 (589824, 512, 2048)  915.12   959.46   959.39  -0.0%
 (1152, 65536, 1024) 1071.84   926.53   962.52  +3.9%
  (8192, 4608, 6144) 1170.41  1176.01  1195.01  +1.6%
(16384, 11264, 5632) 1089.06  1132.88  1149.37  +1.5%
  (8192, 8192, 2048) 1193.88  1141.82  1162.53  +1.8%
             average  960.74   935.37   947.40  +1.3%

Differential Revision: D94608909

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 27, 2026
@meta-codesync
Copy link

meta-codesync bot commented Feb 27, 2026

@htyu has exported this pull request. If you are a Meta employee, you can view the originating Diff in D94608909.

@htyu htyu requested a review from dshi7 February 27, 2026 21:21
htyu added a commit to htyu/triton-1 that referenced this pull request Feb 27, 2026
…gue (facebookexperimental#1003)

Summary:

Add an interleaved epilogue mode to the Blackwell warp-specialized GEMM
kernel. When `INTERLEAVE_EPILOGUE=1`, the epilogue alternates TMA stores
between MMA group 0 and group 1 instead of draining each group
sequentially. This overlaps the TMA store latency of one group with the
TMEM read of the other, improving memory throughput on store-bound shapes.

The interleaved path is enabled by default for GPU-saturated shapes and
for tall-M shapes with small K (low arithmetic intensity). It is disabled
for Split-K configs (which use atomic reductions) and for the tall-M
high-arithmetic-intensity path with BLOCK_K=128.

Autotuning is also extended to explore `INTERLEAVE_EPILOGUE` in {0, 1},
with config pruning updated to filter invalid combinations (interleave
requires `NUM_MMA_GROUPS == 2` and `SPLIT_K == 1`).

Perf on B200 (tflops):

```
                         aten      tlx_matmul_ws
           (M, N, K)  matmul   before    after   delta
  (8192, 8192, 8192) 1142.09  1168.46  1182.49  +1.2%
 (3159809, 384, 384)  647.23   664.21   644.05  -3.0%
(1152, 12800, 32768) 1124.60  1076.00  1069.98  -0.6%
  (1024, 256, 16384)  363.73   209.39   209.72  +0.2%
  (560849, 512, 896)  889.47   898.91   938.96  +4.5%
 (589824, 512, 2048)  915.12   959.46   959.39  -0.0%
 (1152, 65536, 1024) 1071.84   926.53   962.52  +3.9%
  (8192, 4608, 6144) 1170.41  1176.01  1195.01  +1.6%
(16384, 11264, 5632) 1089.06  1132.88  1149.37  +1.5%
  (8192, 8192, 2048) 1193.88  1141.82  1162.53  +1.8%
             average  960.74   935.37   947.40  +1.3%
```

Differential Revision: D94608909
htyu added a commit to htyu/triton-1 that referenced this pull request Feb 28, 2026
…gue (facebookexperimental#1003)

Summary:

Add an interleaved epilogue mode to the Blackwell warp-specialized GEMM
kernel. When `INTERLEAVE_EPILOGUE=1`, the epilogue alternates TMA stores
between MMA group 0 and group 1 instead of draining each group
sequentially. This overlaps the TMA store latency of one group with the
TMEM read of the other, improving memory throughput on store-bound shapes.

The interleaved path is enabled by default for GPU-saturated shapes and
for tall-M shapes with small K (low arithmetic intensity). It is disabled
for Split-K configs (which use atomic reductions) and for the tall-M
high-arithmetic-intensity path with BLOCK_K=128.

Autotuning is also extended to explore `INTERLEAVE_EPILOGUE` in {0, 1},
with config pruning updated to filter invalid combinations (interleave
requires `NUM_MMA_GROUPS == 2` and `SPLIT_K == 1`).

Perf on B200 (tflops):

```
                         aten      tlx_matmul_ws
           (M, N, K)  matmul   before    after   delta
  (8192, 8192, 8192) 1142.09  1168.46  1182.49  +1.2%
 (3159809, 384, 384)  647.23   664.21   644.05  -3.0%
(1152, 12800, 32768) 1124.60  1076.00  1069.98  -0.6%
  (1024, 256, 16384)  363.73   209.39   209.72  +0.2%
  (560849, 512, 896)  889.47   898.91   938.96  +4.5%
 (589824, 512, 2048)  915.12   959.46   959.39  -0.0%
 (1152, 65536, 1024) 1071.84   926.53   962.52  +3.9%
  (8192, 4608, 6144) 1170.41  1176.01  1195.01  +1.6%
(16384, 11264, 5632) 1089.06  1132.88  1149.37  +1.5%
  (8192, 8192, 2048) 1193.88  1141.82  1162.53  +1.8%
             average  960.74   935.37   947.40  +1.3%
```

Differential Revision: D94608909
…gue (facebookexperimental#1003)

Summary:

Add an interleaved epilogue mode to the Blackwell warp-specialized GEMM
kernel. When `INTERLEAVE_EPILOGUE=1`, the epilogue alternates TMA stores
between MMA group 0 and group 1 instead of draining each group
sequentially. This overlaps the TMA store latency of one group with the
TMEM read of the other, improving memory throughput on store-bound shapes.

The interleaved path is enabled by default for GPU-saturated shapes and
for tall-M shapes with small K (low arithmetic intensity). It is disabled
for Split-K configs (which use atomic reductions) and for the tall-M
high-arithmetic-intensity path with BLOCK_K=128.

Autotuning is also extended to explore `INTERLEAVE_EPILOGUE` in {0, 1},
with config pruning updated to filter invalid combinations (interleave
requires `NUM_MMA_GROUPS == 2` and `SPLIT_K == 1`).

Perf on B200 (tflops):

```
                         aten      tlx_matmul_ws
           (M, N, K)  matmul   before    after   delta
  (8192, 8192, 8192) 1142.09  1168.46  1182.49  +1.2%
 (3159809, 384, 384)  647.23   664.21   644.05  -3.0%
(1152, 12800, 32768) 1124.60  1076.00  1069.98  -0.6%
  (1024, 256, 16384)  363.73   209.39   209.72  +0.2%
  (560849, 512, 896)  889.47   898.91   938.96  +4.5%
 (589824, 512, 2048)  915.12   959.46   959.39  -0.0%
 (1152, 65536, 1024) 1071.84   926.53   962.52  +3.9%
  (8192, 4608, 6144) 1170.41  1176.01  1195.01  +1.6%
(16384, 11264, 5632) 1089.06  1132.88  1149.37  +1.5%
  (8192, 8192, 2048) 1193.88  1141.82  1162.53  +1.8%
             average  960.74   935.37   947.40  +1.3%
```

Differential Revision: D94608909
Copy link
Contributor

@njriasan njriasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this looks good. A couple of minor suggestions, but these can be done as followup or ignored.

c,
store_reduce="add",
c_smem = c_smem_buffers[0]
tlx.async_descriptor_store_wait(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I correct in saying that hardcoding this to 1 assumes that if we ever wanted larger blocks we would subtile less (e.g. we don't want 2 blocks per MMA subtiled by 4). If that's something we may consider then may be we can parameterize this?

It also seems like with some small refactoring this would work for both 1 and 2 MMAs since the pattern is simple. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 1 here only works for two MMA groups, as with only one group, TMA stores for different subtiles share the same underlying smem buffer, and overlapping them could be an issue.

With one group we can use do multi-buffering between the TMA stores. That's something I'm working on separately.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the one group I was partially suggesting that tlx.async_descriptor_store_wait(NUM_MMAs - 1) would make the code general. Not a functionality overlap but a code cleanup fix.

@meta-codesync
Copy link

meta-codesync bot commented Mar 3, 2026

This pull request has been merged in 17af1ee.

htyu added a commit that referenced this pull request Mar 3, 2026
…gue (#1003)

Summary:
Pull Request resolved: #1003

Add an interleaved epilogue mode to the Blackwell warp-specialized GEMM
kernel. When `INTERLEAVE_EPILOGUE=1`, the epilogue alternates TMA stores
between MMA group 0 and group 1 instead of draining each group
sequentially. This overlaps the TMA store latency of one group with the
TMEM read of the other, improving memory throughput on store-bound shapes.

The interleaved path is enabled by default for GPU-saturated shapes and
for tall-M shapes with small K (low arithmetic intensity). It is disabled
for Split-K configs (which use atomic reductions) and for the tall-M
high-arithmetic-intensity path with BLOCK_K=128.

Autotuning is also extended to explore `INTERLEAVE_EPILOGUE` in {0, 1},
with config pruning updated to filter invalid combinations (interleave
requires `NUM_MMA_GROUPS == 2` and `SPLIT_K == 1`).

Perf on B200 (tflops):

```
                         aten      tlx_matmul_ws
           (M, N, K)  matmul   before    after   delta
  (8192, 8192, 8192) 1142.09  1168.46  1182.49  +1.2%
 (3159809, 384, 384)  647.23   664.21   644.05  -3.0%
(1152, 12800, 32768) 1124.60  1076.00  1069.98  -0.6%
  (1024, 256, 16384)  363.73   209.39   209.72  +0.2%
  (560849, 512, 896)  889.47   898.91   938.96  +4.5%
 (589824, 512, 2048)  915.12   959.46   959.39  -0.0%
 (1152, 65536, 1024) 1071.84   926.53   962.52  +3.9%
  (8192, 4608, 6144) 1170.41  1176.01  1195.01  +1.6%
(16384, 11264, 5632) 1089.06  1132.88  1149.37  +1.5%
  (8192, 8192, 2048) 1193.88  1141.82  1162.53  +1.8%
             average  960.74   935.37   947.40  +1.3%
```

Reviewed By: njriasan

Differential Revision: D94608909

fbshipit-source-id: 1d41228ab875a00a7050e1025d838e447833a368
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot. fb-exported Merged meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants