Skip to content

Commit 9d2dfe5

Browse files
committed
fix issues (tma bytes, loop bound) + debugging
Summary: 1 grid Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 57e9261 commit 9d2dfe5

File tree

1 file changed

+46
-5
lines changed

1 file changed

+46
-5
lines changed

tritonbench/operators/gdpa/gdpa_blackwell_tlx.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -584,8 +584,10 @@ def gdpa_kernel_tma_ws_blackwell(
584584
out_offset = off_h.to(tl.int64) * stride_oh
585585
if start_m * BLOCK_M < qlen:
586586
lo, hi = 0, klen
587+
tl.device_print("default", hi)
587588
for start_n in range(lo, hi, BLOCK_N):
588589
start_n = tl.multiple_of(start_n, BLOCK_N)
590+
tl.device_print("default start_n", start_n)
589591
## communication channel for qk0, p0
590592
# _do_activation(
591593
# qk0_buf,
@@ -602,6 +604,8 @@ def gdpa_kernel_tma_ws_blackwell(
602604
phase = (accum_cnt // NUM_BUFFERS_QK) & 1
603605
qk_view = tlx.local_view(qk0_buf, bufIdx)
604606
consumer_qk_view = tlx.local_view(producer_commit_qk0, bufIdx)
607+
tl.device_print("producer_commit_qk0", accum_cnt)
608+
tl.device_print("producer_commit_qk0_phase", phase)
605609
tlx.barrier_wait(consumer_qk_view, phase)
606610
qk0 = tlx.local_load(qk_view) # , tlx.storage_kind.tmem)
607611
# ConsumerWait for qk, ProducerAcquire for p
@@ -632,6 +636,8 @@ def gdpa_kernel_tma_ws_blackwell(
632636
phase = (accum_cnt // NUM_BUFFERS_O) & 1
633637
# consumer wait of o0: producer_commit
634638
consumer_o0_view = tlx.local_view(producer_commit_o0, bufIdx)
639+
tl.device_print("producer_commit_o0", accum_cnt)
640+
tl.device_print("producer_commit_o0_phase", phase)
635641
tlx.barrier_wait(consumer_o0_view, phase)
636642
accum_cnt += 1
637643

@@ -647,6 +653,7 @@ def gdpa_kernel_tma_ws_blackwell(
647653
consumer_release_o0_view = tlx.local_view(
648654
producer_o0, bufIdx_o_outer
649655
)
656+
tl.device_print("arrive producer_o0", accum_cnt_outer)
650657
tlx.barrier_arrive(consumer_release_o0_view, 1)
651658
o0_desc = tl.make_tensor_descriptor(
652659
Out,
@@ -662,6 +669,7 @@ def gdpa_kernel_tma_ws_blackwell(
662669
o0.to(Out.type.element_ty),
663670
)
664671
accum_cnt_outer += 1
672+
tile_idx += num_progs
665673

666674
with tlx.async_task(num_warps=4):
667675
accum_cnt = 0
@@ -761,6 +769,7 @@ def gdpa_kernel_tma_ws_blackwell(
761769
o1.to(Out.type.element_ty),
762770
)
763771
accum_cnt_outer += 1
772+
tile_idx += num_progs
764773

765774
with tlx.async_task(num_warps=1): # gemm
766775
accum_cnt_q = 0
@@ -829,7 +838,11 @@ def gdpa_kernel_tma_ws_blackwell(
829838
consumer_q0_view = tlx.local_view(consumer_q0, bufIdx_q)
830839
consumer_k_view = tlx.local_view(consumer_k, bufIdx_k)
831840
# producer_qk0_view = tlx.local_view(producer_qk0, bufIdx_qk)
841+
tl.device_print("consumer_q0_prologue", accum_cnt_q)
842+
tl.device_print("consumer_q0_phase", phase_q)
832843
tlx.barrier_wait(consumer_q0_view, phase_q) # consumer wait for q0
844+
tl.device_print("consumer_k", accum_cnt_k)
845+
tl.device_print("consumer_k_phase", phase_k)
833846
tlx.barrier_wait(consumer_k_view, phase_k) # consumer wait for k
834847
# Do we need the initial acquire here?
835848
# dot partition has producer commit for qk0, activation partition consumer wait for qk0
@@ -853,6 +866,8 @@ def gdpa_kernel_tma_ws_blackwell(
853866

854867
consumer_q1_view = tlx.local_view(consumer_q1, bufIdx_q)
855868
# producer_qk1_view = tlx.local_view(producer_qk1, bufIdx_qk)
869+
tl.device_print("consumer_q1", accum_cnt_q)
870+
tl.device_print("consumer_q1_phase", phase_q)
856871
tlx.barrier_wait(consumer_q1_view, phase_q) # consumer wait for q1
857872
# tlx.barrier_wait(producer_qk1_view, phase_qk) # producer acquire for qk1
858873
# consumer release for k, producer commit for qk1
@@ -874,13 +889,17 @@ def gdpa_kernel_tma_ws_blackwell(
874889
# accum_cnt_qk1 += 1
875890

876891
consumer_v_view = tlx.local_view(consumer_v, bufIdx_k)
892+
tl.device_print("consumer_v", accum_cnt_k)
893+
tl.device_print("consumer_v_phase", phase_k)
877894
tlx.barrier_wait(consumer_v_view, phase_k) # consumer wait for v
878895
# need to acquire o0 to make sure epilogue is done, this is needed for each outer loop
879896
bufIdx_o_outer, phase_o_outer = _get_bufidx_phase(
880897
accum_cnt_outer, NUM_BUFFERS_O
881898
)
882899
producer_o0_view = tlx.local_view(producer_o0, bufIdx_o_outer)
883900
producer_o1_view = tlx.local_view(producer_o1, bufIdx_o_outer)
901+
tl.device_print("producer_o0", accum_cnt_outer)
902+
tl.device_print("producer_o0_phase", phase_o_outer)
884903
tlx.barrier_wait(
885904
producer_o0_view, phase_o_outer ^ 1
886905
) # producer acquire for o0
@@ -889,6 +908,8 @@ def gdpa_kernel_tma_ws_blackwell(
889908
# dot partition: producer commit of qk0, ..., consumer wait for p0 (use the same barrier as producer_qk0)
890909
bufIdx_p, phase_p = _get_bufidx_phase(accum_cnt_qk, NUM_BUFFERS_QK)
891910
consumer_p0_view = tlx.local_view(producer_qk0, bufIdx_p)
911+
tl.device_print("producer_qk0", accum_cnt_qk)
912+
tl.device_print("producer_qk0_phase", phase_p)
892913
tlx.barrier_wait(
893914
consumer_p0_view, phase_p
894915
) # consumer wait for p0 due to reuse of p0 and qk0
@@ -917,7 +938,11 @@ def gdpa_kernel_tma_ws_blackwell(
917938
mma_iters = (hi - lo) // BLOCK_N
918939
accum_cnt_k += 1
919940
accum_cnt_qk += 1
920-
for _ in range(mma_iters - 1):
941+
tl.device_print("gemm for ", hi)
942+
tl.device_print("gemm mma_iters ", mma_iters)
943+
for it in range(BLOCK_N, hi, BLOCK_N):
944+
# for it in range(mma_iters - 1):
945+
tl.device_print("gemm iter ", it)
921946
bufIdx_k, phase_k = _get_bufidx_phase(
922947
accum_cnt_k, NUM_BUFFERS_K
923948
)
@@ -927,6 +952,8 @@ def gdpa_kernel_tma_ws_blackwell(
927952

928953
# q0 dot k
929954
consumer_k_view = tlx.local_view(consumer_k, bufIdx_k)
955+
tl.device_print("consumer_k", accum_cnt_k)
956+
tl.device_print("consumer_k_phase", phase_k)
930957
tlx.barrier_wait(
931958
consumer_k_view, phase_k
932959
) # consumer wait for k
@@ -948,9 +975,13 @@ def gdpa_kernel_tma_ws_blackwell(
948975
accum_cnt_qk1, NUM_BUFFERS_QK
949976
)
950977
consumer_p1_view = tlx.local_view(producer_qk1, bufIdx_qk1)
978+
tl.device_print("producer_o1", accum_cnt_outer)
979+
tl.device_print("producer_o1_phase", phase_o_outer)
951980
tlx.barrier_wait(
952981
producer_o1_view, phase_o_outer ^ 1, first
953982
) # producer acquire for o1, only needed for first iteration
983+
tl.device_print("producer_qk1", accum_cnt_qk1)
984+
tl.device_print("producer_qk1_phase", phase_qk1)
954985
tlx.barrier_wait(
955986
consumer_p1_view, phase_qk1
956987
) # consumer wait for p1 use producer_qk1 due to reuse
@@ -1007,12 +1038,16 @@ def gdpa_kernel_tma_ws_blackwell(
10071038

10081039
# p0 dot v
10091040
consumer_v_view = tlx.local_view(consumer_v, bufIdx_k)
1041+
tl.device_print("consumer_v", accum_cnt_k)
1042+
tl.device_print("consumer_v_phase", phase_k)
10101043
tlx.barrier_wait(
10111044
consumer_v_view, phase_k
10121045
) # consumer wait for v
10131046
# no need to acquire o0 as this is the only partition updating it
10141047
# tlx.barrier_wait(producer_o0) # producer acquire for o0
10151048
consumer_p0_view = tlx.local_view(producer_qk0, bufIdx_qk)
1049+
tl.device_print("producer_qk0", accum_cnt_qk)
1050+
tl.device_print("producer_qk0_phase", phase_qk)
10161051
tlx.barrier_wait(
10171052
consumer_p0_view, phase_qk
10181053
) # consumer wait for p0 use producer_qk0 due to reuse
@@ -1049,13 +1084,17 @@ def gdpa_kernel_tma_ws_blackwell(
10491084
tlx.tcgen05_commit(release_q0_view)
10501085
release_q1_view = tlx.local_view(consumer_release_q1, bufIdx_q)
10511086
tlx.tcgen05_commit(release_q1_view)
1087+
tl.device_print("producer_o1_epilogue", accum_cnt_outer)
1088+
tl.device_print("producer_o1_phase", phase_o_outer)
10521089
tlx.barrier_wait(
10531090
producer_o1_view, phase_o_outer ^ 1, first
10541091
) # producer acquire for o1 at the first iteration
10551092
bufIdx_qk1, phase_qk1 = _get_bufidx_phase(
10561093
accum_cnt_qk1, NUM_BUFFERS_QK
10571094
)
10581095
consumer_p1_view = tlx.local_view(producer_qk1, bufIdx_qk1)
1096+
tl.device_print("producer_qk1_epilogue", accum_cnt_qk1)
1097+
tl.device_print("producer_qk1_phase", phase_qk1)
10591098
tlx.barrier_wait(
10601099
consumer_p1_view, phase_qk1
10611100
) # consumer wait for p1 due to reuse of p1 and qk1
@@ -1092,6 +1131,7 @@ def gdpa_kernel_tma_ws_blackwell(
10921131
accum_cnt_outer += 1
10931132
# signal producer commit of epi0 and epi1, we don't want to block the gemm partition
10941133
# to wait for the completion
1134+
tile_idx += num_progs
10951135

10961136
with tlx.async_task(num_warps=1): # load
10971137
accum_count_q = 0
@@ -1153,7 +1193,7 @@ def gdpa_kernel_tma_ws_blackwell(
11531193
consumer_q0, q_bufIdx
11541194
) # full_bars, bufIdx)
11551195
tlx.barrier_expect_bytes(
1156-
q0_full_view, BLOCK_M * BLOCK_D * 2
1196+
q0_full_view, BLOCK_M // 2 * BLOCK_D * 2
11571197
) # num_bytes)
11581198
q0_smem_view = tlx.local_view(q0_buf, q_bufIdx)
11591199
tlx.async_descriptor_load(
@@ -1183,7 +1223,7 @@ def gdpa_kernel_tma_ws_blackwell(
11831223
# barrier for producer commit
11841224
q1_full_view = tlx.local_view(consumer_q1, q_bufIdx)
11851225
tlx.barrier_expect_bytes(
1186-
q1_full_view, BLOCK_M * BLOCK_D * 2
1226+
q1_full_view, BLOCK_M // 2 * BLOCK_D * 2
11871227
) # num_bytes)
11881228
q1_smem_view = tlx.local_view(q1_buf, q_bufIdx)
11891229
tlx.async_descriptor_load(
@@ -1259,8 +1299,9 @@ def gdpa_kernel_tma_ws_blackwell(
12591299
v_full_view,
12601300
)
12611301
accum_count_k += 1
1262-
1302+
# outside of inner for
12631303
accum_count_q += 1
1304+
tile_idx += num_progs
12641305

12651306

12661307
def next_power_of_2(x):
@@ -1375,7 +1416,7 @@ def alloc_fn(size: int, alignment: int, _):
13751416

13761417
def grid_tma_persistent(META):
13771418
return (
1378-
min(NUM_SMS, triton.cdiv(max_seq_len_q, META["BLOCK_M"]) * BATCH * nheads),
1419+
1, # min(NUM_SMS, triton.cdiv(max_seq_len_q, META["BLOCK_M"]) * BATCH * nheads),
13791420
1,
13801421
1,
13811422
)

0 commit comments

Comments
 (0)