@@ -34,9 +34,12 @@ def get_cuda_autotune_config():
34
34
"BLOCK_M" : BM ,
35
35
"BLOCK_N" : BN ,
36
36
"NUM_BUFFERS_Q" : bq ,
37
- "NUM_BUFFERS_KV" : bk ,
37
+ "NUM_BUFFERS_KV" : bkv ,
38
38
"NUM_BUFFERS_QK" : bqk ,
39
39
"NUM_BUFFERS_O" : bo ,
40
+ "SUBTILING" : SUBTILE ,
41
+ "PINGPONG" : pp ,
42
+ "ACT_REGS" : ar ,
40
43
},
41
44
num_warps = 4 ,
42
45
num_stages = 1 ,
@@ -45,9 +48,12 @@ def get_cuda_autotune_config():
45
48
for BM in [256 ] # 128 or 256
46
49
for BN in [128 ]
47
50
for bq in [1 ]
48
- for bk in [3 ]
51
+ for bkv in [3 ]
49
52
for bqk in [1 ] # in tmem
50
53
for bo in [1 ] # in tmem
54
+ for SUBTILE in [True ] # doesn't support False
55
+ for pp in [True , False ]
56
+ for ar in [192 , 232 ]
51
57
]
52
58
53
59
@@ -285,6 +291,9 @@ def gdpa_kernel_tma_ws_blackwell(
285
291
NUM_BUFFERS_KV : tl .constexpr ,
286
292
NUM_BUFFERS_QK : tl .constexpr ,
287
293
NUM_BUFFERS_O : tl .constexpr ,
294
+ SUBTILING : tl .constexpr ,
295
+ PINGPONG : tl .constexpr ,
296
+ ACT_REGS : tl .constexpr ,
288
297
):
289
298
n_tile_num = tl .cdiv (N_CTX , BLOCK_M )
290
299
prog_id = tl .program_id (0 )
@@ -375,7 +384,7 @@ def gdpa_kernel_tma_ws_blackwell(
375
384
376
385
with tlx .async_tasks ():
377
386
# activation calculation
378
- with tlx .async_task ("default" , registers = 192 ):
387
+ with tlx .async_task ("default" , registers = ACT_REGS ):
379
388
accum_cnt = 0
380
389
accum_cnt_outer = 0
381
390
for _ in range (0 , tiles_per_sm ):
@@ -411,23 +420,38 @@ def gdpa_kernel_tma_ws_blackwell(
411
420
# qk_view: BLOCK_M // 2, HEAD_DIM
412
421
qk_view_1st = tlx .subslice (qk_view , 0 , HEAD_DIM // 2 )
413
422
qk0 = tlx .local_load (qk_view_1st )
414
- p0 = fast_gelu (qk0 )
415
- p0 = p0 .to (dtype )
416
- p0_view = tlx .local_reinterpret (qk_view_1st , dtype )
417
- tlx .local_store (p0_view , p0 )
418
-
419
423
qk_view_2nd = tlx .subslice (
420
424
qk_view , HEAD_DIM // 2 , HEAD_DIM // 2
421
425
)
422
- qk0 = tlx .local_load (qk_view_2nd )
423
- p0 = fast_gelu (qk0 )
426
+ qk1 = tlx .local_load (qk_view_2nd )
427
+ c1 = 0.0356774081
428
+ c0 = 0.7978845608
429
+ square = _mul_f32x2 (qk0 , qk0 )
430
+ inner = _fma_f32x2 (c1 , square , c0 )
431
+ inner0 = _mul_f32x2 (inner , qk0 )
432
+ square = _mul_f32x2 (qk1 , qk1 )
433
+ inner = _fma_f32x2 (c1 , square , c0 )
434
+ inner1 = _mul_f32x2 (inner , qk1 )
435
+
436
+ if PINGPONG :
437
+ tlx .named_barrier_wait (9 , 128 )
438
+ # p0 = fast_gelu(qk0)
439
+ p0 = _fma_f32x2 (qk0 , tanh_approx_fp32 (inner0 ), qk0 )
424
440
p0 = p0 .to (dtype )
425
- p0_view = tlx .local_reinterpret (qk_view_2nd , dtype )
441
+ p0_view = tlx .local_reinterpret (qk_view_1st , dtype )
426
442
tlx .local_store (p0_view , p0 )
427
443
444
+ # p1 = fast_gelu(qk1)
445
+ p1 = _fma_f32x2 (qk1 , tanh_approx_fp32 (inner1 ), qk1 )
446
+ p1 = p1 .to (dtype )
447
+ p1_view = tlx .local_reinterpret (qk_view_2nd , dtype )
448
+ tlx .local_store (p1_view , p1 )
449
+
428
450
# p and qk reuse tmem space, single producer commit for p via consumer_release_qk
429
451
consumer_release_qk_view = tlx .local_view (producer_qk0 , bufIdx )
430
452
tlx .barrier_arrive (consumer_release_qk_view , 1 )
453
+ if PINGPONG :
454
+ tlx .named_barrier_arrive (10 , 128 )
431
455
432
456
# wait for o0, o1 per iteration
433
457
bufIdx = accum_cnt % NUM_BUFFERS_O
@@ -436,10 +460,12 @@ def gdpa_kernel_tma_ws_blackwell(
436
460
consumer_o0_view = tlx .local_view (producer_commit_o0 , bufIdx )
437
461
# tl.device_print("default producer_commit_o0", accum_cnt)
438
462
# tl.device_print("default producer_commit_o0_phase", phase)
439
- tlx .barrier_wait (consumer_o0_view , phase )
463
+ # there is no need to wait for o0 at each iteration
464
+ # tlx.barrier_wait(consumer_o0_view, phase)
440
465
accum_cnt += 1
441
466
442
467
# epilogue here, load from tmem
468
+ # FIXME: wait till o0 is done for the inner loop
443
469
bufIdx_o_outer , phase_o_outer = _get_bufidx_phase (
444
470
accum_cnt_outer , NUM_BUFFERS_O
445
471
)
@@ -472,9 +498,11 @@ def gdpa_kernel_tma_ws_blackwell(
472
498
accum_cnt_outer += 1
473
499
tile_idx += num_progs
474
500
475
- with tlx .async_task (num_warps = 4 , registers = 192 ):
501
+ with tlx .async_task (num_warps = 4 , registers = ACT_REGS ):
476
502
accum_cnt = 0
477
503
accum_cnt_outer = 0
504
+ if PINGPONG :
505
+ tlx .named_barrier_arrive (9 , 128 )
478
506
for _ in range (0 , tiles_per_sm ):
479
507
pid = tile_idx % n_tile_num
480
508
start_m = pid
@@ -505,32 +533,49 @@ def gdpa_kernel_tma_ws_blackwell(
505
533
# qk_view: BLOCK_M // 2, HEAD_DIM
506
534
qk_view_1st = tlx .subslice (qk_view , 0 , HEAD_DIM // 2 )
507
535
qk0 = tlx .local_load (qk_view_1st )
508
- p0 = fast_gelu (qk0 )
509
- p0 = p0 .to (dtype )
510
- p0_view = tlx .local_reinterpret (qk_view_1st , dtype )
511
- tlx .local_store (p0_view , p0 )
512
-
513
536
qk_view_2nd = tlx .subslice (
514
537
qk_view , HEAD_DIM // 2 , HEAD_DIM // 2
515
538
)
516
- qk0 = tlx .local_load (qk_view_2nd )
517
- p0 = fast_gelu (qk0 )
539
+ qk1 = tlx .local_load (qk_view_2nd )
540
+ c1 = 0.0356774081
541
+ c0 = 0.7978845608
542
+ square = _mul_f32x2 (qk0 , qk0 )
543
+ inner = _fma_f32x2 (c1 , square , c0 )
544
+ inner0 = _mul_f32x2 (inner , qk0 )
545
+ square = _mul_f32x2 (qk1 , qk1 )
546
+ inner = _fma_f32x2 (c1 , square , c0 )
547
+ inner1 = _mul_f32x2 (inner , qk1 )
548
+
549
+ if PINGPONG :
550
+ tlx .named_barrier_wait (10 , 128 )
551
+ # p0 = fast_gelu(qk0)
552
+ p0 = _fma_f32x2 (qk0 , tanh_approx_fp32 (inner0 ), qk0 )
518
553
p0 = p0 .to (dtype )
519
- p0_view = tlx .local_reinterpret (qk_view_2nd , dtype )
554
+ p0_view = tlx .local_reinterpret (qk_view_1st , dtype )
520
555
tlx .local_store (p0_view , p0 )
521
556
557
+ # p1 = fast_gelu(qk1)
558
+ p1 = _fma_f32x2 (qk1 , tanh_approx_fp32 (inner1 ), qk1 )
559
+ p1 = p1 .to (dtype )
560
+ p1_view = tlx .local_reinterpret (qk_view_2nd , dtype )
561
+ tlx .local_store (p1_view , p1 )
562
+
522
563
# p and qk reuse tmem space, single producer commit for p via consumer_release_qk
523
564
consumer_release_qk_view = tlx .local_view (producer_qk1 , bufIdx )
524
565
tlx .barrier_arrive (consumer_release_qk_view , 1 )
566
+ if PINGPONG :
567
+ tlx .named_barrier_arrive (9 , 128 )
525
568
526
569
# wait for o0, o1 per iteration
527
570
bufIdx = accum_cnt % NUM_BUFFERS_O
528
571
phase = (accum_cnt // NUM_BUFFERS_O ) & 1
529
572
# consumer wait of o1
530
573
consumer_o1_view = tlx .local_view (producer_commit_o1 , bufIdx )
531
- tlx .barrier_wait (consumer_o1_view , phase )
574
+ # there is no need to wait for o1 at each iteration
575
+ # tlx.barrier_wait(consumer_o1_view, phase)
532
576
accum_cnt += 1
533
577
# epilogue here, load from tmem
578
+ # FIXME: wait till o1 is done for the inner loop
534
579
bufIdx_o_outer , phase_o_outer = _get_bufidx_phase (
535
580
accum_cnt_outer , NUM_BUFFERS_O
536
581
)
@@ -1210,15 +1255,16 @@ def gdpa_forward_tlx(
1210
1255
1211
1256
stage = 1 # When supporting causal, change to 3
1212
1257
extra_kern_args = {}
1258
+ # extra_kern_args["maxnreg"] = 168
1213
1259
nheads = query .shape [1 ]
1214
1260
G = query .shape [1 ] // key .shape [1 ]
1215
1261
assert query .shape [1 ] % key .shape [1 ] == 0
1216
1262
batch_size = BATCH * nheads
1217
1263
NUM_SMS = (
1218
1264
get_num_sms () or 1000000
1219
- ) * 8 # if num sms is None, use a large number so that it is a no-op
1220
- print ("NUM_SMS" , NUM_SMS )
1221
- print (triton .cdiv (max_seq_len_q , 256 ) * BATCH * nheads )
1265
+ ) # * 8 # if num sms is None, use a large number so that it is a no-op
1266
+ # print("NUM_SMS", NUM_SMS)
1267
+ # print(triton.cdiv(max_seq_len_q, 256) * BATCH * nheads)
1222
1268
1223
1269
q = expect_contiguous (query )
1224
1270
k = expect_contiguous (key )
@@ -1268,7 +1314,7 @@ def grid_tma_persistent(META):
1268
1314
)
1269
1315
1270
1316
activation_enum_int = activation_string_to_int (activation )
1271
- print (q .shape , k .shape , v .shape )
1317
+ # print(q.shape, k.shape, v.shape)
1272
1318
# print("activation_enum_int", activation, activation_enum_int)
1273
1319
# print(query_offset)
1274
1320
# print(key_offset)
0 commit comments