Skip to content

Commit c04d334

Browse files
committed
remove debugging, correct grid
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 9d2dfe5 commit c04d334

File tree

1 file changed

+38
-38
lines changed

1 file changed

+38
-38
lines changed

tritonbench/operators/gdpa/gdpa_blackwell_tlx.py

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -584,10 +584,10 @@ def gdpa_kernel_tma_ws_blackwell(
584584
out_offset = off_h.to(tl.int64) * stride_oh
585585
if start_m * BLOCK_M < qlen:
586586
lo, hi = 0, klen
587-
tl.device_print("default", hi)
587+
# tl.device_print("default", hi)
588588
for start_n in range(lo, hi, BLOCK_N):
589589
start_n = tl.multiple_of(start_n, BLOCK_N)
590-
tl.device_print("default start_n", start_n)
590+
# tl.device_print("default start_n", start_n)
591591
## communication channel for qk0, p0
592592
# _do_activation(
593593
# qk0_buf,
@@ -604,8 +604,8 @@ def gdpa_kernel_tma_ws_blackwell(
604604
phase = (accum_cnt // NUM_BUFFERS_QK) & 1
605605
qk_view = tlx.local_view(qk0_buf, bufIdx)
606606
consumer_qk_view = tlx.local_view(producer_commit_qk0, bufIdx)
607-
tl.device_print("producer_commit_qk0", accum_cnt)
608-
tl.device_print("producer_commit_qk0_phase", phase)
607+
# tl.device_print("producer_commit_qk0", accum_cnt)
608+
# tl.device_print("producer_commit_qk0_phase", phase)
609609
tlx.barrier_wait(consumer_qk_view, phase)
610610
qk0 = tlx.local_load(qk_view) # , tlx.storage_kind.tmem)
611611
# ConsumerWait for qk, ProducerAcquire for p
@@ -636,8 +636,8 @@ def gdpa_kernel_tma_ws_blackwell(
636636
phase = (accum_cnt // NUM_BUFFERS_O) & 1
637637
# consumer wait of o0: producer_commit
638638
consumer_o0_view = tlx.local_view(producer_commit_o0, bufIdx)
639-
tl.device_print("producer_commit_o0", accum_cnt)
640-
tl.device_print("producer_commit_o0_phase", phase)
639+
# tl.device_print("producer_commit_o0", accum_cnt)
640+
# tl.device_print("producer_commit_o0_phase", phase)
641641
tlx.barrier_wait(consumer_o0_view, phase)
642642
accum_cnt += 1
643643

@@ -653,7 +653,7 @@ def gdpa_kernel_tma_ws_blackwell(
653653
consumer_release_o0_view = tlx.local_view(
654654
producer_o0, bufIdx_o_outer
655655
)
656-
tl.device_print("arrive producer_o0", accum_cnt_outer)
656+
# tl.device_print("arrive producer_o0", accum_cnt_outer)
657657
tlx.barrier_arrive(consumer_release_o0_view, 1)
658658
o0_desc = tl.make_tensor_descriptor(
659659
Out,
@@ -838,11 +838,11 @@ def gdpa_kernel_tma_ws_blackwell(
838838
consumer_q0_view = tlx.local_view(consumer_q0, bufIdx_q)
839839
consumer_k_view = tlx.local_view(consumer_k, bufIdx_k)
840840
# producer_qk0_view = tlx.local_view(producer_qk0, bufIdx_qk)
841-
tl.device_print("consumer_q0_prologue", accum_cnt_q)
842-
tl.device_print("consumer_q0_phase", phase_q)
841+
# tl.device_print("consumer_q0_prologue", accum_cnt_q)
842+
# tl.device_print("consumer_q0_phase", phase_q)
843843
tlx.barrier_wait(consumer_q0_view, phase_q) # consumer wait for q0
844-
tl.device_print("consumer_k", accum_cnt_k)
845-
tl.device_print("consumer_k_phase", phase_k)
844+
# tl.device_print("consumer_k", accum_cnt_k)
845+
# tl.device_print("consumer_k_phase", phase_k)
846846
tlx.barrier_wait(consumer_k_view, phase_k) # consumer wait for k
847847
# Do we need the initial acquire here?
848848
# dot partition has producer commit for qk0, activation partition consumer wait for qk0
@@ -866,8 +866,8 @@ def gdpa_kernel_tma_ws_blackwell(
866866

867867
consumer_q1_view = tlx.local_view(consumer_q1, bufIdx_q)
868868
# producer_qk1_view = tlx.local_view(producer_qk1, bufIdx_qk)
869-
tl.device_print("consumer_q1", accum_cnt_q)
870-
tl.device_print("consumer_q1_phase", phase_q)
869+
# tl.device_print("consumer_q1", accum_cnt_q)
870+
# tl.device_print("consumer_q1_phase", phase_q)
871871
tlx.barrier_wait(consumer_q1_view, phase_q) # consumer wait for q1
872872
# tlx.barrier_wait(producer_qk1_view, phase_qk) # producer acquire for qk1
873873
# consumer release for k, producer commit for qk1
@@ -889,17 +889,17 @@ def gdpa_kernel_tma_ws_blackwell(
889889
# accum_cnt_qk1 += 1
890890

891891
consumer_v_view = tlx.local_view(consumer_v, bufIdx_k)
892-
tl.device_print("consumer_v", accum_cnt_k)
893-
tl.device_print("consumer_v_phase", phase_k)
892+
# tl.device_print("consumer_v", accum_cnt_k)
893+
# tl.device_print("consumer_v_phase", phase_k)
894894
tlx.barrier_wait(consumer_v_view, phase_k) # consumer wait for v
895895
# need to acquire o0 to make sure epilogue is done, this is needed for each outer loop
896896
bufIdx_o_outer, phase_o_outer = _get_bufidx_phase(
897897
accum_cnt_outer, NUM_BUFFERS_O
898898
)
899899
producer_o0_view = tlx.local_view(producer_o0, bufIdx_o_outer)
900900
producer_o1_view = tlx.local_view(producer_o1, bufIdx_o_outer)
901-
tl.device_print("producer_o0", accum_cnt_outer)
902-
tl.device_print("producer_o0_phase", phase_o_outer)
901+
# tl.device_print("producer_o0", accum_cnt_outer)
902+
# tl.device_print("producer_o0_phase", phase_o_outer)
903903
tlx.barrier_wait(
904904
producer_o0_view, phase_o_outer ^ 1
905905
) # producer acquire for o0
@@ -908,8 +908,8 @@ def gdpa_kernel_tma_ws_blackwell(
908908
# dot partition: producer commit of qk0, ..., consumer wait for p0 (use the same barrier as producer_qk0)
909909
bufIdx_p, phase_p = _get_bufidx_phase(accum_cnt_qk, NUM_BUFFERS_QK)
910910
consumer_p0_view = tlx.local_view(producer_qk0, bufIdx_p)
911-
tl.device_print("producer_qk0", accum_cnt_qk)
912-
tl.device_print("producer_qk0_phase", phase_p)
911+
# tl.device_print("producer_qk0", accum_cnt_qk)
912+
# tl.device_print("producer_qk0_phase", phase_p)
913913
tlx.barrier_wait(
914914
consumer_p0_view, phase_p
915915
) # consumer wait for p0 due to reuse of p0 and qk0
@@ -938,11 +938,11 @@ def gdpa_kernel_tma_ws_blackwell(
938938
mma_iters = (hi - lo) // BLOCK_N
939939
accum_cnt_k += 1
940940
accum_cnt_qk += 1
941-
tl.device_print("gemm for ", hi)
942-
tl.device_print("gemm mma_iters ", mma_iters)
941+
# tl.device_print("gemm for ", hi)
942+
# tl.device_print("gemm mma_iters ", mma_iters)
943943
for it in range(BLOCK_N, hi, BLOCK_N):
944944
# for it in range(mma_iters - 1):
945-
tl.device_print("gemm iter ", it)
945+
# tl.device_print("gemm iter ", it)
946946
bufIdx_k, phase_k = _get_bufidx_phase(
947947
accum_cnt_k, NUM_BUFFERS_K
948948
)
@@ -952,8 +952,8 @@ def gdpa_kernel_tma_ws_blackwell(
952952

953953
# q0 dot k
954954
consumer_k_view = tlx.local_view(consumer_k, bufIdx_k)
955-
tl.device_print("consumer_k", accum_cnt_k)
956-
tl.device_print("consumer_k_phase", phase_k)
955+
# tl.device_print("consumer_k", accum_cnt_k)
956+
# tl.device_print("consumer_k_phase", phase_k)
957957
tlx.barrier_wait(
958958
consumer_k_view, phase_k
959959
) # consumer wait for k
@@ -975,13 +975,13 @@ def gdpa_kernel_tma_ws_blackwell(
975975
accum_cnt_qk1, NUM_BUFFERS_QK
976976
)
977977
consumer_p1_view = tlx.local_view(producer_qk1, bufIdx_qk1)
978-
tl.device_print("producer_o1", accum_cnt_outer)
979-
tl.device_print("producer_o1_phase", phase_o_outer)
978+
# tl.device_print("producer_o1", accum_cnt_outer)
979+
# tl.device_print("producer_o1_phase", phase_o_outer)
980980
tlx.barrier_wait(
981981
producer_o1_view, phase_o_outer ^ 1, first
982982
) # producer acquire for o1, only needed for first iteration
983-
tl.device_print("producer_qk1", accum_cnt_qk1)
984-
tl.device_print("producer_qk1_phase", phase_qk1)
983+
# tl.device_print("producer_qk1", accum_cnt_qk1)
984+
# tl.device_print("producer_qk1_phase", phase_qk1)
985985
tlx.barrier_wait(
986986
consumer_p1_view, phase_qk1
987987
) # consumer wait for p1 use producer_qk1 due to reuse
@@ -1038,16 +1038,16 @@ def gdpa_kernel_tma_ws_blackwell(
10381038

10391039
# p0 dot v
10401040
consumer_v_view = tlx.local_view(consumer_v, bufIdx_k)
1041-
tl.device_print("consumer_v", accum_cnt_k)
1042-
tl.device_print("consumer_v_phase", phase_k)
1041+
# tl.device_print("consumer_v", accum_cnt_k)
1042+
# tl.device_print("consumer_v_phase", phase_k)
10431043
tlx.barrier_wait(
10441044
consumer_v_view, phase_k
10451045
) # consumer wait for v
10461046
# no need to acquire o0 as this is the only partition updating it
10471047
# tlx.barrier_wait(producer_o0) # producer acquire for o0
10481048
consumer_p0_view = tlx.local_view(producer_qk0, bufIdx_qk)
1049-
tl.device_print("producer_qk0", accum_cnt_qk)
1050-
tl.device_print("producer_qk0_phase", phase_qk)
1049+
# tl.device_print("producer_qk0", accum_cnt_qk)
1050+
# tl.device_print("producer_qk0_phase", phase_qk)
10511051
tlx.barrier_wait(
10521052
consumer_p0_view, phase_qk
10531053
) # consumer wait for p0 use producer_qk0 due to reuse
@@ -1084,17 +1084,17 @@ def gdpa_kernel_tma_ws_blackwell(
10841084
tlx.tcgen05_commit(release_q0_view)
10851085
release_q1_view = tlx.local_view(consumer_release_q1, bufIdx_q)
10861086
tlx.tcgen05_commit(release_q1_view)
1087-
tl.device_print("producer_o1_epilogue", accum_cnt_outer)
1088-
tl.device_print("producer_o1_phase", phase_o_outer)
1087+
# tl.device_print("producer_o1_epilogue", accum_cnt_outer)
1088+
# tl.device_print("producer_o1_phase", phase_o_outer)
10891089
tlx.barrier_wait(
10901090
producer_o1_view, phase_o_outer ^ 1, first
10911091
) # producer acquire for o1 at the first iteration
10921092
bufIdx_qk1, phase_qk1 = _get_bufidx_phase(
10931093
accum_cnt_qk1, NUM_BUFFERS_QK
10941094
)
10951095
consumer_p1_view = tlx.local_view(producer_qk1, bufIdx_qk1)
1096-
tl.device_print("producer_qk1_epilogue", accum_cnt_qk1)
1097-
tl.device_print("producer_qk1_phase", phase_qk1)
1096+
# tl.device_print("producer_qk1_epilogue", accum_cnt_qk1)
1097+
# tl.device_print("producer_qk1_phase", phase_qk1)
10981098
tlx.barrier_wait(
10991099
consumer_p1_view, phase_qk1
11001100
) # consumer wait for p1 due to reuse of p1 and qk1
@@ -1416,7 +1416,7 @@ def alloc_fn(size: int, alignment: int, _):
14161416

14171417
def grid_tma_persistent(META):
14181418
return (
1419-
1, # min(NUM_SMS, triton.cdiv(max_seq_len_q, META["BLOCK_M"]) * BATCH * nheads),
1419+
min(NUM_SMS, triton.cdiv(max_seq_len_q, META["BLOCK_M"]) * BATCH * nheads),
14201420
1,
14211421
1,
14221422
)
@@ -1428,7 +1428,7 @@ def grid_tma_persistent(META):
14281428
vstrides = v.stride()
14291429

14301430
activation_enum_int = activation_string_to_int(activation)
1431-
print("activation_enum_int", activation, activation_enum_int)
1431+
# print("activation_enum_int", activation, activation_enum_int)
14321432

14331433
gdpa_kernel_tma_ws_blackwell[grid_tma_persistent](
14341434
q,

0 commit comments

Comments
 (0)