@@ -584,8 +584,10 @@ def gdpa_kernel_tma_ws_blackwell(
584
584
out_offset = off_h .to (tl .int64 ) * stride_oh
585
585
if start_m * BLOCK_M < qlen :
586
586
lo , hi = 0 , klen
587
+ tl .device_print ("default" , hi )
587
588
for start_n in range (lo , hi , BLOCK_N ):
588
589
start_n = tl .multiple_of (start_n , BLOCK_N )
590
+ tl .device_print ("default start_n" , start_n )
589
591
## communication channel for qk0, p0
590
592
# _do_activation(
591
593
# qk0_buf,
@@ -602,6 +604,8 @@ def gdpa_kernel_tma_ws_blackwell(
602
604
phase = (accum_cnt // NUM_BUFFERS_QK ) & 1
603
605
qk_view = tlx .local_view (qk0_buf , bufIdx )
604
606
consumer_qk_view = tlx .local_view (producer_commit_qk0 , bufIdx )
607
+ tl .device_print ("producer_commit_qk0" , accum_cnt )
608
+ tl .device_print ("producer_commit_qk0_phase" , phase )
605
609
tlx .barrier_wait (consumer_qk_view , phase )
606
610
qk0 = tlx .local_load (qk_view ) # , tlx.storage_kind.tmem)
607
611
# ConsumerWait for qk, ProducerAcquire for p
@@ -632,6 +636,8 @@ def gdpa_kernel_tma_ws_blackwell(
632
636
phase = (accum_cnt // NUM_BUFFERS_O ) & 1
633
637
# consumer wait of o0: producer_commit
634
638
consumer_o0_view = tlx .local_view (producer_commit_o0 , bufIdx )
639
+ tl .device_print ("producer_commit_o0" , accum_cnt )
640
+ tl .device_print ("producer_commit_o0_phase" , phase )
635
641
tlx .barrier_wait (consumer_o0_view , phase )
636
642
accum_cnt += 1
637
643
@@ -647,6 +653,7 @@ def gdpa_kernel_tma_ws_blackwell(
647
653
consumer_release_o0_view = tlx .local_view (
648
654
producer_o0 , bufIdx_o_outer
649
655
)
656
+ tl .device_print ("arrive producer_o0" , accum_cnt_outer )
650
657
tlx .barrier_arrive (consumer_release_o0_view , 1 )
651
658
o0_desc = tl .make_tensor_descriptor (
652
659
Out ,
@@ -662,6 +669,7 @@ def gdpa_kernel_tma_ws_blackwell(
662
669
o0 .to (Out .type .element_ty ),
663
670
)
664
671
accum_cnt_outer += 1
672
+ tile_idx += num_progs
665
673
666
674
with tlx .async_task (num_warps = 4 ):
667
675
accum_cnt = 0
@@ -761,6 +769,7 @@ def gdpa_kernel_tma_ws_blackwell(
761
769
o1 .to (Out .type .element_ty ),
762
770
)
763
771
accum_cnt_outer += 1
772
+ tile_idx += num_progs
764
773
765
774
with tlx .async_task (num_warps = 1 ): # gemm
766
775
accum_cnt_q = 0
@@ -829,7 +838,11 @@ def gdpa_kernel_tma_ws_blackwell(
829
838
consumer_q0_view = tlx .local_view (consumer_q0 , bufIdx_q )
830
839
consumer_k_view = tlx .local_view (consumer_k , bufIdx_k )
831
840
# producer_qk0_view = tlx.local_view(producer_qk0, bufIdx_qk)
841
+ tl .device_print ("consumer_q0_prologue" , accum_cnt_q )
842
+ tl .device_print ("consumer_q0_phase" , phase_q )
832
843
tlx .barrier_wait (consumer_q0_view , phase_q ) # consumer wait for q0
844
+ tl .device_print ("consumer_k" , accum_cnt_k )
845
+ tl .device_print ("consumer_k_phase" , phase_k )
833
846
tlx .barrier_wait (consumer_k_view , phase_k ) # consumer wait for k
834
847
# Do we need the initial acquire here?
835
848
# dot partition has producer commit for qk0, activation partition consumer wait for qk0
@@ -853,6 +866,8 @@ def gdpa_kernel_tma_ws_blackwell(
853
866
854
867
consumer_q1_view = tlx .local_view (consumer_q1 , bufIdx_q )
855
868
# producer_qk1_view = tlx.local_view(producer_qk1, bufIdx_qk)
869
+ tl .device_print ("consumer_q1" , accum_cnt_q )
870
+ tl .device_print ("consumer_q1_phase" , phase_q )
856
871
tlx .barrier_wait (consumer_q1_view , phase_q ) # consumer wait for q1
857
872
# tlx.barrier_wait(producer_qk1_view, phase_qk) # producer acquire for qk1
858
873
# consumer release for k, producer commit for qk1
@@ -874,13 +889,17 @@ def gdpa_kernel_tma_ws_blackwell(
874
889
# accum_cnt_qk1 += 1
875
890
876
891
consumer_v_view = tlx .local_view (consumer_v , bufIdx_k )
892
+ tl .device_print ("consumer_v" , accum_cnt_k )
893
+ tl .device_print ("consumer_v_phase" , phase_k )
877
894
tlx .barrier_wait (consumer_v_view , phase_k ) # consumer wait for v
878
895
# need to acquire o0 to make sure epilogue is done, this is needed for each outer loop
879
896
bufIdx_o_outer , phase_o_outer = _get_bufidx_phase (
880
897
accum_cnt_outer , NUM_BUFFERS_O
881
898
)
882
899
producer_o0_view = tlx .local_view (producer_o0 , bufIdx_o_outer )
883
900
producer_o1_view = tlx .local_view (producer_o1 , bufIdx_o_outer )
901
+ tl .device_print ("producer_o0" , accum_cnt_outer )
902
+ tl .device_print ("producer_o0_phase" , phase_o_outer )
884
903
tlx .barrier_wait (
885
904
producer_o0_view , phase_o_outer ^ 1
886
905
) # producer acquire for o0
@@ -889,6 +908,8 @@ def gdpa_kernel_tma_ws_blackwell(
889
908
# dot partition: producer commit of qk0, ..., consumer wait for p0 (use the same barrier as producer_qk0)
890
909
bufIdx_p , phase_p = _get_bufidx_phase (accum_cnt_qk , NUM_BUFFERS_QK )
891
910
consumer_p0_view = tlx .local_view (producer_qk0 , bufIdx_p )
911
+ tl .device_print ("producer_qk0" , accum_cnt_qk )
912
+ tl .device_print ("producer_qk0_phase" , phase_p )
892
913
tlx .barrier_wait (
893
914
consumer_p0_view , phase_p
894
915
) # consumer wait for p0 due to reuse of p0 and qk0
@@ -917,7 +938,11 @@ def gdpa_kernel_tma_ws_blackwell(
917
938
mma_iters = (hi - lo ) // BLOCK_N
918
939
accum_cnt_k += 1
919
940
accum_cnt_qk += 1
920
- for _ in range (mma_iters - 1 ):
941
+ tl .device_print ("gemm for " , hi )
942
+ tl .device_print ("gemm mma_iters " , mma_iters )
943
+ for it in range (BLOCK_N , hi , BLOCK_N ):
944
+ # for it in range(mma_iters - 1):
945
+ tl .device_print ("gemm iter " , it )
921
946
bufIdx_k , phase_k = _get_bufidx_phase (
922
947
accum_cnt_k , NUM_BUFFERS_K
923
948
)
@@ -927,6 +952,8 @@ def gdpa_kernel_tma_ws_blackwell(
927
952
928
953
# q0 dot k
929
954
consumer_k_view = tlx .local_view (consumer_k , bufIdx_k )
955
+ tl .device_print ("consumer_k" , accum_cnt_k )
956
+ tl .device_print ("consumer_k_phase" , phase_k )
930
957
tlx .barrier_wait (
931
958
consumer_k_view , phase_k
932
959
) # consumer wait for k
@@ -948,9 +975,13 @@ def gdpa_kernel_tma_ws_blackwell(
948
975
accum_cnt_qk1 , NUM_BUFFERS_QK
949
976
)
950
977
consumer_p1_view = tlx .local_view (producer_qk1 , bufIdx_qk1 )
978
+ tl .device_print ("producer_o1" , accum_cnt_outer )
979
+ tl .device_print ("producer_o1_phase" , phase_o_outer )
951
980
tlx .barrier_wait (
952
981
producer_o1_view , phase_o_outer ^ 1 , first
953
982
) # producer acquire for o1, only needed for first iteration
983
+ tl .device_print ("producer_qk1" , accum_cnt_qk1 )
984
+ tl .device_print ("producer_qk1_phase" , phase_qk1 )
954
985
tlx .barrier_wait (
955
986
consumer_p1_view , phase_qk1
956
987
) # consumer wait for p1 use producer_qk1 due to reuse
@@ -1007,12 +1038,16 @@ def gdpa_kernel_tma_ws_blackwell(
1007
1038
1008
1039
# p0 dot v
1009
1040
consumer_v_view = tlx .local_view (consumer_v , bufIdx_k )
1041
+ tl .device_print ("consumer_v" , accum_cnt_k )
1042
+ tl .device_print ("consumer_v_phase" , phase_k )
1010
1043
tlx .barrier_wait (
1011
1044
consumer_v_view , phase_k
1012
1045
) # consumer wait for v
1013
1046
# no need to acquire o0 as this is the only partition updating it
1014
1047
# tlx.barrier_wait(producer_o0) # producer acquire for o0
1015
1048
consumer_p0_view = tlx .local_view (producer_qk0 , bufIdx_qk )
1049
+ tl .device_print ("producer_qk0" , accum_cnt_qk )
1050
+ tl .device_print ("producer_qk0_phase" , phase_qk )
1016
1051
tlx .barrier_wait (
1017
1052
consumer_p0_view , phase_qk
1018
1053
) # consumer wait for p0 use producer_qk0 due to reuse
@@ -1049,13 +1084,17 @@ def gdpa_kernel_tma_ws_blackwell(
1049
1084
tlx .tcgen05_commit (release_q0_view )
1050
1085
release_q1_view = tlx .local_view (consumer_release_q1 , bufIdx_q )
1051
1086
tlx .tcgen05_commit (release_q1_view )
1087
+ tl .device_print ("producer_o1_epilogue" , accum_cnt_outer )
1088
+ tl .device_print ("producer_o1_phase" , phase_o_outer )
1052
1089
tlx .barrier_wait (
1053
1090
producer_o1_view , phase_o_outer ^ 1 , first
1054
1091
) # producer acquire for o1 at the first iteration
1055
1092
bufIdx_qk1 , phase_qk1 = _get_bufidx_phase (
1056
1093
accum_cnt_qk1 , NUM_BUFFERS_QK
1057
1094
)
1058
1095
consumer_p1_view = tlx .local_view (producer_qk1 , bufIdx_qk1 )
1096
+ tl .device_print ("producer_qk1_epilogue" , accum_cnt_qk1 )
1097
+ tl .device_print ("producer_qk1_phase" , phase_qk1 )
1059
1098
tlx .barrier_wait (
1060
1099
consumer_p1_view , phase_qk1
1061
1100
) # consumer wait for p1 due to reuse of p1 and qk1
@@ -1092,6 +1131,7 @@ def gdpa_kernel_tma_ws_blackwell(
1092
1131
accum_cnt_outer += 1
1093
1132
# signal producer commit of epi0 and epi1, we don't want to block the gemm partition
1094
1133
# to wait for the completion
1134
+ tile_idx += num_progs
1095
1135
1096
1136
with tlx .async_task (num_warps = 1 ): # load
1097
1137
accum_count_q = 0
@@ -1153,7 +1193,7 @@ def gdpa_kernel_tma_ws_blackwell(
1153
1193
consumer_q0 , q_bufIdx
1154
1194
) # full_bars, bufIdx)
1155
1195
tlx .barrier_expect_bytes (
1156
- q0_full_view , BLOCK_M * BLOCK_D * 2
1196
+ q0_full_view , BLOCK_M // 2 * BLOCK_D * 2
1157
1197
) # num_bytes)
1158
1198
q0_smem_view = tlx .local_view (q0_buf , q_bufIdx )
1159
1199
tlx .async_descriptor_load (
@@ -1183,7 +1223,7 @@ def gdpa_kernel_tma_ws_blackwell(
1183
1223
# barrier for producer commit
1184
1224
q1_full_view = tlx .local_view (consumer_q1 , q_bufIdx )
1185
1225
tlx .barrier_expect_bytes (
1186
- q1_full_view , BLOCK_M * BLOCK_D * 2
1226
+ q1_full_view , BLOCK_M // 2 * BLOCK_D * 2
1187
1227
) # num_bytes)
1188
1228
q1_smem_view = tlx .local_view (q1_buf , q_bufIdx )
1189
1229
tlx .async_descriptor_load (
@@ -1259,8 +1299,9 @@ def gdpa_kernel_tma_ws_blackwell(
1259
1299
v_full_view ,
1260
1300
)
1261
1301
accum_count_k += 1
1262
-
1302
+ # outside of inner for
1263
1303
accum_count_q += 1
1304
+ tile_idx += num_progs
1264
1305
1265
1306
1266
1307
def next_power_of_2 (x ):
@@ -1375,7 +1416,7 @@ def alloc_fn(size: int, alignment: int, _):
1375
1416
1376
1417
def grid_tma_persistent (META ):
1377
1418
return (
1378
- min (NUM_SMS , triton .cdiv (max_seq_len_q , META ["BLOCK_M" ]) * BATCH * nheads ),
1419
+ 1 , # min(NUM_SMS, triton.cdiv(max_seq_len_q, META["BLOCK_M"]) * BATCH * nheads),
1379
1420
1 ,
1380
1421
1 ,
1381
1422
)
0 commit comments