Skip to content

Commit 7f3b62f

Browse files
authored
[gdpa] Add pingpong around tanh
Differential Revision: D81242916 Pull Request resolved: #368
1 parent e01dd2d commit 7f3b62f

File tree

1 file changed

+72
-26
lines changed

1 file changed

+72
-26
lines changed

tritonbench/operators/gdpa/gdpa_blackwell_tlx.py

Lines changed: 72 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,12 @@ def get_cuda_autotune_config():
3434
"BLOCK_M": BM,
3535
"BLOCK_N": BN,
3636
"NUM_BUFFERS_Q": bq,
37-
"NUM_BUFFERS_KV": bk,
37+
"NUM_BUFFERS_KV": bkv,
3838
"NUM_BUFFERS_QK": bqk,
3939
"NUM_BUFFERS_O": bo,
40+
"SUBTILING": SUBTILE,
41+
"PINGPONG": pp,
42+
"ACT_REGS": ar,
4043
},
4144
num_warps=4,
4245
num_stages=1,
@@ -45,9 +48,12 @@ def get_cuda_autotune_config():
4548
for BM in [256] # 128 or 256
4649
for BN in [128]
4750
for bq in [1]
48-
for bk in [3]
51+
for bkv in [3]
4952
for bqk in [1] # in tmem
5053
for bo in [1] # in tmem
54+
for SUBTILE in [True] # doesn't support False
55+
for pp in [True, False]
56+
for ar in [192, 232]
5157
]
5258

5359

@@ -285,6 +291,9 @@ def gdpa_kernel_tma_ws_blackwell(
285291
NUM_BUFFERS_KV: tl.constexpr,
286292
NUM_BUFFERS_QK: tl.constexpr,
287293
NUM_BUFFERS_O: tl.constexpr,
294+
SUBTILING: tl.constexpr,
295+
PINGPONG: tl.constexpr,
296+
ACT_REGS: tl.constexpr,
288297
):
289298
n_tile_num = tl.cdiv(N_CTX, BLOCK_M)
290299
prog_id = tl.program_id(0)
@@ -375,7 +384,7 @@ def gdpa_kernel_tma_ws_blackwell(
375384

376385
with tlx.async_tasks():
377386
# activation calculation
378-
with tlx.async_task("default", registers=192):
387+
with tlx.async_task("default", registers=ACT_REGS):
379388
accum_cnt = 0
380389
accum_cnt_outer = 0
381390
for _ in range(0, tiles_per_sm):
@@ -411,23 +420,38 @@ def gdpa_kernel_tma_ws_blackwell(
411420
# qk_view: BLOCK_M // 2, HEAD_DIM
412421
qk_view_1st = tlx.subslice(qk_view, 0, HEAD_DIM // 2)
413422
qk0 = tlx.local_load(qk_view_1st)
414-
p0 = fast_gelu(qk0)
415-
p0 = p0.to(dtype)
416-
p0_view = tlx.local_reinterpret(qk_view_1st, dtype)
417-
tlx.local_store(p0_view, p0)
418-
419423
qk_view_2nd = tlx.subslice(
420424
qk_view, HEAD_DIM // 2, HEAD_DIM // 2
421425
)
422-
qk0 = tlx.local_load(qk_view_2nd)
423-
p0 = fast_gelu(qk0)
426+
qk1 = tlx.local_load(qk_view_2nd)
427+
c1 = 0.0356774081
428+
c0 = 0.7978845608
429+
square = _mul_f32x2(qk0, qk0)
430+
inner = _fma_f32x2(c1, square, c0)
431+
inner0 = _mul_f32x2(inner, qk0)
432+
square = _mul_f32x2(qk1, qk1)
433+
inner = _fma_f32x2(c1, square, c0)
434+
inner1 = _mul_f32x2(inner, qk1)
435+
436+
if PINGPONG:
437+
tlx.named_barrier_wait(9, 128)
438+
# p0 = fast_gelu(qk0)
439+
p0 = _fma_f32x2(qk0, tanh_approx_fp32(inner0), qk0)
424440
p0 = p0.to(dtype)
425-
p0_view = tlx.local_reinterpret(qk_view_2nd, dtype)
441+
p0_view = tlx.local_reinterpret(qk_view_1st, dtype)
426442
tlx.local_store(p0_view, p0)
427443

444+
# p1 = fast_gelu(qk1)
445+
p1 = _fma_f32x2(qk1, tanh_approx_fp32(inner1), qk1)
446+
p1 = p1.to(dtype)
447+
p1_view = tlx.local_reinterpret(qk_view_2nd, dtype)
448+
tlx.local_store(p1_view, p1)
449+
428450
# p and qk reuse tmem space, single producer commit for p via consumer_release_qk
429451
consumer_release_qk_view = tlx.local_view(producer_qk0, bufIdx)
430452
tlx.barrier_arrive(consumer_release_qk_view, 1)
453+
if PINGPONG:
454+
tlx.named_barrier_arrive(10, 128)
431455

432456
# wait for o0, o1 per iteration
433457
bufIdx = accum_cnt % NUM_BUFFERS_O
@@ -436,10 +460,12 @@ def gdpa_kernel_tma_ws_blackwell(
436460
consumer_o0_view = tlx.local_view(producer_commit_o0, bufIdx)
437461
# tl.device_print("default producer_commit_o0", accum_cnt)
438462
# tl.device_print("default producer_commit_o0_phase", phase)
439-
tlx.barrier_wait(consumer_o0_view, phase)
463+
# there is no need to wait for o0 at each iteration
464+
# tlx.barrier_wait(consumer_o0_view, phase)
440465
accum_cnt += 1
441466

442467
# epilogue here, load from tmem
468+
# FIXME: wait till o0 is done for the inner loop
443469
bufIdx_o_outer, phase_o_outer = _get_bufidx_phase(
444470
accum_cnt_outer, NUM_BUFFERS_O
445471
)
@@ -472,9 +498,11 @@ def gdpa_kernel_tma_ws_blackwell(
472498
accum_cnt_outer += 1
473499
tile_idx += num_progs
474500

475-
with tlx.async_task(num_warps=4, registers=192):
501+
with tlx.async_task(num_warps=4, registers=ACT_REGS):
476502
accum_cnt = 0
477503
accum_cnt_outer = 0
504+
if PINGPONG:
505+
tlx.named_barrier_arrive(9, 128)
478506
for _ in range(0, tiles_per_sm):
479507
pid = tile_idx % n_tile_num
480508
start_m = pid
@@ -505,32 +533,49 @@ def gdpa_kernel_tma_ws_blackwell(
505533
# qk_view: BLOCK_M // 2, HEAD_DIM
506534
qk_view_1st = tlx.subslice(qk_view, 0, HEAD_DIM // 2)
507535
qk0 = tlx.local_load(qk_view_1st)
508-
p0 = fast_gelu(qk0)
509-
p0 = p0.to(dtype)
510-
p0_view = tlx.local_reinterpret(qk_view_1st, dtype)
511-
tlx.local_store(p0_view, p0)
512-
513536
qk_view_2nd = tlx.subslice(
514537
qk_view, HEAD_DIM // 2, HEAD_DIM // 2
515538
)
516-
qk0 = tlx.local_load(qk_view_2nd)
517-
p0 = fast_gelu(qk0)
539+
qk1 = tlx.local_load(qk_view_2nd)
540+
c1 = 0.0356774081
541+
c0 = 0.7978845608
542+
square = _mul_f32x2(qk0, qk0)
543+
inner = _fma_f32x2(c1, square, c0)
544+
inner0 = _mul_f32x2(inner, qk0)
545+
square = _mul_f32x2(qk1, qk1)
546+
inner = _fma_f32x2(c1, square, c0)
547+
inner1 = _mul_f32x2(inner, qk1)
548+
549+
if PINGPONG:
550+
tlx.named_barrier_wait(10, 128)
551+
# p0 = fast_gelu(qk0)
552+
p0 = _fma_f32x2(qk0, tanh_approx_fp32(inner0), qk0)
518553
p0 = p0.to(dtype)
519-
p0_view = tlx.local_reinterpret(qk_view_2nd, dtype)
554+
p0_view = tlx.local_reinterpret(qk_view_1st, dtype)
520555
tlx.local_store(p0_view, p0)
521556

557+
# p1 = fast_gelu(qk1)
558+
p1 = _fma_f32x2(qk1, tanh_approx_fp32(inner1), qk1)
559+
p1 = p1.to(dtype)
560+
p1_view = tlx.local_reinterpret(qk_view_2nd, dtype)
561+
tlx.local_store(p1_view, p1)
562+
522563
# p and qk reuse tmem space, single producer commit for p via consumer_release_qk
523564
consumer_release_qk_view = tlx.local_view(producer_qk1, bufIdx)
524565
tlx.barrier_arrive(consumer_release_qk_view, 1)
566+
if PINGPONG:
567+
tlx.named_barrier_arrive(9, 128)
525568

526569
# wait for o0, o1 per iteration
527570
bufIdx = accum_cnt % NUM_BUFFERS_O
528571
phase = (accum_cnt // NUM_BUFFERS_O) & 1
529572
# consumer wait of o1
530573
consumer_o1_view = tlx.local_view(producer_commit_o1, bufIdx)
531-
tlx.barrier_wait(consumer_o1_view, phase)
574+
# there is no need to wait for o1 at each iteration
575+
# tlx.barrier_wait(consumer_o1_view, phase)
532576
accum_cnt += 1
533577
# epilogue here, load from tmem
578+
# FIXME: wait till o1 is done for the inner loop
534579
bufIdx_o_outer, phase_o_outer = _get_bufidx_phase(
535580
accum_cnt_outer, NUM_BUFFERS_O
536581
)
@@ -1210,15 +1255,16 @@ def gdpa_forward_tlx(
12101255

12111256
stage = 1 # When supporting causal, change to 3
12121257
extra_kern_args = {}
1258+
# extra_kern_args["maxnreg"] = 168
12131259
nheads = query.shape[1]
12141260
G = query.shape[1] // key.shape[1]
12151261
assert query.shape[1] % key.shape[1] == 0
12161262
batch_size = BATCH * nheads
12171263
NUM_SMS = (
12181264
get_num_sms() or 1000000
1219-
) * 8 # if num sms is None, use a large number so that it is a no-op
1220-
print("NUM_SMS", NUM_SMS)
1221-
print(triton.cdiv(max_seq_len_q, 256) * BATCH * nheads)
1265+
) # * 8 # if num sms is None, use a large number so that it is a no-op
1266+
# print("NUM_SMS", NUM_SMS)
1267+
# print(triton.cdiv(max_seq_len_q, 256) * BATCH * nheads)
12221268

12231269
q = expect_contiguous(query)
12241270
k = expect_contiguous(key)
@@ -1268,7 +1314,7 @@ def grid_tma_persistent(META):
12681314
)
12691315

12701316
activation_enum_int = activation_string_to_int(activation)
1271-
print(q.shape, k.shape, v.shape)
1317+
# print(q.shape, k.shape, v.shape)
12721318
# print("activation_enum_int", activation, activation_enum_int)
12731319
# print(query_offset)
12741320
# print(key_offset)

0 commit comments

Comments
 (0)