@@ -584,10 +584,10 @@ def gdpa_kernel_tma_ws_blackwell(
584
584
out_offset = off_h .to (tl .int64 ) * stride_oh
585
585
if start_m * BLOCK_M < qlen :
586
586
lo , hi = 0 , klen
587
- tl .device_print ("default" , hi )
587
+ # tl.device_print("default", hi)
588
588
for start_n in range (lo , hi , BLOCK_N ):
589
589
start_n = tl .multiple_of (start_n , BLOCK_N )
590
- tl .device_print ("default start_n" , start_n )
590
+ # tl.device_print("default start_n", start_n)
591
591
## communication channel for qk0, p0
592
592
# _do_activation(
593
593
# qk0_buf,
@@ -604,8 +604,8 @@ def gdpa_kernel_tma_ws_blackwell(
604
604
phase = (accum_cnt // NUM_BUFFERS_QK ) & 1
605
605
qk_view = tlx .local_view (qk0_buf , bufIdx )
606
606
consumer_qk_view = tlx .local_view (producer_commit_qk0 , bufIdx )
607
- tl .device_print ("producer_commit_qk0" , accum_cnt )
608
- tl .device_print ("producer_commit_qk0_phase" , phase )
607
+ # tl.device_print("producer_commit_qk0", accum_cnt)
608
+ # tl.device_print("producer_commit_qk0_phase", phase)
609
609
tlx .barrier_wait (consumer_qk_view , phase )
610
610
qk0 = tlx .local_load (qk_view ) # , tlx.storage_kind.tmem)
611
611
# ConsumerWait for qk, ProducerAcquire for p
@@ -636,8 +636,8 @@ def gdpa_kernel_tma_ws_blackwell(
636
636
phase = (accum_cnt // NUM_BUFFERS_O ) & 1
637
637
# consumer wait of o0: producer_commit
638
638
consumer_o0_view = tlx .local_view (producer_commit_o0 , bufIdx )
639
- tl .device_print ("producer_commit_o0" , accum_cnt )
640
- tl .device_print ("producer_commit_o0_phase" , phase )
639
+ # tl.device_print("producer_commit_o0", accum_cnt)
640
+ # tl.device_print("producer_commit_o0_phase", phase)
641
641
tlx .barrier_wait (consumer_o0_view , phase )
642
642
accum_cnt += 1
643
643
@@ -653,7 +653,7 @@ def gdpa_kernel_tma_ws_blackwell(
653
653
consumer_release_o0_view = tlx .local_view (
654
654
producer_o0 , bufIdx_o_outer
655
655
)
656
- tl .device_print ("arrive producer_o0" , accum_cnt_outer )
656
+ # tl.device_print("arrive producer_o0", accum_cnt_outer)
657
657
tlx .barrier_arrive (consumer_release_o0_view , 1 )
658
658
o0_desc = tl .make_tensor_descriptor (
659
659
Out ,
@@ -838,11 +838,11 @@ def gdpa_kernel_tma_ws_blackwell(
838
838
consumer_q0_view = tlx .local_view (consumer_q0 , bufIdx_q )
839
839
consumer_k_view = tlx .local_view (consumer_k , bufIdx_k )
840
840
# producer_qk0_view = tlx.local_view(producer_qk0, bufIdx_qk)
841
- tl .device_print ("consumer_q0_prologue" , accum_cnt_q )
842
- tl .device_print ("consumer_q0_phase" , phase_q )
841
+ # tl.device_print("consumer_q0_prologue", accum_cnt_q)
842
+ # tl.device_print("consumer_q0_phase", phase_q)
843
843
tlx .barrier_wait (consumer_q0_view , phase_q ) # consumer wait for q0
844
- tl .device_print ("consumer_k" , accum_cnt_k )
845
- tl .device_print ("consumer_k_phase" , phase_k )
844
+ # tl.device_print("consumer_k", accum_cnt_k)
845
+ # tl.device_print("consumer_k_phase", phase_k)
846
846
tlx .barrier_wait (consumer_k_view , phase_k ) # consumer wait for k
847
847
# Do we need the initial acquire here?
848
848
# dot partition has producer commit for qk0, activation partition consumer wait for qk0
@@ -866,8 +866,8 @@ def gdpa_kernel_tma_ws_blackwell(
866
866
867
867
consumer_q1_view = tlx .local_view (consumer_q1 , bufIdx_q )
868
868
# producer_qk1_view = tlx.local_view(producer_qk1, bufIdx_qk)
869
- tl .device_print ("consumer_q1" , accum_cnt_q )
870
- tl .device_print ("consumer_q1_phase" , phase_q )
869
+ # tl.device_print("consumer_q1", accum_cnt_q)
870
+ # tl.device_print("consumer_q1_phase", phase_q)
871
871
tlx .barrier_wait (consumer_q1_view , phase_q ) # consumer wait for q1
872
872
# tlx.barrier_wait(producer_qk1_view, phase_qk) # producer acquire for qk1
873
873
# consumer release for k, producer commit for qk1
@@ -889,17 +889,17 @@ def gdpa_kernel_tma_ws_blackwell(
889
889
# accum_cnt_qk1 += 1
890
890
891
891
consumer_v_view = tlx .local_view (consumer_v , bufIdx_k )
892
- tl .device_print ("consumer_v" , accum_cnt_k )
893
- tl .device_print ("consumer_v_phase" , phase_k )
892
+ # tl.device_print("consumer_v", accum_cnt_k)
893
+ # tl.device_print("consumer_v_phase", phase_k)
894
894
tlx .barrier_wait (consumer_v_view , phase_k ) # consumer wait for v
895
895
# need to acquire o0 to make sure epilogue is done, this is needed for each outer loop
896
896
bufIdx_o_outer , phase_o_outer = _get_bufidx_phase (
897
897
accum_cnt_outer , NUM_BUFFERS_O
898
898
)
899
899
producer_o0_view = tlx .local_view (producer_o0 , bufIdx_o_outer )
900
900
producer_o1_view = tlx .local_view (producer_o1 , bufIdx_o_outer )
901
- tl .device_print ("producer_o0" , accum_cnt_outer )
902
- tl .device_print ("producer_o0_phase" , phase_o_outer )
901
+ # tl.device_print("producer_o0", accum_cnt_outer)
902
+ # tl.device_print("producer_o0_phase", phase_o_outer)
903
903
tlx .barrier_wait (
904
904
producer_o0_view , phase_o_outer ^ 1
905
905
) # producer acquire for o0
@@ -908,8 +908,8 @@ def gdpa_kernel_tma_ws_blackwell(
908
908
# dot partition: producer commit of qk0, ..., consumer wait for p0 (use the same barrier as producer_qk0)
909
909
bufIdx_p , phase_p = _get_bufidx_phase (accum_cnt_qk , NUM_BUFFERS_QK )
910
910
consumer_p0_view = tlx .local_view (producer_qk0 , bufIdx_p )
911
- tl .device_print ("producer_qk0" , accum_cnt_qk )
912
- tl .device_print ("producer_qk0_phase" , phase_p )
911
+ # tl.device_print("producer_qk0", accum_cnt_qk)
912
+ # tl.device_print("producer_qk0_phase", phase_p)
913
913
tlx .barrier_wait (
914
914
consumer_p0_view , phase_p
915
915
) # consumer wait for p0 due to reuse of p0 and qk0
@@ -938,11 +938,11 @@ def gdpa_kernel_tma_ws_blackwell(
938
938
mma_iters = (hi - lo ) // BLOCK_N
939
939
accum_cnt_k += 1
940
940
accum_cnt_qk += 1
941
- tl .device_print ("gemm for " , hi )
942
- tl .device_print ("gemm mma_iters " , mma_iters )
941
+ # tl.device_print("gemm for ", hi)
942
+ # tl.device_print("gemm mma_iters ", mma_iters)
943
943
for it in range (BLOCK_N , hi , BLOCK_N ):
944
944
# for it in range(mma_iters - 1):
945
- tl .device_print ("gemm iter " , it )
945
+ # tl.device_print("gemm iter ", it)
946
946
bufIdx_k , phase_k = _get_bufidx_phase (
947
947
accum_cnt_k , NUM_BUFFERS_K
948
948
)
@@ -952,8 +952,8 @@ def gdpa_kernel_tma_ws_blackwell(
952
952
953
953
# q0 dot k
954
954
consumer_k_view = tlx .local_view (consumer_k , bufIdx_k )
955
- tl .device_print ("consumer_k" , accum_cnt_k )
956
- tl .device_print ("consumer_k_phase" , phase_k )
955
+ # tl.device_print("consumer_k", accum_cnt_k)
956
+ # tl.device_print("consumer_k_phase", phase_k)
957
957
tlx .barrier_wait (
958
958
consumer_k_view , phase_k
959
959
) # consumer wait for k
@@ -975,13 +975,13 @@ def gdpa_kernel_tma_ws_blackwell(
975
975
accum_cnt_qk1 , NUM_BUFFERS_QK
976
976
)
977
977
consumer_p1_view = tlx .local_view (producer_qk1 , bufIdx_qk1 )
978
- tl .device_print ("producer_o1" , accum_cnt_outer )
979
- tl .device_print ("producer_o1_phase" , phase_o_outer )
978
+ # tl.device_print("producer_o1", accum_cnt_outer)
979
+ # tl.device_print("producer_o1_phase", phase_o_outer)
980
980
tlx .barrier_wait (
981
981
producer_o1_view , phase_o_outer ^ 1 , first
982
982
) # producer acquire for o1, only needed for first iteration
983
- tl .device_print ("producer_qk1" , accum_cnt_qk1 )
984
- tl .device_print ("producer_qk1_phase" , phase_qk1 )
983
+ # tl.device_print("producer_qk1", accum_cnt_qk1)
984
+ # tl.device_print("producer_qk1_phase", phase_qk1)
985
985
tlx .barrier_wait (
986
986
consumer_p1_view , phase_qk1
987
987
) # consumer wait for p1 use producer_qk1 due to reuse
@@ -1038,16 +1038,16 @@ def gdpa_kernel_tma_ws_blackwell(
1038
1038
1039
1039
# p0 dot v
1040
1040
consumer_v_view = tlx .local_view (consumer_v , bufIdx_k )
1041
- tl .device_print ("consumer_v" , accum_cnt_k )
1042
- tl .device_print ("consumer_v_phase" , phase_k )
1041
+ # tl.device_print("consumer_v", accum_cnt_k)
1042
+ # tl.device_print("consumer_v_phase", phase_k)
1043
1043
tlx .barrier_wait (
1044
1044
consumer_v_view , phase_k
1045
1045
) # consumer wait for v
1046
1046
# no need to acquire o0 as this is the only partition updating it
1047
1047
# tlx.barrier_wait(producer_o0) # producer acquire for o0
1048
1048
consumer_p0_view = tlx .local_view (producer_qk0 , bufIdx_qk )
1049
- tl .device_print ("producer_qk0" , accum_cnt_qk )
1050
- tl .device_print ("producer_qk0_phase" , phase_qk )
1049
+ # tl.device_print("producer_qk0", accum_cnt_qk)
1050
+ # tl.device_print("producer_qk0_phase", phase_qk)
1051
1051
tlx .barrier_wait (
1052
1052
consumer_p0_view , phase_qk
1053
1053
) # consumer wait for p0 use producer_qk0 due to reuse
@@ -1084,17 +1084,17 @@ def gdpa_kernel_tma_ws_blackwell(
1084
1084
tlx .tcgen05_commit (release_q0_view )
1085
1085
release_q1_view = tlx .local_view (consumer_release_q1 , bufIdx_q )
1086
1086
tlx .tcgen05_commit (release_q1_view )
1087
- tl .device_print ("producer_o1_epilogue" , accum_cnt_outer )
1088
- tl .device_print ("producer_o1_phase" , phase_o_outer )
1087
+ # tl.device_print("producer_o1_epilogue", accum_cnt_outer)
1088
+ # tl.device_print("producer_o1_phase", phase_o_outer)
1089
1089
tlx .barrier_wait (
1090
1090
producer_o1_view , phase_o_outer ^ 1 , first
1091
1091
) # producer acquire for o1 at the first iteration
1092
1092
bufIdx_qk1 , phase_qk1 = _get_bufidx_phase (
1093
1093
accum_cnt_qk1 , NUM_BUFFERS_QK
1094
1094
)
1095
1095
consumer_p1_view = tlx .local_view (producer_qk1 , bufIdx_qk1 )
1096
- tl .device_print ("producer_qk1_epilogue" , accum_cnt_qk1 )
1097
- tl .device_print ("producer_qk1_phase" , phase_qk1 )
1096
+ # tl.device_print("producer_qk1_epilogue", accum_cnt_qk1)
1097
+ # tl.device_print("producer_qk1_phase", phase_qk1)
1098
1098
tlx .barrier_wait (
1099
1099
consumer_p1_view , phase_qk1
1100
1100
) # consumer wait for p1 due to reuse of p1 and qk1
@@ -1416,7 +1416,7 @@ def alloc_fn(size: int, alignment: int, _):
1416
1416
1417
1417
def grid_tma_persistent (META ):
1418
1418
return (
1419
- 1 , # min(NUM_SMS, triton.cdiv(max_seq_len_q, META["BLOCK_M"]) * BATCH * nheads),
1419
+ min (NUM_SMS , triton .cdiv (max_seq_len_q , META ["BLOCK_M" ]) * BATCH * nheads ),
1420
1420
1 ,
1421
1421
1 ,
1422
1422
)
@@ -1428,7 +1428,7 @@ def grid_tma_persistent(META):
1428
1428
vstrides = v .stride ()
1429
1429
1430
1430
activation_enum_int = activation_string_to_int (activation )
1431
- print ("activation_enum_int" , activation , activation_enum_int )
1431
+ # print("activation_enum_int", activation, activation_enum_int)
1432
1432
1433
1433
gdpa_kernel_tma_ws_blackwell [grid_tma_persistent ](
1434
1434
q ,
0 commit comments