@@ -254,12 +254,17 @@ def gdpa_kernel_tma_ws_blackwell(
254
254
block_shape = [BLOCK_N , BLOCK_D ],
255
255
)
256
256
257
+ if USE_ON_DEVICE_TMA :
258
+ dtype = V .dtype .element_ty # v_dtype)
259
+ else :
260
+ dtype = tlx .dtype_of (v_desc )
261
+
257
262
# allocate buffers for q0, q1
258
- q0_buf = tlx .local_alloc ((BLOCK_M // 2 , BLOCK_D ), tl . float16 , 1 )
259
- q1_buf = tlx .local_alloc ((BLOCK_M // 2 , BLOCK_D ), tl . float16 , 1 )
263
+ q0_buf = tlx .local_alloc ((BLOCK_M // 2 , BLOCK_D ), dtype , 1 )
264
+ q1_buf = tlx .local_alloc ((BLOCK_M // 2 , BLOCK_D ), dtype , 1 )
260
265
261
266
# allocate buffers for k, v
262
- kv_buf = tlx .local_alloc ((BLOCK_N , BLOCK_D ), tl . float16 , NUM_BUFFERS_KV ) # k
267
+ kv_buf = tlx .local_alloc ((BLOCK_N , BLOCK_D ), dtype , NUM_BUFFERS_KV ) # k
263
268
264
269
# allocate tmem for outputs of 4 dots (after partitioning)
265
270
# qk0 = q0 dot k, qk1 = q1 dot k, acc0 = p0 dot v, acc1 = p1 dot v
@@ -357,7 +362,7 @@ def gdpa_kernel_tma_ws_blackwell(
357
362
else :
358
363
p0 = p0 .to (tlx .dtype_of (v_desc ))
359
364
qk_view = tlx .local_view (qk0_buf , bufIdx )
360
- p0_view = tlx .local_reinterpret (qk_view , tl . float16 )
365
+ p0_view = tlx .local_reinterpret (qk_view , dtype )
361
366
tlx .local_store (p0_view , p0 ) # , tlx.storage_kind.tmem)
362
367
# p and qk reuse tmem space, single producer commit for p via consumer_release_qk
363
368
consumer_release_qk_view = tlx .local_view (producer_qk0 , bufIdx )
@@ -448,7 +453,7 @@ def gdpa_kernel_tma_ws_blackwell(
448
453
else :
449
454
p1 = p1 .to (tlx .dtype_of (v_desc ))
450
455
qk_view = tlx .local_view (qk1_buf , bufIdx )
451
- p1_view = tlx .local_reinterpret (qk_view , tl . float16 )
456
+ p1_view = tlx .local_reinterpret (qk_view , dtype )
452
457
tlx .local_store (p1_view , p1 ) # , tlx.storage_kind.tmem)
453
458
# p and qk reuse tmem space, single producer commit for p via consumer_release_qk
454
459
consumer_release_qk_view = tlx .local_view (producer_qk1 , bufIdx )
@@ -617,7 +622,7 @@ def gdpa_kernel_tma_ws_blackwell(
617
622
# reinterpret qk0 as p0
618
623
# p0_view = _reinterpret(qk0_buf, bufIdx_p)
619
624
qk_view = tlx .local_view (qk0_buf , bufIdx_p )
620
- p0_view = tlx .local_reinterpret (qk_view , tl . float16 )
625
+ p0_view = tlx .local_reinterpret (qk_view , dtype )
621
626
622
627
bufIdx_o , phase_o = _get_bufidx_phase (accum_cnt_o , NUM_BUFFERS_O )
623
628
producer_commit_o0_view = tlx .local_view (
@@ -709,7 +714,7 @@ def gdpa_kernel_tma_ws_blackwell(
709
714
# reinterpret as p1
710
715
# p1_view = _reinterpret(qk1_buf, bufIdx_qk1)
711
716
qk_view = tlx .local_view (qk1_buf , bufIdx_qk1 )
712
- p1_view = tlx .local_reinterpret (qk_view , tl . float16 )
717
+ p1_view = tlx .local_reinterpret (qk_view , dtype )
713
718
tlx .async_dot ( # p1 . v from previous iteration
714
719
p1_view ,
715
720
v_view ,
@@ -770,7 +775,7 @@ def gdpa_kernel_tma_ws_blackwell(
770
775
# reinterpret as p0
771
776
# p0_view = _reinterpret(qk0_buf, bufIdx_qk)
772
777
qk_view = tlx .local_view (qk0_buf , bufIdx_qk )
773
- p0_view = tlx .local_reinterpret (qk_view , tl . float16 )
778
+ p0_view = tlx .local_reinterpret (qk_view , dtype )
774
779
775
780
v_view = tlx .local_view (kv_buf , bufIdx_v )
776
781
bufIdx_o , phase_o = _get_bufidx_phase (
@@ -819,7 +824,7 @@ def gdpa_kernel_tma_ws_blackwell(
819
824
) # consumer wait for p1 due to reuse of p1 and qk1
820
825
# p1_view = _reinterpret(qk1_buf, bufIdx_qk1)
821
826
qk_view = tlx .local_view (qk1_buf , bufIdx_qk1 )
822
- p1_view = tlx .local_reinterpret (qk_view , tl . float16 )
827
+ p1_view = tlx .local_reinterpret (qk_view , dtype )
823
828
824
829
accum_cnt_qk1 += 1
825
830
# release p0, p1 via producer_commit_qk0, qk1 barriers
0 commit comments