Skip to content

Commit 3de877f

Browse files
htyumeta-codesync[bot]
authored andcommitted
[TLX] Add L2 cache hints to Blackwell GEMM TMA loads and stores (#1027)
Summary: Pull Request resolved: #1027 Add eviction_policy="evict_last" to TMA descriptor loads (A and B) to keep input data in L2 since it's reused across K iterations. Add eviction_policy="evict_first" to TMA descriptor stores (C) to evict output data from L2 since it won't be reread. When SPLIT_K > 1, the store eviction hint is skipped because eviction_policy cannot be combined with store_reduce. Inspired by D94747718. Reviewed By: levendlee Differential Revision: D94928618 fbshipit-source-id: 6c1647cff39d03e85a4413c73cb39f5d6fb89aa4
1 parent ce6394b commit 3de877f

File tree

1 file changed

+50
-24
lines changed

1 file changed

+50
-24
lines changed

third_party/tlx/tutorials/blackwell_gemm_ws.py

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def get_heuristic_config(M, N, K, num_sms=148):
150150
"EPILOGUE_SUBTILE": 8,
151151
"NUM_CTAS": 1,
152152
"SPLIT_K": split_k,
153-
"INTERLEAVE_EPILOGUE": 0,
153+
"INTERLEAVE_EPILOGUE": 1,
154154
"ctas_per_cga": None,
155155
"pre_hook": matmul_tma_set_block_size_hook,
156156
}
@@ -167,7 +167,7 @@ def get_heuristic_config(M, N, K, num_sms=148):
167167
"EPILOGUE_SUBTILE": 1,
168168
"NUM_CTAS": 1,
169169
"SPLIT_K": split_k,
170-
"INTERLEAVE_EPILOGUE": 0,
170+
"INTERLEAVE_EPILOGUE": 1,
171171
"ctas_per_cga": None,
172172
"pre_hook": matmul_tma_set_block_size_hook,
173173
}
@@ -447,8 +447,8 @@ def preprocess_configs(configs, named_args, **kwargs):
447447
if BLOCK_N % EPILOGUE_SUBTILE != 0:
448448
continue
449449

450-
# Interleaved epilogue requires NUM_MMA_GROUPS == 2 and SPLIT_K == 1
451-
if INTERLEAVE_EPILOGUE and (NUM_MMA_GROUPS != 2 or SPLIT_K != 1):
450+
# Interleaved epilogue requires NUM_MMA_GROUPS == 2
451+
if INTERLEAVE_EPILOGUE and NUM_MMA_GROUPS != 2:
452452
continue
453453

454454
num_tiles_m = math.ceil(M / BLOCK_M)
@@ -631,6 +631,7 @@ def _process_tile_epilogue_inner(
631631
BLOCK_M_SPLIT: tl.constexpr = BLOCK_SIZE_M // NUM_MMA_GROUPS
632632

633633
slice_size: tl.constexpr = BLOCK_SIZE_N // EPILOGUE_SUBTILE
634+
STORE_REDUCE: tl.constexpr = "add" if SPLIT_K > 1 else ""
634635

635636
if INTERLEAVE_EPILOGUE:
636637
# Interleaved TMA stores across two groups to improve memory throughput.
@@ -652,7 +653,13 @@ def _process_tile_epilogue_inner(
652653
c_smem = c_smem_buffers[0]
653654
tlx.local_store(c_smem, c)
654655
tlx.fence_async_shared()
655-
tlx.async_descriptor_store(c_desc, c_smem, [offs_am_0, offs_bn + 0 * slice_size])
656+
tlx.async_descriptor_store(
657+
c_desc,
658+
c_smem,
659+
[offs_am_0, offs_bn + 0 * slice_size],
660+
store_reduce=STORE_REDUCE,
661+
eviction_policy="evict_first",
662+
)
656663

657664
# --- Wait for group 1, store group 1 slice 0 ---
658665
tlx.barrier_wait(tmem_full_bars[buf_idx_1], tmem_read_phase)
@@ -663,7 +670,13 @@ def _process_tile_epilogue_inner(
663670
c_smem = c_smem_buffers[1]
664671
tlx.local_store(c_smem, c)
665672
tlx.fence_async_shared()
666-
tlx.async_descriptor_store(c_desc, c_smem, [offs_am_1, offs_bn + 0 * slice_size])
673+
tlx.async_descriptor_store(
674+
c_desc,
675+
c_smem,
676+
[offs_am_1, offs_bn + 0 * slice_size],
677+
store_reduce=STORE_REDUCE,
678+
eviction_policy="evict_first",
679+
)
667680

668681
# --- Slices 1-3: alternate group 0, group 1 ---
669682
for slice_id in tl.static_range(1, EPILOGUE_SUBTILE):
@@ -676,7 +689,13 @@ def _process_tile_epilogue_inner(
676689
tlx.async_descriptor_store_wait(1)
677690
tlx.local_store(c_smem, c)
678691
tlx.fence("async_shared")
679-
tlx.async_descriptor_store(c_desc, c_smem, [offs_am_0, offs_bn + slice_id * slice_size])
692+
tlx.async_descriptor_store(
693+
c_desc,
694+
c_smem,
695+
[offs_am_0, offs_bn + slice_id * slice_size],
696+
store_reduce=STORE_REDUCE,
697+
eviction_policy="evict_first",
698+
)
680699

681700
# Group 1
682701
acc_sub = tlx.local_slice(acc_tmem_1, [0, slice_id * slice_size], [BLOCK_M_SPLIT, slice_size])
@@ -687,7 +706,13 @@ def _process_tile_epilogue_inner(
687706
tlx.async_descriptor_store_wait(1)
688707
tlx.local_store(c_smem, c)
689708
tlx.fence("async_shared")
690-
tlx.async_descriptor_store(c_desc, c_smem, [offs_am_1, offs_bn + slice_id * slice_size])
709+
tlx.async_descriptor_store(
710+
c_desc,
711+
c_smem,
712+
[offs_am_1, offs_bn + slice_id * slice_size],
713+
store_reduce=STORE_REDUCE,
714+
eviction_policy="evict_first",
715+
)
691716
else:
692717
for group_id in tl.static_range(NUM_MMA_GROUPS):
693718
# Wait for TMEM to be filled
@@ -708,19 +733,17 @@ def _process_tile_epilogue_inner(
708733
# Signal MMA consumer after each slice
709734
tlx.barrier_arrive(tmem_empty_bars[buf_idx], 1)
710735
c = result.to(tlx.dtype_of(c_desc))
711-
if SPLIT_K == 1:
712-
# Store to SMEM then use async TMA store to global
713-
c_smem = c_smem_buffers[group_id]
714-
tlx.async_descriptor_store_wait(0)
715-
tlx.local_store(c_smem, c)
716-
tlx.fence_async_shared()
717-
tlx.async_descriptor_store(c_desc, c_smem, [offs_am, offs_bn + slice_id * slice_size])
718-
else:
719-
c_desc.store(
720-
[offs_am, offs_bn + slice_id * slice_size],
721-
c,
722-
store_reduce="add",
723-
)
736+
c_smem = c_smem_buffers[group_id]
737+
tlx.async_descriptor_store_wait(0)
738+
tlx.local_store(c_smem, c)
739+
tlx.fence_async_shared()
740+
tlx.async_descriptor_store(
741+
c_desc,
742+
c_smem,
743+
[offs_am, offs_bn + slice_id * slice_size],
744+
store_reduce=STORE_REDUCE,
745+
eviction_policy="evict_first",
746+
)
724747

725748
# Wait for all TMA stores to complete
726749
tlx.async_descriptor_store_wait(0)
@@ -881,13 +904,15 @@ def _process_tile_producer_inner(
881904
tlx.barrier_wait(A_smem_empty_bars[a_buf], phase ^ 1)
882905
offs_am = pid_m * BLOCK_SIZE_M
883906
tlx.barrier_expect_bytes(A_smem_full_bars[a_buf], dsize * BLOCK_M_SPLIT * BLOCK_SIZE_K)
884-
tlx.async_descriptor_load(a_desc, buffers_A[a_buf], [offs_am, offs_k], A_smem_full_bars[a_buf])
907+
tlx.async_descriptor_load(a_desc, buffers_A[a_buf], [offs_am, offs_k], A_smem_full_bars[a_buf],
908+
eviction_policy="evict_last")
885909

886910
# Load B once per K iteration (shared across all subtiles)
887911
last_a_buf = (NUM_MMA_GROUPS - 1) * NUM_SMEM_BUFFERS + buf
888912
tlx.barrier_wait(A_smem_empty_bars[last_a_buf], phase ^ 1)
889913
tlx.barrier_expect_bytes(B_smem_full_bars[buf], expected_bytes)
890-
tlx.async_descriptor_load(b_desc, buffers_B[buf], [offs_k, offs_bn], B_smem_full_bars[buf])
914+
tlx.async_descriptor_load(b_desc, buffers_B[buf], [offs_k, offs_bn], B_smem_full_bars[buf],
915+
eviction_policy="evict_last")
891916

892917
# Load all remaining A subtiles for this K iteration
893918
for group_id in tl.static_range(1, NUM_MMA_GROUPS):
@@ -898,7 +923,8 @@ def _process_tile_producer_inner(
898923
offs_am2 = offs_am + group_id * BLOCK_M_SPLIT
899924

900925
tlx.barrier_expect_bytes(A_smem_full_bars[a_buf], dsize * BLOCK_M_SPLIT * BLOCK_SIZE_K)
901-
tlx.async_descriptor_load(a_desc, buffers_A[a_buf], [offs_am2, offs_k], A_smem_full_bars[a_buf])
926+
tlx.async_descriptor_load(a_desc, buffers_A[a_buf], [offs_am2, offs_k], A_smem_full_bars[a_buf],
927+
eviction_policy="evict_last")
902928

903929
smem_accum_cnt += 1
904930

0 commit comments

Comments
 (0)