Skip to content

Commit dee73ff

Browse files
authored
add non-persistent variant with data partitioning
Differential Revision: D83419153 Pull Request resolved: #487
1 parent 4f60f45 commit dee73ff

File tree

2 files changed

+157
-29
lines changed

2 files changed

+157
-29
lines changed

tritonbench/kernels/blackwell_triton_fused_attention_dp.py

Lines changed: 146 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def _attn_fwd_subtile(
5656
v,
5757
dtype: tl.constexpr,
5858
STAGE: tl.constexpr,
59+
SUBTILING: tl.constexpr,
5960
):
6061
qk = tl.dot(q, k)
6162
if STAGE == 2:
@@ -75,10 +76,13 @@ def _attn_fwd_subtile(
7576
BM: tl.constexpr = acc.shape[0]
7677
BN: tl.constexpr = acc.shape[1]
7778

78-
acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split()
79-
acc0 = acc0 * alpha[:, None]
80-
acc1 = acc1 * alpha[:, None]
81-
acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN])
79+
if SUBTILING:
80+
acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split()
81+
acc0 = acc0 * alpha[:, None]
82+
acc1 = acc1 * alpha[:, None]
83+
acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN])
84+
else:
85+
acc = acc * alpha[:, None]
8286

8387
# prepare p and v for the dot
8488
p = p.to(dtype)
@@ -117,6 +121,7 @@ def _attn_fwd_inner_oss_dp(
117121
offs_n: tl.constexpr, #
118122
N_CTX: tl.constexpr,
119123
warp_specialize: tl.constexpr,
124+
SUBTILING: tl.constexpr,
120125
):
121126
# range of values handled by this stage
122127
if STAGE == 1:
@@ -139,10 +144,34 @@ def _attn_fwd_inner_oss_dp(
139144
v = desc_v.load([offsetkv_y, 0])
140145

141146
l_i0, m_i0, acc0 = _attn_fwd_subtile(
142-
q0, k, offs_m0, start_n, offs_n, qk_scale, l_i0, m_i0, acc0, v, dtype, STAGE
147+
q0,
148+
k,
149+
offs_m0,
150+
start_n,
151+
offs_n,
152+
qk_scale,
153+
l_i0,
154+
m_i0,
155+
acc0,
156+
v,
157+
dtype,
158+
STAGE,
159+
SUBTILING,
143160
)
144161
l_i1, m_i1, acc1 = _attn_fwd_subtile(
145-
q1, k, offs_m1, start_n, offs_n, qk_scale, l_i1, m_i1, acc1, v, dtype, STAGE
162+
q1,
163+
k,
164+
offs_m1,
165+
start_n,
166+
offs_n,
167+
qk_scale,
168+
l_i1,
169+
m_i1,
170+
acc1,
171+
v,
172+
dtype,
173+
STAGE,
174+
SUBTILING,
146175
)
147176

148177
offsetkv_y += BLOCK_N
@@ -174,15 +203,17 @@ def _host_descriptor_pre_hook(nargs):
174203

175204
configs = [
176205
triton.Config(
177-
{"BLOCK_M": BM, "BLOCK_N": BN},
206+
{"BLOCK_M": BM, "BLOCK_N": BN, "SUBTILING": subtile},
178207
num_stages=s,
179208
num_warps=w,
180209
pre_hook=_host_descriptor_pre_hook,
210+
# ir_override=f"/home/mren/OpenSource/tritonbench/override/_attn_fwd_persist.ttgir"
181211
)
182212
for BM in [256]
183213
for BN in [128]
184214
for s in NUM_STAGES_OPTIONS
185215
for w in [4]
216+
for subtile in [True]
186217
]
187218

188219

@@ -222,6 +253,8 @@ def _attn_fwd_tma_dp(
222253
desc_k,
223254
desc_v,
224255
desc_o,
256+
pid,
257+
off_hz,
225258
N_CTX, #
226259
HEAD_DIM: tl.constexpr, #
227260
BLOCK_M: tl.constexpr, #
@@ -230,10 +263,11 @@ def _attn_fwd_tma_dp(
230263
STAGE: tl.constexpr, #
231264
warp_specialize: tl.constexpr, #
232265
dtype: tl.constexpr,
266+
SUBTILING: tl.constexpr,
233267
):
234268
tl.static_assert(BLOCK_N <= HEAD_DIM)
235-
start_m = tl.program_id(0)
236-
off_hz = tl.program_id(1)
269+
start_m = pid # tl.program_id(0)
270+
# off_hz = tl.program_id(1)
237271
off_z = off_hz // H
238272
off_h = off_hz % H
239273

@@ -283,6 +317,7 @@ def _attn_fwd_tma_dp(
283317
offs_n,
284318
N_CTX, #
285319
warp_specialize,
320+
SUBTILING,
286321
)
287322
if STAGE & 2:
288323
acc0, acc1, l_i0, l_i1, m_i0, m_i1 = _attn_fwd_inner_oss_dp(
@@ -309,6 +344,7 @@ def _attn_fwd_tma_dp(
309344
offs_n,
310345
N_CTX, #
311346
warp_specialize,
347+
SUBTILING,
312348
)
313349

314350
m_i0 += tl.math.log2(l_i0)
@@ -324,6 +360,56 @@ def _attn_fwd_tma_dp(
324360
desc_o.store([qo_offset_y + BLOCK_M // 2, 0], acc1.to(dtype))
325361

326362

363+
@triton.autotune(
364+
configs=list(filter(keep, configs)),
365+
key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"],
366+
prune_configs_by={"early_config_prune": prune_invalid_configs},
367+
)
368+
@triton.jit
369+
def _attn_fwd(
370+
sm_scale,
371+
M, #
372+
Z,
373+
H,
374+
desc_q,
375+
desc_k,
376+
desc_v,
377+
desc_o,
378+
N_CTX, #
379+
HEAD_DIM: tl.constexpr, #
380+
BLOCK_M: tl.constexpr, #
381+
BLOCK_N: tl.constexpr, #
382+
FP8_OUTPUT: tl.constexpr, #
383+
STAGE: tl.constexpr, #
384+
warp_specialize: tl.constexpr, #
385+
dtype: tl.constexpr,
386+
SUBTILING: tl.constexpr,
387+
):
388+
pid = tl.program_id(0)
389+
off_hz = tl.program_id(1)
390+
_attn_fwd_tma_dp(
391+
sm_scale,
392+
M,
393+
Z,
394+
H,
395+
desc_q,
396+
desc_k,
397+
desc_v,
398+
desc_o,
399+
pid,
400+
off_hz,
401+
N_CTX,
402+
HEAD_DIM,
403+
BLOCK_M,
404+
BLOCK_N,
405+
FP8_OUTPUT,
406+
STAGE,
407+
warp_specialize,
408+
dtype,
409+
SUBTILING,
410+
)
411+
412+
327413
@triton.autotune(
328414
configs=list(filter(keep, configs)),
329415
key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"],
@@ -348,6 +434,7 @@ def _attn_fwd_persist(
348434
warp_specialize: tl.constexpr, #
349435
OUTER_LOOP: tl.constexpr,
350436
dtype: tl.constexpr,
437+
SUBTILING: tl.constexpr,
351438
):
352439
n_tile_num = tl.cdiv(N_CTX, BLOCK_M)
353440
prog_id = tl.program_id(0)
@@ -372,6 +459,8 @@ def _attn_fwd_persist(
372459
desc_k,
373460
desc_v,
374461
desc_o,
462+
pid,
463+
off_hz,
375464
N_CTX,
376465
HEAD_DIM,
377466
BLOCK_M,
@@ -380,6 +469,7 @@ def _attn_fwd_persist(
380469
STAGE,
381470
warp_specialize and not OUTER_LOOP,
382471
dtype,
472+
SUBTILING,
383473
)
384474
tile_idx += num_progs
385475

@@ -406,6 +496,7 @@ def forward(ctx, q, k, v, causal, sm_scale, baseVariant):
406496
M = torch.empty(
407497
(q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
408498
)
499+
warp_specialize = baseVariant == "ws" or baseVariant == "ws_persistent"
409500
# Use device_descriptor for Hopper + warpspec.
410501
if supports_host_descriptor() and not (is_hopper() and warp_specialize):
411502
# Note that on Hopper we cannot perform a FP8 dot with a non-transposed second tensor
@@ -473,33 +564,59 @@ def grid_persist(META):
473564
1,
474565
)
475566

567+
def grid_debug(META):
568+
return (
569+
1,
570+
1,
571+
1,
572+
)
573+
476574
ctx.grid = grid
477-
warp_specialize = baseVariant == "ws"
575+
persistent = baseVariant == "persistent" or baseVariant == "ws_persistent"
478576
if is_blackwell() and warp_specialize:
479577
if HEAD_DIM_K == 128 and (
480578
q.dtype == torch.float16 or q.dtype == torch.bfloat16
481579
):
482-
extra_kern_args["maxnreg"] = 168
580+
extra_kern_args["maxnreg"] = 128
483581
else:
484582
extra_kern_args["maxnreg"] = 80
485-
_attn_fwd_persist[grid_persist](
486-
sm_scale,
487-
M, #
488-
q.shape[0],
489-
q.shape[1], #
490-
desc_q,
491-
desc_k,
492-
desc_v,
493-
desc_o, #
494-
N_CTX=q.shape[2], #
495-
HEAD_DIM=HEAD_DIM_K, #
496-
FP8_OUTPUT=q.dtype == torch.float8_e5m2, #
497-
STAGE=stage, #
498-
warp_specialize=warp_specialize,
499-
OUTER_LOOP=True,
500-
dtype=torch_dtype_to_triton(q.dtype),
501-
**extra_kern_args,
502-
)
583+
if persistent:
584+
_attn_fwd_persist[grid_persist](
585+
sm_scale,
586+
M, #
587+
q.shape[0],
588+
q.shape[1], #
589+
desc_q,
590+
desc_k,
591+
desc_v,
592+
desc_o, #
593+
N_CTX=q.shape[2], #
594+
HEAD_DIM=HEAD_DIM_K, #
595+
FP8_OUTPUT=q.dtype == torch.float8_e5m2, #
596+
STAGE=stage, #
597+
warp_specialize=warp_specialize,
598+
OUTER_LOOP=True,
599+
dtype=torch_dtype_to_triton(q.dtype),
600+
**extra_kern_args,
601+
)
602+
else:
603+
_attn_fwd[grid](
604+
sm_scale,
605+
M, #
606+
q.shape[0],
607+
q.shape[1], #
608+
desc_q,
609+
desc_k,
610+
desc_v,
611+
desc_o, #
612+
N_CTX=q.shape[2], #
613+
HEAD_DIM=HEAD_DIM_K, #
614+
FP8_OUTPUT=q.dtype == torch.float8_e5m2, #
615+
STAGE=stage, #
616+
warp_specialize=warp_specialize,
617+
dtype=torch_dtype_to_triton(q.dtype),
618+
**extra_kern_args,
619+
)
503620

504621
ctx.save_for_backward(q, k, v, o, M)
505622

tritonbench/operators/blackwell_attentions/operator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,17 @@ def triton_tutorial_flash_dp_blackwell(
454454
q, k, v, self.causal, self.sm_scale, "ws"
455455
)
456456

457+
@register_benchmark(enabled=False)
458+
def triton_tutorial_flash_dp_persistent_blackwell(
459+
self,
460+
q: torch.Tensor,
461+
k: torch.Tensor,
462+
v: torch.Tensor,
463+
) -> Callable:
464+
return lambda: blackwell_triton_tutorial_FA2_dp(
465+
q, k, v, self.causal, self.sm_scale, "ws_persistent"
466+
)
467+
457468
# Only works with triton main, forward only.
458469
@register_benchmark(enabled=False)
459470
def gluon_blackwell_tutorial_fwd(

0 commit comments

Comments
 (0)