Skip to content

Commit 2b0fbd0

Browse files
peterbell10makslevental
authored andcommitted
[TUTORIAL] Apply remat perf fix to non-TMA persistent matmul (triton-lang#5811)
Currently the persistent matmul is 10% slower than non-persistent on blackwell, with this fix it's about 20% speedup vs non-persistent.
1 parent 6248526 commit 2b0fbd0

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

python/tutorials/09-persistent-matmul.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,15 +258,16 @@ def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, #
258258
if start_pid < num_tiles % NUM_SMS:
259259
tiles_per_SM += 1
260260

261+
# NOTE: There is currently a bug in blackwell pipelining that means it can't handle a value being
262+
# used in both the prologue and epilogue, so we duplicate the counters as a work-around.
261263
tile_id = start_pid - NUM_SMS
264+
tile_id_c = start_pid - NUM_SMS
262265
ki = -1
263266

264267
offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
265268

266269
num_pid_in_group = GROUP_SIZE_M * num_pid_n
267270

268-
pid_m = 0
269-
pid_n = 0
270271
offs_am = tl.arange(0, BLOCK_SIZE_M)
271272
offs_bn = tl.arange(0, BLOCK_SIZE_N)
272273

@@ -293,6 +294,8 @@ def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, #
293294
accumulator = tl.dot(a, b, accumulator)
294295

295296
if ki == k_tiles - 1:
297+
tile_id_c, pid_m, pid_n = _compute_tile_and_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M,
298+
NUM_SMS)
296299
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
297300
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
298301
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
@@ -366,8 +369,6 @@ def matmul_kernel_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
366369
tiles_per_SM += 1
367370

368371
tile_id = start_pid - NUM_SMS
369-
# tile_id_c is used in the epilogue to break the dependency between
370-
# the prologue and the epilogue
371372
tile_id_c = start_pid - NUM_SMS
372373

373374
ki = -1

0 commit comments

Comments
 (0)