Skip to content

Commit 17af1ee

Browse files
htyumeta-codesync[bot]
authored andcommitted
[TLX] Interleave TMA stores across MMA groups in Blackwell GEMM epilogue (#1003)
Summary: Pull Request resolved: #1003 Add an interleaved epilogue mode to the Blackwell warp-specialized GEMM kernel. When `INTERLEAVE_EPILOGUE=1`, the epilogue alternates TMA stores between MMA group 0 and group 1 instead of draining each group sequentially. This overlaps the TMA store latency of one group with the TMEM read of the other, improving memory throughput on store-bound shapes. The interleaved path is enabled by default for GPU-saturated shapes and for tall-M shapes with small K (low arithmetic intensity). It is disabled for Split-K configs (which use atomic reductions) and for the tall-M high-arithmetic-intensity path with BLOCK_K=128. Autotuning is also extended to explore `INTERLEAVE_EPILOGUE` in {0, 1}, with config pruning updated to filter invalid combinations (interleave requires `NUM_MMA_GROUPS == 2` and `SPLIT_K == 1`). Perf on B200 (tflops): ``` aten tlx_matmul_ws (M, N, K) matmul before after delta (8192, 8192, 8192) 1142.09 1168.46 1182.49 +1.2% (3159809, 384, 384) 647.23 664.21 644.05 -3.0% (1152, 12800, 32768) 1124.60 1076.00 1069.98 -0.6% (1024, 256, 16384) 363.73 209.39 209.72 +0.2% (560849, 512, 896) 889.47 898.91 938.96 +4.5% (589824, 512, 2048) 915.12 959.46 959.39 -0.0% (1152, 65536, 1024) 1071.84 926.53 962.52 +3.9% (8192, 4608, 6144) 1170.41 1176.01 1195.01 +1.6% (16384, 11264, 5632) 1089.06 1132.88 1149.37 +1.5% (8192, 8192, 2048) 1193.88 1141.82 1162.53 +1.8% average 960.74 935.37 947.40 +1.3% ``` Reviewed By: njriasan Differential Revision: D94608909 fbshipit-source-id: 1d41228ab875a00a7050e1025d838e447833a368
1 parent f73e9a4 commit 17af1ee

File tree

2 files changed

+107
-30
lines changed

2 files changed

+107
-30
lines changed

third_party/tlx/tutorials/blackwell_gemm_ws.py

Lines changed: 106 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def get_heuristic_config(M, N, K, num_sms=148):
7676
"EPILOGUE_SUBTILE": 1,
7777
"NUM_CTAS": 2,
7878
"SPLIT_K": 1,
79+
"INTERLEAVE_EPILOGUE": 1,
7980
"ctas_per_cga": (2, 1, 1),
8081
"pre_hook": matmul_tma_set_block_size_hook,
8182
}
@@ -95,6 +96,7 @@ def get_heuristic_config(M, N, K, num_sms=148):
9596
"EPILOGUE_SUBTILE": 4,
9697
"NUM_CTAS": 2,
9798
"SPLIT_K": 1,
99+
"INTERLEAVE_EPILOGUE": 0,
98100
"ctas_per_cga": (2, 1, 1),
99101
"pre_hook": matmul_tma_set_block_size_hook,
100102
}
@@ -110,6 +112,7 @@ def get_heuristic_config(M, N, K, num_sms=148):
110112
"EPILOGUE_SUBTILE": 4,
111113
"NUM_CTAS": 2,
112114
"SPLIT_K": 1,
115+
"INTERLEAVE_EPILOGUE": 1,
113116
"ctas_per_cga": (2, 1, 1),
114117
"pre_hook": matmul_tma_set_block_size_hook,
115118
}
@@ -147,6 +150,7 @@ def get_heuristic_config(M, N, K, num_sms=148):
147150
"EPILOGUE_SUBTILE": 8,
148151
"NUM_CTAS": 1,
149152
"SPLIT_K": split_k,
153+
"INTERLEAVE_EPILOGUE": 0,
150154
"ctas_per_cga": None,
151155
"pre_hook": matmul_tma_set_block_size_hook,
152156
}
@@ -163,6 +167,7 @@ def get_heuristic_config(M, N, K, num_sms=148):
163167
"EPILOGUE_SUBTILE": 1,
164168
"NUM_CTAS": 1,
165169
"SPLIT_K": split_k,
170+
"INTERLEAVE_EPILOGUE": 0,
166171
"ctas_per_cga": None,
167172
"pre_hook": matmul_tma_set_block_size_hook,
168173
}
@@ -180,6 +185,7 @@ def get_heuristic_config(M, N, K, num_sms=148):
180185
"EPILOGUE_SUBTILE": 4,
181186
"NUM_CTAS": 1,
182187
"SPLIT_K": 1,
188+
"INTERLEAVE_EPILOGUE": 1,
183189
"ctas_per_cga": None,
184190
"pre_hook": matmul_tma_set_block_size_hook,
185191
}
@@ -311,6 +317,7 @@ def compute_wave_score(bm, bn, num_ctas, split_k=1):
311317
"EPILOGUE_SUBTILE": epilogue_subtile,
312318
"NUM_CTAS": num_ctas,
313319
"SPLIT_K": split_k,
320+
"INTERLEAVE_EPILOGUE": 0,
314321
"ctas_per_cga": (num_ctas, 1, 1) if num_ctas > 1 else None,
315322
"pre_hook": matmul_tma_set_block_size_hook,
316323
}
@@ -359,6 +366,7 @@ def get_cuda_autotune_config():
359366
"EPILOGUE_SUBTILE": subtile,
360367
"NUM_CTAS": num_ctas,
361368
"SPLIT_K": split_k,
369+
"INTERLEAVE_EPILOGUE": interleave,
362370
},
363371
num_warps=4,
364372
num_stages=1,
@@ -374,6 +382,7 @@ def get_cuda_autotune_config():
374382
for subtile in [1, 2, 4, 8]
375383
for num_ctas in [1, 2]
376384
for split_k in [1, 4]
385+
for interleave in [0, 1]
377386
for g in [1, 8, 64]
378387
]
379388

@@ -428,6 +437,7 @@ def preprocess_configs(configs, named_args, **kwargs):
428437
NUM_MMA_GROUPS = conf.kwargs["NUM_MMA_GROUPS"]
429438
SPLIT_K = conf.kwargs.get("SPLIT_K", 1)
430439
EPILOGUE_SUBTILE = conf.kwargs["EPILOGUE_SUBTILE"]
440+
INTERLEAVE_EPILOGUE = conf.kwargs.get("INTERLEAVE_EPILOGUE", 0)
431441

432442
# Filter out invalid config that causes wrong hardware MMA
433443
if BLOCK_M // NUM_MMA_GROUPS > 128:
@@ -437,6 +447,10 @@ def preprocess_configs(configs, named_args, **kwargs):
437447
if BLOCK_N % EPILOGUE_SUBTILE != 0:
438448
continue
439449

450+
# Interleaved epilogue requires NUM_MMA_GROUPS == 2 and SPLIT_K == 1
451+
if INTERLEAVE_EPILOGUE and (NUM_MMA_GROUPS != 2 or SPLIT_K != 1):
452+
continue
453+
440454
num_tiles_m = math.ceil(M / BLOCK_M)
441455
num_tiles_n = math.ceil(N / BLOCK_N)
442456
num_mn_tiles = num_tiles_m * num_tiles_n
@@ -527,6 +541,7 @@ def _group_key(c):
527541
c.kwargs["EPILOGUE_SUBTILE"],
528542
c.kwargs["NUM_CTAS"],
529543
c.kwargs.get("SPLIT_K", 1),
544+
c.kwargs.get("INTERLEAVE_EPILOGUE", 0),
530545
)
531546

532547
def _val(c):
@@ -600,6 +615,7 @@ def _process_tile_epilogue_inner(
600615
NUM_MMA_GROUPS,
601616
NUM_TMEM_BUFFERS,
602617
SPLIT_K,
618+
INTERLEAVE_EPILOGUE,
603619
c_desc,
604620
c_smem_buffers,
605621
tmem_buffers,
@@ -616,38 +632,96 @@ def _process_tile_epilogue_inner(
616632

617633
slice_size: tl.constexpr = BLOCK_SIZE_N // EPILOGUE_SUBTILE
618634

619-
for group_id in tl.static_range(NUM_MMA_GROUPS):
620-
# Wait for TMEM to be filled
621-
buf_idx = group_id * NUM_TMEM_BUFFERS + cur_tmem_buf
622-
623-
tlx.barrier_wait(tmem_full_bars[buf_idx], tmem_read_phase)
624-
625-
# load the result from TMEM to registers
626-
acc_tmem = tmem_buffers[buf_idx]
627-
offs_am = pid_m * BLOCK_SIZE_M + group_id * BLOCK_M_SPLIT
628-
for slice_id in tl.static_range(EPILOGUE_SUBTILE):
629-
acc_tmem_subslice = tlx.local_slice(
630-
acc_tmem,
631-
[0, slice_id * slice_size],
632-
[BLOCK_M_SPLIT, slice_size],
633-
)
634-
result = tlx.local_load(acc_tmem_subslice)
635-
# Signal MMA consumer after each slice
636-
tlx.barrier_arrive(tmem_empty_bars[buf_idx], 1)
635+
if INTERLEAVE_EPILOGUE:
636+
# Interleaved TMA stores across two groups to improve memory throughput.
637+
# Pattern: wait g0, store g0s0, wait g1, store g1s0,
638+
# then alternate g0/g1 for slices 1-3.
639+
buf_idx_0 = 0 * NUM_TMEM_BUFFERS + cur_tmem_buf
640+
buf_idx_1 = 1 * NUM_TMEM_BUFFERS + cur_tmem_buf
641+
acc_tmem_0 = tmem_buffers[buf_idx_0]
642+
acc_tmem_1 = tmem_buffers[buf_idx_1]
643+
offs_am_0 = pid_m * BLOCK_SIZE_M + 0 * BLOCK_M_SPLIT
644+
offs_am_1 = pid_m * BLOCK_SIZE_M + 1 * BLOCK_M_SPLIT
645+
646+
# --- Wait for group 0, store group 0 slice 0 ---
647+
tlx.barrier_wait(tmem_full_bars[buf_idx_0], tmem_read_phase)
648+
acc_sub = tlx.local_slice(acc_tmem_0, [0, 0 * slice_size], [BLOCK_M_SPLIT, slice_size])
649+
result = tlx.local_load(acc_sub)
650+
tlx.barrier_arrive(tmem_empty_bars[buf_idx_0], 1)
651+
c = result.to(tlx.dtype_of(c_desc))
652+
c_smem = c_smem_buffers[0]
653+
tlx.local_store(c_smem, c)
654+
tlx.fence_async_shared()
655+
tlx.async_descriptor_store(c_desc, c_smem, [offs_am_0, offs_bn + 0 * slice_size])
656+
657+
# --- Wait for group 1, store group 1 slice 0 ---
658+
tlx.barrier_wait(tmem_full_bars[buf_idx_1], tmem_read_phase)
659+
acc_sub = tlx.local_slice(acc_tmem_1, [0, 0 * slice_size], [BLOCK_M_SPLIT, slice_size])
660+
result = tlx.local_load(acc_sub)
661+
tlx.barrier_arrive(tmem_empty_bars[buf_idx_1], 1)
662+
c = result.to(tlx.dtype_of(c_desc))
663+
c_smem = c_smem_buffers[1]
664+
tlx.local_store(c_smem, c)
665+
tlx.fence_async_shared()
666+
tlx.async_descriptor_store(c_desc, c_smem, [offs_am_1, offs_bn + 0 * slice_size])
667+
668+
# --- Slices 1-3: alternate group 0, group 1 ---
669+
for slice_id in tl.static_range(1, EPILOGUE_SUBTILE):
670+
# Group 0
671+
acc_sub = tlx.local_slice(acc_tmem_0, [0, slice_id * slice_size], [BLOCK_M_SPLIT, slice_size])
672+
result = tlx.local_load(acc_sub)
673+
tlx.barrier_arrive(tmem_empty_bars[buf_idx_0], 1)
637674
c = result.to(tlx.dtype_of(c_desc))
638-
if SPLIT_K == 1:
639-
# Store to SMEM then use async TMA store to global
640-
c_smem = c_smem_buffers[group_id]
641-
tlx.async_descriptor_store_wait(0)
642-
tlx.local_store(c_smem, c)
643-
tlx.fence_async_shared()
644-
tlx.async_descriptor_store(c_desc, c_smem, [offs_am, offs_bn + slice_id * slice_size])
645-
else:
646-
c_desc.store(
647-
[offs_am, offs_bn + slice_id * slice_size],
648-
c,
649-
store_reduce="add",
675+
c_smem = c_smem_buffers[0]
676+
tlx.async_descriptor_store_wait(1)
677+
tlx.local_store(c_smem, c)
678+
tlx.fence_async_shared()
679+
tlx.async_descriptor_store(c_desc, c_smem, [offs_am_0, offs_bn + slice_id * slice_size])
680+
681+
# Group 1
682+
acc_sub = tlx.local_slice(acc_tmem_1, [0, slice_id * slice_size], [BLOCK_M_SPLIT, slice_size])
683+
result = tlx.local_load(acc_sub)
684+
tlx.barrier_arrive(tmem_empty_bars[buf_idx_1], 1)
685+
c = result.to(tlx.dtype_of(c_desc))
686+
c_smem = c_smem_buffers[1]
687+
tlx.async_descriptor_store_wait(1)
688+
tlx.local_store(c_smem, c)
689+
tlx.fence_async_shared()
690+
tlx.async_descriptor_store(c_desc, c_smem, [offs_am_1, offs_bn + slice_id * slice_size])
691+
else:
692+
for group_id in tl.static_range(NUM_MMA_GROUPS):
693+
# Wait for TMEM to be filled
694+
buf_idx = group_id * NUM_TMEM_BUFFERS + cur_tmem_buf
695+
696+
tlx.barrier_wait(tmem_full_bars[buf_idx], tmem_read_phase)
697+
698+
# load the result from TMEM to registers
699+
acc_tmem = tmem_buffers[buf_idx]
700+
offs_am = pid_m * BLOCK_SIZE_M + group_id * BLOCK_M_SPLIT
701+
for slice_id in tl.static_range(EPILOGUE_SUBTILE):
702+
acc_tmem_subslice = tlx.local_slice(
703+
acc_tmem,
704+
[0, slice_id * slice_size],
705+
[BLOCK_M_SPLIT, slice_size],
650706
)
707+
result = tlx.local_load(acc_tmem_subslice)
708+
# Signal MMA consumer after each slice
709+
tlx.barrier_arrive(tmem_empty_bars[buf_idx], 1)
710+
c = result.to(tlx.dtype_of(c_desc))
711+
if SPLIT_K == 1:
712+
# Store to SMEM then use async TMA store to global
713+
c_smem = c_smem_buffers[group_id]
714+
tlx.async_descriptor_store_wait(0)
715+
tlx.local_store(c_smem, c)
716+
tlx.fence_async_shared()
717+
tlx.async_descriptor_store(c_desc, c_smem, [offs_am, offs_bn + slice_id * slice_size])
718+
else:
719+
c_desc.store(
720+
[offs_am, offs_bn + slice_id * slice_size],
721+
c,
722+
store_reduce="add",
723+
)
724+
651725
# Wait for all TMA stores to complete
652726
tlx.async_descriptor_store_wait(0)
653727

@@ -854,6 +928,7 @@ def matmul_kernel_tma_ws_blackwell(
854928
EPILOGUE_SUBTILE: tl.constexpr,
855929
NUM_CTAS: tl.constexpr,
856930
SPLIT_K: tl.constexpr,
931+
INTERLEAVE_EPILOGUE: tl.constexpr,
857932
NUM_SMS: tl.constexpr,
858933
):
859934
# allocate NUM_SMEM_BUFFERS buffers
@@ -943,6 +1018,7 @@ def matmul_kernel_tma_ws_blackwell(
9431018
NUM_MMA_GROUPS=NUM_MMA_GROUPS,
9441019
NUM_TMEM_BUFFERS=NUM_TMEM_BUFFERS,
9451020
SPLIT_K=SPLIT_K,
1021+
INTERLEAVE_EPILOGUE=INTERLEAVE_EPILOGUE,
9461022
c_desc=c_desc,
9471023
c_smem_buffers=c_smem_buffers,
9481024
tmem_buffers=tmem_buffers,

third_party/tlx/tutorials/testing/test_correctness.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class Gemm:
6363
"EPILOGUE_SUBTILE": 1,
6464
"NUM_CTAS": 1,
6565
"SPLIT_K": 1,
66+
"INTERLEAVE_EPILOGUE": 0,
6667
},
6768
"blackwell_gemm_clc": {
6869
"BLOCK_SIZE_M": 128,

0 commit comments

Comments
 (0)