Skip to content

Commit 68752b6

Browse files
htyufacebook-github-bot
authored andcommitted
[TLX] Multi-buffer epilogue TMA stores in Blackwell GEMM
Summary: Use double-buffering for epilogue TMA stores on the non-interleaved path (used by 32/48 shapes in benchmarks). Instead of using a single SMEM buffer per MMA group and waiting for all stores to complete (wait(0)), alternate between two SMEM buffers and wait for all-but-one (wait(1)). The buffer index is computed as (group_id * EPILOGUE_SUBTILE + slice_id) % 2 to avoid collisions across MMA group boundaries. Also ensure at least 2 SMEM epilogue buffers are allocated so multi-buffering works even when NUM_MMA_GROUPS == 1. The interleaved epilogue path (used by the remaining 16/48 shapes) already had this optimization via its two-group interleaving pattern. Also disable heuristic config selection by default, falling back to autotuning. On an internal L2 benchmark suite (48 shapes, autotuning): - Average TFLOPS: 713.1 -> 717.6 (+0.63%) - Average speedup vs aten: 0.899 -> 0.903 - Biggest wins on bandwidth-bound shapes (small N/K): (1142784, 256, 256): +9.0%, (1060571, 512, 512): +8.7%, (3159809, 128, 128): +7.6%, (589824, 256, 256): +7.0% Also falling back to full autotune while working on stabilizing the heuristics. Differential Revision: D95074321
1 parent 638e25e commit 68752b6

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

third_party/tlx/tutorials/blackwell_gemm_ws.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -706,11 +706,10 @@ def _process_tile_epilogue_inner(
706706
[BLOCK_M_SPLIT, slice_size],
707707
)
708708
result = tlx.local_load(acc_tmem_subslice)
709-
# Signal MMA consumer after each slice
710709
tlx.barrier_arrive(tmem_empty_bars[buf_idx], 1)
711710
c = result.to(tlx.dtype_of(c_desc))
712-
c_smem = c_smem_buffers[group_id]
713-
tlx.async_descriptor_store_wait(0)
711+
c_smem = c_smem_buffers[(group_id * EPILOGUE_SUBTILE + slice_id) % 2]
712+
tlx.async_descriptor_store_wait(1)
714713
tlx.local_store(c_smem, c)
715714
tlx.fence_async_shared()
716715
tlx.async_descriptor_store(c_desc, c_smem, [offs_am, offs_bn + slice_id * slice_size], store_reduce=STORE_REDUCE, eviction_policy="evict_first")
@@ -943,12 +942,13 @@ def matmul_kernel_tma_ws_blackwell(
943942
tlx.storage_kind.tmem,
944943
)
945944

946-
# Allocate SMEM buffer for epilogue TMA store (one per MMA group)
945+
# Allocate SMEM buffers for epilogue TMA store (at least 2 for multi-buffering)
946+
NUM_EPILOGUE_SMEM_BUFFERS: tl.constexpr = NUM_MMA_GROUPS if NUM_MMA_GROUPS > 2 else 2
947947
slice_size: tl.constexpr = BLOCK_SIZE_N // EPILOGUE_SUBTILE
948948
c_smem_buffers = tlx.local_alloc(
949949
(BLOCK_M_SPLIT, slice_size),
950950
tlx.dtype_of(c_desc),
951-
NUM_MMA_GROUPS,
951+
NUM_EPILOGUE_SMEM_BUFFERS,
952952
)
953953

954954
# CTA pairs are placed along M dim
@@ -1139,7 +1139,7 @@ def matmul_kernel_tma_ws_blackwell(
11391139
tile_id += NUM_SMS
11401140

11411141

1142-
def matmul(a, b, config=None, use_heuristic=True):
1142+
def matmul(a, b, config=None, use_heuristic=False):
11431143
"""Matrix multiplication using TLX GEMM kernel.
11441144
11451145
Args:

0 commit comments

Comments
 (0)