Skip to content

Commit e6b4ddc

Browse files
committed
on-host TMA, needs to add masking
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 55075f0 commit e6b4ddc

File tree

1 file changed

+128
-70
lines changed

1 file changed

+128
-70
lines changed

tritonbench/operators/gdpa/gdpa_blackwell_tlx.py

Lines changed: 128 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,27 @@
66
import triton
77
import triton.language as tl
88
import triton.language.extra.tlx as tlx
9+
from triton.tools.tensor_descriptor import TensorDescriptor
910

1011
from .gdpa_utils import get_num_sms
1112
from .math import activation_string_to_int
1213

1314

15+
def _host_descriptor_pre_hook(nargs):
16+
BLOCK_M = nargs["BLOCK_M"]
17+
BLOCK_N = nargs["BLOCK_N"]
18+
BLOCK_D = nargs["BLOCK_D"]
19+
if not isinstance(nargs["Q"], TensorDescriptor):
20+
# early return for on-device TMA
21+
return
22+
NUM_MMA_GROUPS = 2
23+
BLOCK_M_SPLIT = BLOCK_M // NUM_MMA_GROUPS
24+
nargs["Q"].block_shape = [BLOCK_M_SPLIT, BLOCK_D]
25+
nargs["V"].block_shape = [BLOCK_N, BLOCK_D]
26+
nargs["K"].block_shape = [BLOCK_N, BLOCK_D]
27+
nargs["Out"].block_shape = [BLOCK_M_SPLIT, BLOCK_D]
28+
29+
1430
def get_cuda_autotune_config():
1531
return [
1632
triton.Config(
@@ -24,6 +40,7 @@ def get_cuda_autotune_config():
2440
},
2541
num_warps=4,
2642
num_stages=1,
43+
pre_hook=_host_descriptor_pre_hook,
2744
)
2845
for BM in [256] # 128 or 256
2946
for BN in [128]
@@ -198,6 +215,7 @@ def gdpa_kernel_tma_ws_blackwell(
198215
BROADCAST_Q: tl.constexpr,
199216
IS_DENSE_KV: tl.constexpr,
200217
activation_enum_int: tl.constexpr,
218+
USE_ON_DEVICE_TMA: tl.constexpr,
201219
NUM_BUFFERS_Q: tl.constexpr,
202220
NUM_BUFFERS_KV: tl.constexpr,
203221
NUM_BUFFERS_QK: tl.constexpr,
@@ -214,21 +232,27 @@ def gdpa_kernel_tma_ws_blackwell(
214232
tiles_per_sm += 1
215233

216234
tile_idx = prog_id
235+
if not USE_ON_DEVICE_TMA:
236+
q_desc = Q
237+
k_desc = K
238+
v_desc = V
239+
o_desc = Out
217240

218241
# start with on-device TMA where descriptors for k, v are set up outside of the persistent
219242
# loop and descriptor for q is set up inside the persistent loop.
220-
k_desc = tl.make_tensor_descriptor(
221-
K,
222-
shape=[N_CTX_KV * Z, HEAD_DIM * H // G],
223-
strides=[HEAD_DIM * H // G, 1],
224-
block_shape=[BLOCK_N, BLOCK_D],
225-
)
226-
v_desc = tl.make_tensor_descriptor(
227-
V,
228-
shape=[N_CTX_KV * Z, HEAD_DIM * H // G],
229-
strides=[HEAD_DIM * H // G, 1],
230-
block_shape=[BLOCK_N, BLOCK_D],
231-
)
243+
if USE_ON_DEVICE_TMA:
244+
k_desc = tl.make_tensor_descriptor(
245+
K,
246+
shape=[N_CTX_KV * Z, HEAD_DIM * H // G],
247+
strides=[HEAD_DIM * H // G, 1],
248+
block_shape=[BLOCK_N, BLOCK_D],
249+
)
250+
v_desc = tl.make_tensor_descriptor(
251+
V,
252+
shape=[N_CTX_KV * Z, HEAD_DIM * H // G],
253+
strides=[HEAD_DIM * H // G, 1],
254+
block_shape=[BLOCK_N, BLOCK_D],
255+
)
232256

233257
# allocate buffers for q0, q1
234258
q0_buf = tlx.local_alloc((BLOCK_M // 2, BLOCK_D), tl.float16, 1)
@@ -326,20 +350,12 @@ def gdpa_kernel_tma_ws_blackwell(
326350
qk0 = tlx.local_load(qk_view) # , tlx.storage_kind.tmem)
327351
# ConsumerWait for qk, ProducerAcquire for p
328352
# if activation_enum_int == 3:
329-
p0 = (
330-
qk0
331-
* 0.5
332-
* (
333-
1
334-
+ tanh_approx_fp32(
335-
0.7978845608 * qk0 * (1.0 + 0.044715 * qk0 * qk0)
336-
)
337-
)
338-
) # fast_gelu(qk0)
339-
# else:
340-
# p0 = qk0
353+
p0 = fast_gelu(qk0)
341354
p0 *= qk_scale
342-
p0 = p0.to(V.dtype.element_ty) # v_dtype)
355+
if USE_ON_DEVICE_TMA:
356+
p0 = p0.to(V.dtype.element_ty) # v_dtype)
357+
else:
358+
p0 = p0.to(tlx.dtype_of(v_desc))
343359
qk_view = tlx.local_view(qk0_buf, bufIdx)
344360
p0_view = tlx.local_reinterpret(qk_view, tl.float16)
345361
tlx.local_store(p0_view, p0) # , tlx.storage_kind.tmem)
@@ -371,18 +387,23 @@ def gdpa_kernel_tma_ws_blackwell(
371387
)
372388
# tl.device_print("default producer_o0", accum_cnt_outer)
373389
tlx.barrier_arrive(consumer_release_o0_view, 1)
374-
o0_desc = tl.make_tensor_descriptor(
375-
Out,
376-
shape=[end_q.to(tl.int32), HEAD_DIM * H],
377-
strides=[HEAD_DIM * H, 1],
378-
block_shape=[BLOCK_M // 2, BLOCK_D],
379-
)
380-
o0_desc.store(
390+
if USE_ON_DEVICE_TMA:
391+
o_desc = tl.make_tensor_descriptor(
392+
Out,
393+
shape=[end_q.to(tl.int32), HEAD_DIM * H],
394+
strides=[HEAD_DIM * H, 1],
395+
block_shape=[BLOCK_M // 2, BLOCK_D],
396+
)
397+
if USE_ON_DEVICE_TMA:
398+
o0 = o0.to(Out.type.element_ty)
399+
else:
400+
o0 = o0.to(tlx.dtype_of(o_desc))
401+
o_desc.store(
381402
[
382403
(begin_q + start_m * BLOCK_M).to(tl.int32),
383404
(out_offset).to(tl.int32),
384405
],
385-
o0.to(Out.type.element_ty),
406+
o0,
386407
)
387408
accum_cnt_outer += 1
388409
tile_idx += num_progs
@@ -420,20 +441,12 @@ def gdpa_kernel_tma_ws_blackwell(
420441
qk1 = tlx.local_load(qk_view) # , tlx.storage_kind.tmem)
421442
# ConsumerWait for qk, ProducerAcquire for p
422443
# if activation_enum_int == 3:
423-
p1 = (
424-
qk1
425-
* 0.5
426-
* (
427-
1
428-
+ tanh_approx_fp32(
429-
0.7978845608 * qk1 * (1.0 + 0.044715 * qk1 * qk1)
430-
)
431-
)
432-
) # fast_gelu(qk1)
433-
# else:
434-
# p1 = qk1
444+
p1 = fast_gelu(qk1)
435445
p1 *= qk_scale
436-
p1 = p1.to(V.dtype.element_ty) # v_dtype)
446+
if USE_ON_DEVICE_TMA:
447+
p1 = p1.to(V.dtype.element_ty) # v_dtype)
448+
else:
449+
p1 = p1.to(tlx.dtype_of(v_desc))
437450
qk_view = tlx.local_view(qk1_buf, bufIdx)
438451
p1_view = tlx.local_reinterpret(qk_view, tl.float16)
439452
tlx.local_store(p1_view, p1) # , tlx.storage_kind.tmem)
@@ -452,12 +465,13 @@ def gdpa_kernel_tma_ws_blackwell(
452465
bufIdx_o_outer, phase_o_outer = _get_bufidx_phase(
453466
accum_cnt_outer, NUM_BUFFERS_O
454467
)
455-
o0_desc = tl.make_tensor_descriptor(
456-
Out,
457-
shape=[end_q.to(tl.int32), HEAD_DIM * H],
458-
strides=[HEAD_DIM * H, 1],
459-
block_shape=[BLOCK_M // 2, BLOCK_D],
460-
)
468+
if USE_ON_DEVICE_TMA:
469+
o_desc = tl.make_tensor_descriptor(
470+
Out,
471+
shape=[end_q.to(tl.int32), HEAD_DIM * H],
472+
strides=[HEAD_DIM * H, 1],
473+
block_shape=[BLOCK_M // 2, BLOCK_D],
474+
)
461475
o1_view = tlx.local_view(
462476
o1_buf, bufIdx_o_outer
463477
) # FIXME: should be 0
@@ -467,12 +481,16 @@ def gdpa_kernel_tma_ws_blackwell(
467481
producer_o1, bufIdx_o_outer
468482
)
469483
tlx.barrier_arrive(consumer_release_o1_view, 1)
470-
o0_desc.store(
484+
if USE_ON_DEVICE_TMA:
485+
o1 = o1.to(Out.type.element_ty)
486+
else:
487+
o1 = o1.to(tlx.dtype_of(o_desc))
488+
o_desc.store(
471489
[
472490
(begin_q + start_m * BLOCK_M + BLOCK_M // 2).to(tl.int32),
473491
(out_offset).to(tl.int32),
474492
],
475-
o1.to(Out.type.element_ty),
493+
o1,
476494
)
477495
accum_cnt_outer += 1
478496
tile_idx += num_progs
@@ -581,6 +599,7 @@ def gdpa_kernel_tma_ws_blackwell(
581599
producer_o1_view = tlx.local_view(producer_o1, bufIdx_o_outer)
582600
# tl.device_print("gemm producer_o0", accum_cnt_outer)
583601
# tl.device_print("gemm producer_o0_phase", phase_o_outer)
602+
# DEBUG_PERF
584603
tlx.barrier_wait(
585604
producer_o0_view, phase_o_outer ^ 1
586605
) # producer acquire for o0
@@ -591,6 +610,7 @@ def gdpa_kernel_tma_ws_blackwell(
591610
consumer_p0_view = tlx.local_view(producer_qk0, bufIdx_p)
592611
# tl.device_print("gemm producer_qk0", accum_cnt_qk)
593612
# tl.device_print("gemm producer_qk0_phase", phase_p)
613+
# DEBUG_PERF_P
594614
tlx.barrier_wait(
595615
consumer_p0_view, phase_p
596616
) # consumer wait for p0 due to reuse of p0 and qk0
@@ -660,11 +680,13 @@ def gdpa_kernel_tma_ws_blackwell(
660680
consumer_p1_view = tlx.local_view(producer_qk1, bufIdx_qk1)
661681
# tl.device_print("gemm producer_o1", accum_cnt_outer)
662682
# tl.device_print("gemm producer_o1_phase", phase_o_outer)
683+
# DEBUG_PERF
663684
tlx.barrier_wait(
664685
producer_o1_view, phase_o_outer ^ 1, first
665686
) # producer acquire for o1, only needed for first iteration
666687
# tl.device_print("gemm producer_qk1", accum_cnt_qk1)
667688
# tl.device_print("gemm producer_qk1_phase", phase_qk1)
689+
# DEBUG_PERF_P
668690
tlx.barrier_wait(
669691
consumer_p1_view, phase_qk1
670692
) # consumer wait for p1 use producer_qk1 due to reuse
@@ -741,6 +763,7 @@ def gdpa_kernel_tma_ws_blackwell(
741763
consumer_p0_view = tlx.local_view(producer_qk0, bufIdx_qk)
742764
# tl.device_print("gemm producer_qk0", accum_cnt_qk)
743765
# tl.device_print("gemm producer_qk0_phase", phase_qk)
766+
# DEBUG_PERF_P
744767
tlx.barrier_wait(
745768
consumer_p0_view, phase_qk
746769
) # consumer wait for p0 use producer_qk0 due to reuse
@@ -780,6 +803,7 @@ def gdpa_kernel_tma_ws_blackwell(
780803
tlx.tcgen05_commit(release_q1_view)
781804
# tl.device_print("gemm producer_o1_epilogue", accum_cnt_outer)
782805
# tl.device_print("gemm producer_o1_phase", phase_o_outer)
806+
# DEBUG_PERF
783807
tlx.barrier_wait(
784808
producer_o1_view, phase_o_outer ^ 1, first
785809
) # producer acquire for o1 at the first iteration
@@ -789,6 +813,7 @@ def gdpa_kernel_tma_ws_blackwell(
789813
consumer_p1_view = tlx.local_view(producer_qk1, bufIdx_qk1)
790814
# tl.device_print("gemm producer_qk1_epilogue", accum_cnt_qk1)
791815
# tl.device_print("gemm producer_qk1_phase", phase_qk1)
816+
# DEBUG_PERF_P
792817
tlx.barrier_wait(
793818
consumer_p1_view, phase_qk1
794819
) # consumer wait for p1 due to reuse of p1 and qk1
@@ -862,12 +887,13 @@ def gdpa_kernel_tma_ws_blackwell(
862887
if start_m * BLOCK_M < qlen:
863888
# begin_o = tl.load(Out_offsets + off_z) # confirm if tma store should use begin_q
864889

865-
q_desc = tl.make_tensor_descriptor(
866-
Q,
867-
shape=[end_q.to(tl.int32), HEAD_DIM * H],
868-
strides=[HEAD_DIM * H, 1],
869-
block_shape=[BLOCK_M // 2, BLOCK_D],
870-
)
890+
if USE_ON_DEVICE_TMA:
891+
q_desc = tl.make_tensor_descriptor(
892+
Q,
893+
shape=[end_q.to(tl.int32), HEAD_DIM * H],
894+
strides=[HEAD_DIM * H, 1],
895+
block_shape=[BLOCK_M // 2, BLOCK_D],
896+
)
871897

872898
# calculate bufIdx and phase from accum_count_q
873899
q_bufIdx = accum_count_q % NUM_BUFFERS_Q
@@ -1131,6 +1157,40 @@ def gdpa_forward_tlx(
11311157
print("NUM_SMS", NUM_SMS)
11321158
print(triton.cdiv(max_seq_len_q, 256) * BATCH * nheads)
11331159

1160+
q = expect_contiguous(query)
1161+
k = expect_contiguous(key)
1162+
v = expect_contiguous(value)
1163+
kstrides = k.stride()
1164+
vstrides = v.stride()
1165+
1166+
dummy_block = [1, 1]
1167+
N_CTX_KV = max_seq_len_kv
1168+
HEAD_DIM = HEAD_DIM_K
1169+
Z = BATCH
1170+
H = nheads
1171+
y_dim = N_CTX_KV * Z
1172+
x_dim = HEAD_DIM * H // G
1173+
USE_ON_DEVICE_TMA = True
1174+
if not USE_ON_DEVICE_TMA:
1175+
desc_q = TensorDescriptor(
1176+
q,
1177+
shape=[y_dim, HEAD_DIM * H],
1178+
strides=[HEAD_DIM * H, 1],
1179+
block_shape=dummy_block,
1180+
)
1181+
desc_v = TensorDescriptor(
1182+
v, shape=[y_dim, x_dim], strides=[x_dim, 1], block_shape=dummy_block
1183+
)
1184+
desc_k = TensorDescriptor(
1185+
k, shape=[y_dim, x_dim], strides=[x_dim, 1], block_shape=dummy_block
1186+
)
1187+
desc_o = TensorDescriptor(
1188+
o,
1189+
shape=[y_dim, HEAD_DIM * H],
1190+
strides=[HEAD_DIM * H, 1],
1191+
block_shape=dummy_block,
1192+
)
1193+
11341194
# TMA descriptors require a global memory allocation
11351195
def alloc_fn(size: int, alignment: int, _):
11361196
return torch.empty(size, device="cuda", dtype=torch.int8)
@@ -1144,22 +1204,19 @@ def grid_tma_persistent(META):
11441204
1,
11451205
)
11461206

1147-
q = expect_contiguous(query)
1148-
k = expect_contiguous(key)
1149-
v = expect_contiguous(value)
1150-
kstrides = k.stride()
1151-
vstrides = v.stride()
1152-
11531207
activation_enum_int = activation_string_to_int(activation)
1208+
print(q.shape, k.shape, v.shape)
11541209
# print("activation_enum_int", activation, activation_enum_int)
1210+
# print(query_offset)
1211+
# print(key_offset)
11551212

11561213
gdpa_kernel_tma_ws_blackwell[grid_tma_persistent](
1157-
q,
1214+
q if USE_ON_DEVICE_TMA else desc_q,
11581215
query_offset,
1159-
k,
1216+
k if USE_ON_DEVICE_TMA else desc_k,
11601217
key_offset,
1161-
v,
1162-
o, #
1218+
v if USE_ON_DEVICE_TMA else desc_v,
1219+
o if USE_ON_DEVICE_TMA else desc_o,
11631220
output_offset,
11641221
ad_to_request_offset,
11651222
seq_index,
@@ -1194,6 +1251,7 @@ def grid_tma_persistent(META):
11941251
BROADCAST_Q=broadcast_q,
11951252
IS_DENSE_KV=is_dense_kv,
11961253
activation_enum_int=activation_enum_int,
1254+
USE_ON_DEVICE_TMA=USE_ON_DEVICE_TMA,
11971255
**extra_kern_args,
11981256
)
11991257
return o

0 commit comments

Comments
 (0)