6
6
import triton
7
7
import triton .language as tl
8
8
import triton .language .extra .tlx as tlx
9
+ from triton .tools .tensor_descriptor import TensorDescriptor
9
10
10
11
from .gdpa_utils import get_num_sms
11
12
from .math import activation_string_to_int
12
13
13
14
15
+ def _host_descriptor_pre_hook (nargs ):
16
+ BLOCK_M = nargs ["BLOCK_M" ]
17
+ BLOCK_N = nargs ["BLOCK_N" ]
18
+ BLOCK_D = nargs ["BLOCK_D" ]
19
+ if not isinstance (nargs ["Q" ], TensorDescriptor ):
20
+ # early return for on-device TMA
21
+ return
22
+ NUM_MMA_GROUPS = 2
23
+ BLOCK_M_SPLIT = BLOCK_M // NUM_MMA_GROUPS
24
+ nargs ["Q" ].block_shape = [BLOCK_M_SPLIT , BLOCK_D ]
25
+ nargs ["V" ].block_shape = [BLOCK_N , BLOCK_D ]
26
+ nargs ["K" ].block_shape = [BLOCK_N , BLOCK_D ]
27
+ nargs ["Out" ].block_shape = [BLOCK_M_SPLIT , BLOCK_D ]
28
+
29
+
14
30
def get_cuda_autotune_config ():
15
31
return [
16
32
triton .Config (
@@ -24,6 +40,7 @@ def get_cuda_autotune_config():
24
40
},
25
41
num_warps = 4 ,
26
42
num_stages = 1 ,
43
+ pre_hook = _host_descriptor_pre_hook ,
27
44
)
28
45
for BM in [256 ] # 128 or 256
29
46
for BN in [128 ]
@@ -198,6 +215,7 @@ def gdpa_kernel_tma_ws_blackwell(
198
215
BROADCAST_Q : tl .constexpr ,
199
216
IS_DENSE_KV : tl .constexpr ,
200
217
activation_enum_int : tl .constexpr ,
218
+ USE_ON_DEVICE_TMA : tl .constexpr ,
201
219
NUM_BUFFERS_Q : tl .constexpr ,
202
220
NUM_BUFFERS_KV : tl .constexpr ,
203
221
NUM_BUFFERS_QK : tl .constexpr ,
@@ -214,21 +232,27 @@ def gdpa_kernel_tma_ws_blackwell(
214
232
tiles_per_sm += 1
215
233
216
234
tile_idx = prog_id
235
+ if not USE_ON_DEVICE_TMA :
236
+ q_desc = Q
237
+ k_desc = K
238
+ v_desc = V
239
+ o_desc = Out
217
240
218
241
# start with on-device TMA where descriptors for k, v are set up outside of the persistent
219
242
# loop and descriptor for q is set up inside the persistent loop.
220
- k_desc = tl .make_tensor_descriptor (
221
- K ,
222
- shape = [N_CTX_KV * Z , HEAD_DIM * H // G ],
223
- strides = [HEAD_DIM * H // G , 1 ],
224
- block_shape = [BLOCK_N , BLOCK_D ],
225
- )
226
- v_desc = tl .make_tensor_descriptor (
227
- V ,
228
- shape = [N_CTX_KV * Z , HEAD_DIM * H // G ],
229
- strides = [HEAD_DIM * H // G , 1 ],
230
- block_shape = [BLOCK_N , BLOCK_D ],
231
- )
243
+ if USE_ON_DEVICE_TMA :
244
+ k_desc = tl .make_tensor_descriptor (
245
+ K ,
246
+ shape = [N_CTX_KV * Z , HEAD_DIM * H // G ],
247
+ strides = [HEAD_DIM * H // G , 1 ],
248
+ block_shape = [BLOCK_N , BLOCK_D ],
249
+ )
250
+ v_desc = tl .make_tensor_descriptor (
251
+ V ,
252
+ shape = [N_CTX_KV * Z , HEAD_DIM * H // G ],
253
+ strides = [HEAD_DIM * H // G , 1 ],
254
+ block_shape = [BLOCK_N , BLOCK_D ],
255
+ )
232
256
233
257
# allocate buffers for q0, q1
234
258
q0_buf = tlx .local_alloc ((BLOCK_M // 2 , BLOCK_D ), tl .float16 , 1 )
@@ -326,20 +350,12 @@ def gdpa_kernel_tma_ws_blackwell(
326
350
qk0 = tlx .local_load (qk_view ) # , tlx.storage_kind.tmem)
327
351
# ConsumerWait for qk, ProducerAcquire for p
328
352
# if activation_enum_int == 3:
329
- p0 = (
330
- qk0
331
- * 0.5
332
- * (
333
- 1
334
- + tanh_approx_fp32 (
335
- 0.7978845608 * qk0 * (1.0 + 0.044715 * qk0 * qk0 )
336
- )
337
- )
338
- ) # fast_gelu(qk0)
339
- # else:
340
- # p0 = qk0
353
+ p0 = fast_gelu (qk0 )
341
354
p0 *= qk_scale
342
- p0 = p0 .to (V .dtype .element_ty ) # v_dtype)
355
+ if USE_ON_DEVICE_TMA :
356
+ p0 = p0 .to (V .dtype .element_ty ) # v_dtype)
357
+ else :
358
+ p0 = p0 .to (tlx .dtype_of (v_desc ))
343
359
qk_view = tlx .local_view (qk0_buf , bufIdx )
344
360
p0_view = tlx .local_reinterpret (qk_view , tl .float16 )
345
361
tlx .local_store (p0_view , p0 ) # , tlx.storage_kind.tmem)
@@ -371,18 +387,23 @@ def gdpa_kernel_tma_ws_blackwell(
371
387
)
372
388
# tl.device_print("default producer_o0", accum_cnt_outer)
373
389
tlx .barrier_arrive (consumer_release_o0_view , 1 )
374
- o0_desc = tl .make_tensor_descriptor (
375
- Out ,
376
- shape = [end_q .to (tl .int32 ), HEAD_DIM * H ],
377
- strides = [HEAD_DIM * H , 1 ],
378
- block_shape = [BLOCK_M // 2 , BLOCK_D ],
379
- )
380
- o0_desc .store (
390
+ if USE_ON_DEVICE_TMA :
391
+ o_desc = tl .make_tensor_descriptor (
392
+ Out ,
393
+ shape = [end_q .to (tl .int32 ), HEAD_DIM * H ],
394
+ strides = [HEAD_DIM * H , 1 ],
395
+ block_shape = [BLOCK_M // 2 , BLOCK_D ],
396
+ )
397
+ if USE_ON_DEVICE_TMA :
398
+ o0 = o0 .to (Out .type .element_ty )
399
+ else :
400
+ o0 = o0 .to (tlx .dtype_of (o_desc ))
401
+ o_desc .store (
381
402
[
382
403
(begin_q + start_m * BLOCK_M ).to (tl .int32 ),
383
404
(out_offset ).to (tl .int32 ),
384
405
],
385
- o0 . to ( Out . type . element_ty ) ,
406
+ o0 ,
386
407
)
387
408
accum_cnt_outer += 1
388
409
tile_idx += num_progs
@@ -420,20 +441,12 @@ def gdpa_kernel_tma_ws_blackwell(
420
441
qk1 = tlx .local_load (qk_view ) # , tlx.storage_kind.tmem)
421
442
# ConsumerWait for qk, ProducerAcquire for p
422
443
# if activation_enum_int == 3:
423
- p1 = (
424
- qk1
425
- * 0.5
426
- * (
427
- 1
428
- + tanh_approx_fp32 (
429
- 0.7978845608 * qk1 * (1.0 + 0.044715 * qk1 * qk1 )
430
- )
431
- )
432
- ) # fast_gelu(qk1)
433
- # else:
434
- # p1 = qk1
444
+ p1 = fast_gelu (qk1 )
435
445
p1 *= qk_scale
436
- p1 = p1 .to (V .dtype .element_ty ) # v_dtype)
446
+ if USE_ON_DEVICE_TMA :
447
+ p1 = p1 .to (V .dtype .element_ty ) # v_dtype)
448
+ else :
449
+ p1 = p1 .to (tlx .dtype_of (v_desc ))
437
450
qk_view = tlx .local_view (qk1_buf , bufIdx )
438
451
p1_view = tlx .local_reinterpret (qk_view , tl .float16 )
439
452
tlx .local_store (p1_view , p1 ) # , tlx.storage_kind.tmem)
@@ -452,12 +465,13 @@ def gdpa_kernel_tma_ws_blackwell(
452
465
bufIdx_o_outer , phase_o_outer = _get_bufidx_phase (
453
466
accum_cnt_outer , NUM_BUFFERS_O
454
467
)
455
- o0_desc = tl .make_tensor_descriptor (
456
- Out ,
457
- shape = [end_q .to (tl .int32 ), HEAD_DIM * H ],
458
- strides = [HEAD_DIM * H , 1 ],
459
- block_shape = [BLOCK_M // 2 , BLOCK_D ],
460
- )
468
+ if USE_ON_DEVICE_TMA :
469
+ o_desc = tl .make_tensor_descriptor (
470
+ Out ,
471
+ shape = [end_q .to (tl .int32 ), HEAD_DIM * H ],
472
+ strides = [HEAD_DIM * H , 1 ],
473
+ block_shape = [BLOCK_M // 2 , BLOCK_D ],
474
+ )
461
475
o1_view = tlx .local_view (
462
476
o1_buf , bufIdx_o_outer
463
477
) # FIXME: should be 0
@@ -467,12 +481,16 @@ def gdpa_kernel_tma_ws_blackwell(
467
481
producer_o1 , bufIdx_o_outer
468
482
)
469
483
tlx .barrier_arrive (consumer_release_o1_view , 1 )
470
- o0_desc .store (
484
+ if USE_ON_DEVICE_TMA :
485
+ o1 = o1 .to (Out .type .element_ty )
486
+ else :
487
+ o1 = o1 .to (tlx .dtype_of (o_desc ))
488
+ o_desc .store (
471
489
[
472
490
(begin_q + start_m * BLOCK_M + BLOCK_M // 2 ).to (tl .int32 ),
473
491
(out_offset ).to (tl .int32 ),
474
492
],
475
- o1 . to ( Out . type . element_ty ) ,
493
+ o1 ,
476
494
)
477
495
accum_cnt_outer += 1
478
496
tile_idx += num_progs
@@ -581,6 +599,7 @@ def gdpa_kernel_tma_ws_blackwell(
581
599
producer_o1_view = tlx .local_view (producer_o1 , bufIdx_o_outer )
582
600
# tl.device_print("gemm producer_o0", accum_cnt_outer)
583
601
# tl.device_print("gemm producer_o0_phase", phase_o_outer)
602
+ # DEBUG_PERF
584
603
tlx .barrier_wait (
585
604
producer_o0_view , phase_o_outer ^ 1
586
605
) # producer acquire for o0
@@ -591,6 +610,7 @@ def gdpa_kernel_tma_ws_blackwell(
591
610
consumer_p0_view = tlx .local_view (producer_qk0 , bufIdx_p )
592
611
# tl.device_print("gemm producer_qk0", accum_cnt_qk)
593
612
# tl.device_print("gemm producer_qk0_phase", phase_p)
613
+ # DEBUG_PERF_P
594
614
tlx .barrier_wait (
595
615
consumer_p0_view , phase_p
596
616
) # consumer wait for p0 due to reuse of p0 and qk0
@@ -660,11 +680,13 @@ def gdpa_kernel_tma_ws_blackwell(
660
680
consumer_p1_view = tlx .local_view (producer_qk1 , bufIdx_qk1 )
661
681
# tl.device_print("gemm producer_o1", accum_cnt_outer)
662
682
# tl.device_print("gemm producer_o1_phase", phase_o_outer)
683
+ # DEBUG_PERF
663
684
tlx .barrier_wait (
664
685
producer_o1_view , phase_o_outer ^ 1 , first
665
686
) # producer acquire for o1, only needed for first iteration
666
687
# tl.device_print("gemm producer_qk1", accum_cnt_qk1)
667
688
# tl.device_print("gemm producer_qk1_phase", phase_qk1)
689
+ # DEBUG_PERF_P
668
690
tlx .barrier_wait (
669
691
consumer_p1_view , phase_qk1
670
692
) # consumer wait for p1 use producer_qk1 due to reuse
@@ -741,6 +763,7 @@ def gdpa_kernel_tma_ws_blackwell(
741
763
consumer_p0_view = tlx .local_view (producer_qk0 , bufIdx_qk )
742
764
# tl.device_print("gemm producer_qk0", accum_cnt_qk)
743
765
# tl.device_print("gemm producer_qk0_phase", phase_qk)
766
+ # DEBUG_PERF_P
744
767
tlx .barrier_wait (
745
768
consumer_p0_view , phase_qk
746
769
) # consumer wait for p0 use producer_qk0 due to reuse
@@ -780,6 +803,7 @@ def gdpa_kernel_tma_ws_blackwell(
780
803
tlx .tcgen05_commit (release_q1_view )
781
804
# tl.device_print("gemm producer_o1_epilogue", accum_cnt_outer)
782
805
# tl.device_print("gemm producer_o1_phase", phase_o_outer)
806
+ # DEBUG_PERF
783
807
tlx .barrier_wait (
784
808
producer_o1_view , phase_o_outer ^ 1 , first
785
809
) # producer acquire for o1 at the first iteration
@@ -789,6 +813,7 @@ def gdpa_kernel_tma_ws_blackwell(
789
813
consumer_p1_view = tlx .local_view (producer_qk1 , bufIdx_qk1 )
790
814
# tl.device_print("gemm producer_qk1_epilogue", accum_cnt_qk1)
791
815
# tl.device_print("gemm producer_qk1_phase", phase_qk1)
816
+ # DEBUG_PERF_P
792
817
tlx .barrier_wait (
793
818
consumer_p1_view , phase_qk1
794
819
) # consumer wait for p1 due to reuse of p1 and qk1
@@ -862,12 +887,13 @@ def gdpa_kernel_tma_ws_blackwell(
862
887
if start_m * BLOCK_M < qlen :
863
888
# begin_o = tl.load(Out_offsets + off_z) # confirm if tma store should use begin_q
864
889
865
- q_desc = tl .make_tensor_descriptor (
866
- Q ,
867
- shape = [end_q .to (tl .int32 ), HEAD_DIM * H ],
868
- strides = [HEAD_DIM * H , 1 ],
869
- block_shape = [BLOCK_M // 2 , BLOCK_D ],
870
- )
890
+ if USE_ON_DEVICE_TMA :
891
+ q_desc = tl .make_tensor_descriptor (
892
+ Q ,
893
+ shape = [end_q .to (tl .int32 ), HEAD_DIM * H ],
894
+ strides = [HEAD_DIM * H , 1 ],
895
+ block_shape = [BLOCK_M // 2 , BLOCK_D ],
896
+ )
871
897
872
898
# calculate bufIdx and phase from accum_count_q
873
899
q_bufIdx = accum_count_q % NUM_BUFFERS_Q
@@ -1131,6 +1157,40 @@ def gdpa_forward_tlx(
1131
1157
print ("NUM_SMS" , NUM_SMS )
1132
1158
print (triton .cdiv (max_seq_len_q , 256 ) * BATCH * nheads )
1133
1159
1160
+ q = expect_contiguous (query )
1161
+ k = expect_contiguous (key )
1162
+ v = expect_contiguous (value )
1163
+ kstrides = k .stride ()
1164
+ vstrides = v .stride ()
1165
+
1166
+ dummy_block = [1 , 1 ]
1167
+ N_CTX_KV = max_seq_len_kv
1168
+ HEAD_DIM = HEAD_DIM_K
1169
+ Z = BATCH
1170
+ H = nheads
1171
+ y_dim = N_CTX_KV * Z
1172
+ x_dim = HEAD_DIM * H // G
1173
+ USE_ON_DEVICE_TMA = True
1174
+ if not USE_ON_DEVICE_TMA :
1175
+ desc_q = TensorDescriptor (
1176
+ q ,
1177
+ shape = [y_dim , HEAD_DIM * H ],
1178
+ strides = [HEAD_DIM * H , 1 ],
1179
+ block_shape = dummy_block ,
1180
+ )
1181
+ desc_v = TensorDescriptor (
1182
+ v , shape = [y_dim , x_dim ], strides = [x_dim , 1 ], block_shape = dummy_block
1183
+ )
1184
+ desc_k = TensorDescriptor (
1185
+ k , shape = [y_dim , x_dim ], strides = [x_dim , 1 ], block_shape = dummy_block
1186
+ )
1187
+ desc_o = TensorDescriptor (
1188
+ o ,
1189
+ shape = [y_dim , HEAD_DIM * H ],
1190
+ strides = [HEAD_DIM * H , 1 ],
1191
+ block_shape = dummy_block ,
1192
+ )
1193
+
1134
1194
# TMA descriptors require a global memory allocation
1135
1195
def alloc_fn (size : int , alignment : int , _ ):
1136
1196
return torch .empty (size , device = "cuda" , dtype = torch .int8 )
@@ -1144,22 +1204,19 @@ def grid_tma_persistent(META):
1144
1204
1 ,
1145
1205
)
1146
1206
1147
- q = expect_contiguous (query )
1148
- k = expect_contiguous (key )
1149
- v = expect_contiguous (value )
1150
- kstrides = k .stride ()
1151
- vstrides = v .stride ()
1152
-
1153
1207
activation_enum_int = activation_string_to_int (activation )
1208
+ print (q .shape , k .shape , v .shape )
1154
1209
# print("activation_enum_int", activation, activation_enum_int)
1210
+ # print(query_offset)
1211
+ # print(key_offset)
1155
1212
1156
1213
gdpa_kernel_tma_ws_blackwell [grid_tma_persistent ](
1157
- q ,
1214
+ q if USE_ON_DEVICE_TMA else desc_q ,
1158
1215
query_offset ,
1159
- k ,
1216
+ k if USE_ON_DEVICE_TMA else desc_k ,
1160
1217
key_offset ,
1161
- v ,
1162
- o , #
1218
+ v if USE_ON_DEVICE_TMA else desc_v ,
1219
+ o if USE_ON_DEVICE_TMA else desc_o ,
1163
1220
output_offset ,
1164
1221
ad_to_request_offset ,
1165
1222
seq_index ,
@@ -1194,6 +1251,7 @@ def grid_tma_persistent(META):
1194
1251
BROADCAST_Q = broadcast_q ,
1195
1252
IS_DENSE_KV = is_dense_kv ,
1196
1253
activation_enum_int = activation_enum_int ,
1254
+ USE_ON_DEVICE_TMA = USE_ON_DEVICE_TMA ,
1197
1255
** extra_kern_args ,
1198
1256
)
1199
1257
return o
0 commit comments