Skip to content

Commit 341b78b

Browse files
htyumeta-codesync[bot]
authored andcommitted
[TLX] Refactor grouped gemm with configurable sublicing (#651)
Summary: Subslicing enables bigger tile size and more pipeline stages. It benefits certain shapes: Triton autotuning for function grouped_matmul_tlx_kernel, best config selected: BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, NUM_SMEM_BUFFERS: 3, NUM_TMEM_BUFFERS: 2, EPILOGUE_SUBTILE: **4**, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None; Pull Request resolved: #651 Reviewed By: manman-ren Differential Revision: D86577923 Pulled By: htyu fbshipit-source-id: cda92b66d1a727dbc1279792a0c931deead84db9
1 parent c8a4965 commit 341b78b

File tree

1 file changed

+11
-16
lines changed

1 file changed

+11
-16
lines changed

third_party/tlx/tutorials/blackwell-grouped-gemm.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def _get_bufidx_phase(accum_cnt, NUM_BUFFERS_KV):
344344
},
345345
num_warps=4,
346346
num_stages=1,
347-
) for BM in [128] for BN in [128, 256] for BK in [64, 128] for s in [2, 3, 4] for t in [2] for subtile in [False]
347+
) for BM in [128] for BN in [128, 256] for BK in [64, 128] for s in [2, 3, 4] for t in [2] for subtile in [1, 2, 4]
348348
]
349349

350350

@@ -411,7 +411,7 @@ def grouped_matmul_tlx_kernel(
411411
c_ptr,
412412
shape=[gm, gn],
413413
strides=[ldc, 1],
414-
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
414+
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N // EPILOGUE_SUBTILE],
415415
)
416416

417417
# iterate through the tiles in the current gemm problem
@@ -430,21 +430,16 @@ def grouped_matmul_tlx_kernel(
430430
offs_cm = tile_m_idx * BLOCK_SIZE_M
431431
offs_cn = tile_n_idx * BLOCK_SIZE_N
432432

433-
if EPILOGUE_SUBTILE:
434-
# We load/store the result half by half to reduce SMEM pressure
435-
acc_tmem_subslice1 = tlx.subslice(acc_tmem, 0, BLOCK_SIZE_N // 2)
436-
result = tlx.local_load(acc_tmem_subslice1)
433+
slice_size: tl.constexpr = BLOCK_SIZE_N // EPILOGUE_SUBTILE
434+
for slice_id in tl.static_range(EPILOGUE_SUBTILE):
435+
acc_slice = tlx.local_slice(
436+
acc_tmem,
437+
[0, slice_id * slice_size],
438+
[BLOCK_SIZE_M, slice_size],
439+
)
440+
result = tlx.local_load(acc_slice)
437441
c = result.to(tl.float16)
438-
c_desc.store([offs_cm, offs_cn], c)
439-
440-
acc_tmem_subslice2 = tlx.subslice(acc_tmem, BLOCK_SIZE_N // 2, BLOCK_SIZE_N // 2)
441-
result = tlx.local_load(acc_tmem_subslice2)
442-
c = result.to(tl.float16)
443-
c_desc.store([offs_cm, offs_cn + BLOCK_SIZE_N // 2], c)
444-
else:
445-
result = tlx.local_load(acc_tmem)
446-
c = result.to(tl.float16)
447-
c_desc.store([offs_cm, offs_cn], c)
442+
c_desc.store([offs_cm, offs_cn + slice_id * slice_size], c)
448443

449444
# done storing this buffer, signal MMA consumer to resume writing to it
450445
tlx.barrier_arrive(tmem_empty_bars[tmem_buf], 1)

0 commit comments

Comments
 (0)