Skip to content

Commit 8f7b5a2

Browse files
authored
fix data type to not hardcode to float16
Differential Revision: D81138002 Pull Request resolved: #361
1 parent febe334 commit 8f7b5a2

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

tritonbench/operators/gdpa/gdpa_blackwell_tlx.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -254,12 +254,17 @@ def gdpa_kernel_tma_ws_blackwell(
254254
block_shape=[BLOCK_N, BLOCK_D],
255255
)
256256

257+
if USE_ON_DEVICE_TMA:
258+
dtype = V.dtype.element_ty # v_dtype)
259+
else:
260+
dtype = tlx.dtype_of(v_desc)
261+
257262
# allocate buffers for q0, q1
258-
q0_buf = tlx.local_alloc((BLOCK_M // 2, BLOCK_D), tl.float16, 1)
259-
q1_buf = tlx.local_alloc((BLOCK_M // 2, BLOCK_D), tl.float16, 1)
263+
q0_buf = tlx.local_alloc((BLOCK_M // 2, BLOCK_D), dtype, 1)
264+
q1_buf = tlx.local_alloc((BLOCK_M // 2, BLOCK_D), dtype, 1)
260265

261266
# allocate buffers for k, v
262-
kv_buf = tlx.local_alloc((BLOCK_N, BLOCK_D), tl.float16, NUM_BUFFERS_KV) # k
267+
kv_buf = tlx.local_alloc((BLOCK_N, BLOCK_D), dtype, NUM_BUFFERS_KV) # k
263268

264269
# allocate tmem for outputs of 4 dots (after partitioning)
265270
# qk0 = q0 dot k, qk1 = q1 dot k, acc0 = p0 dot v, acc1 = p1 dot v
@@ -357,7 +362,7 @@ def gdpa_kernel_tma_ws_blackwell(
357362
else:
358363
p0 = p0.to(tlx.dtype_of(v_desc))
359364
qk_view = tlx.local_view(qk0_buf, bufIdx)
360-
p0_view = tlx.local_reinterpret(qk_view, tl.float16)
365+
p0_view = tlx.local_reinterpret(qk_view, dtype)
361366
tlx.local_store(p0_view, p0) # , tlx.storage_kind.tmem)
362367
# p and qk reuse tmem space, single producer commit for p via consumer_release_qk
363368
consumer_release_qk_view = tlx.local_view(producer_qk0, bufIdx)
@@ -448,7 +453,7 @@ def gdpa_kernel_tma_ws_blackwell(
448453
else:
449454
p1 = p1.to(tlx.dtype_of(v_desc))
450455
qk_view = tlx.local_view(qk1_buf, bufIdx)
451-
p1_view = tlx.local_reinterpret(qk_view, tl.float16)
456+
p1_view = tlx.local_reinterpret(qk_view, dtype)
452457
tlx.local_store(p1_view, p1) # , tlx.storage_kind.tmem)
453458
# p and qk reuse tmem space, single producer commit for p via consumer_release_qk
454459
consumer_release_qk_view = tlx.local_view(producer_qk1, bufIdx)
@@ -617,7 +622,7 @@ def gdpa_kernel_tma_ws_blackwell(
617622
# reinterpret qk0 as p0
618623
# p0_view = _reinterpret(qk0_buf, bufIdx_p)
619624
qk_view = tlx.local_view(qk0_buf, bufIdx_p)
620-
p0_view = tlx.local_reinterpret(qk_view, tl.float16)
625+
p0_view = tlx.local_reinterpret(qk_view, dtype)
621626

622627
bufIdx_o, phase_o = _get_bufidx_phase(accum_cnt_o, NUM_BUFFERS_O)
623628
producer_commit_o0_view = tlx.local_view(
@@ -709,7 +714,7 @@ def gdpa_kernel_tma_ws_blackwell(
709714
# reinterpret as p1
710715
# p1_view = _reinterpret(qk1_buf, bufIdx_qk1)
711716
qk_view = tlx.local_view(qk1_buf, bufIdx_qk1)
712-
p1_view = tlx.local_reinterpret(qk_view, tl.float16)
717+
p1_view = tlx.local_reinterpret(qk_view, dtype)
713718
tlx.async_dot( # p1 . v from previous iteration
714719
p1_view,
715720
v_view,
@@ -770,7 +775,7 @@ def gdpa_kernel_tma_ws_blackwell(
770775
# reinterpret as p0
771776
# p0_view = _reinterpret(qk0_buf, bufIdx_qk)
772777
qk_view = tlx.local_view(qk0_buf, bufIdx_qk)
773-
p0_view = tlx.local_reinterpret(qk_view, tl.float16)
778+
p0_view = tlx.local_reinterpret(qk_view, dtype)
774779

775780
v_view = tlx.local_view(kv_buf, bufIdx_v)
776781
bufIdx_o, phase_o = _get_bufidx_phase(
@@ -819,7 +824,7 @@ def gdpa_kernel_tma_ws_blackwell(
819824
) # consumer wait for p1 due to reuse of p1 and qk1
820825
# p1_view = _reinterpret(qk1_buf, bufIdx_qk1)
821826
qk_view = tlx.local_view(qk1_buf, bufIdx_qk1)
822-
p1_view = tlx.local_reinterpret(qk_view, tl.float16)
827+
p1_view = tlx.local_reinterpret(qk_view, dtype)
823828

824829
accum_cnt_qk1 += 1
825830
# release p0, p1 via producer_commit_qk0, qk1 barriers

0 commit comments

Comments
 (0)