Skip to content

Commit 7183209

Browse files
authored
Update the Blackwell TLX persistent FA kernel
Differential Revision: D84069475 Pull Request resolved: #523
1 parent ad40bed commit 7183209

File tree

1 file changed

+75
-15
lines changed

1 file changed

+75
-15
lines changed

tritonbench/kernels/tlx_attention_ws_pipelined.py

Lines changed: 75 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,23 @@ def _host_descriptor_pre_hook(nargs):
4242
),
4343
]
4444

45+
configs_persistent = [
46+
triton.Config(
47+
{
48+
"BLOCK_M": 256,
49+
"BLOCK_N": 128,
50+
"NUM_BUFFERS_Q": 1,
51+
"NUM_BUFFERS_KV": 3,
52+
"NUM_BUFFERS_QK": 1,
53+
"NUM_MMA_GROUPS": 2,
54+
"NUM_MMA_SLICES": 2,
55+
},
56+
num_stages=0,
57+
num_warps=4,
58+
pre_hook=_host_descriptor_pre_hook,
59+
),
60+
]
61+
4562

4663
@triton.jit
4764
def _get_bufidx_phase(accum_cnt, NUM_BUFFERS_KV):
@@ -63,6 +80,47 @@ def _compute_offsets(H, N_CTX, BLOCK_M):
6380
return start_m, off_hz, lo, hi, qo_offset_y, kv_offset_y
6481

6582

83+
@triton.jit
84+
def _fma_f32x2(a, b, c):
85+
return tl.inline_asm_elementwise(
86+
"""
87+
{
88+
.reg .b64 ra, rb, rc, rd;
89+
mov.b64 ra, { $2, $3 };
90+
mov.b64 rb, { $4, $5 };
91+
mov.b64 rc, { $6, $7 };
92+
fma.rn.f32x2 rd, ra, rb, rc;
93+
mov.b64 { $0, $1 }, rd;
94+
}
95+
""",
96+
"=r,=r,r,r,r,r,r,r",
97+
[a, b, c],
98+
dtype=tl.float32,
99+
is_pure=True,
100+
pack=2,
101+
)
102+
103+
104+
@triton.jit
105+
def _mul_f32x2(a, b):
106+
return tl.inline_asm_elementwise(
107+
"""
108+
{
109+
.reg .b64 ra, rb, rc;
110+
mov.b64 ra, { $2, $3 };
111+
mov.b64 rb, { $4, $5 };
112+
mul.f32x2 rc, ra, rb;
113+
mov.b64 { $0, $1 }, rc;
114+
}
115+
""",
116+
"=r,=r,r,r,r,r",
117+
[a, b],
118+
dtype=tl.float32,
119+
is_pure=True,
120+
pack=2,
121+
)
122+
123+
66124
@triton.autotune(configs=configs, key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT"])
67125
@triton.jit
68126
def _attn_fwd_ws(
@@ -472,7 +530,7 @@ def _compute_offsets_persistent(tile_idx, n_tile_num, H, N_CTX, BLOCK_M):
472530
return start_m, off_hz, lo, hi, qo_offset_y, kv_offset_y
473531

474532

475-
@triton.autotune(configs=configs, key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT"])
533+
@triton.autotune(configs=configs_persistent, key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT"])
476534
@triton.jit
477535
def _attn_fwd_ws_persistent(
478536
sm_scale,
@@ -492,6 +550,7 @@ def _attn_fwd_ws_persistent(
492550
NUM_BUFFERS_KV: tl.constexpr, #
493551
NUM_BUFFERS_QK: tl.constexpr, #
494552
NUM_MMA_GROUPS: tl.constexpr, #
553+
NUM_MMA_SLICES: tl.constexpr, #
495554
):
496555
tl.static_assert(BLOCK_N <= HEAD_DIM)
497556
tl.static_assert(NUM_MMA_GROUPS == 2)
@@ -593,24 +652,27 @@ def _attn_fwd_ws_persistent(
593652
)
594653
for _ in tl.range(lo, hi, BLOCK_N):
595654
_, phase = _get_bufidx_phase(accum_cnt, 1)
596-
for cid in tl.range(
597-
0, NUM_MMA_GROUPS, loop_unroll_factor=NUM_MMA_GROUPS
598-
):
655+
for cid in tl.static_range(0, NUM_MMA_GROUPS):
599656
# -- update output accumulator --
600657
tlx.barrier_wait(alpha_fulls[cid], phase)
601658
# Use alpha[0] for cid=0, and alpha[HEAD_DIM] for cid=1
602659
alpha_1 = tlx.local_load(alpha_tiles[cid * HEAD_DIM])
603660
tlx.barrier_arrive(alpha_empties[cid])
604-
acc = tlx.local_load(acc_tiles[cid])
605-
acc = acc * alpha_1
606-
tlx.local_store(acc_tiles[cid], acc)
661+
for slice_id in tl.static_range(0, NUM_MMA_SLICES):
662+
subslice = tlx.subslice(
663+
acc_tiles[cid],
664+
HEAD_DIM * slice_id // NUM_MMA_SLICES,
665+
HEAD_DIM // NUM_MMA_SLICES,
666+
)
667+
acc = tlx.local_load(subslice)
668+
# acc = acc * alpha_1
669+
acc = _mul_f32x2(acc, alpha_1)
670+
tlx.local_store(subslice, acc)
607671
tlx.barrier_arrive(acc_fulls[cid])
608672
accum_cnt += 1
609673

610674
_, phase = _get_bufidx_phase(i, 1)
611-
for cid in tl.range(
612-
0, NUM_MMA_GROUPS, loop_unroll_factor=NUM_MMA_GROUPS
613-
):
675+
for cid in tl.static_range(0, NUM_MMA_GROUPS):
614676
# epilogue
615677
tlx.barrier_wait(l_fulls[cid], phase)
616678
# Use l[1]/l[1+HEAD_DIM] and m[2][2 + HEAD_DIM]
@@ -646,9 +708,7 @@ def _attn_fwd_ws_persistent(
646708
_compute_offsets_persistent(tile_idx, n_tile_num, H, N_CTX, BLOCK_M)
647709
)
648710
_, phase = _get_bufidx_phase(i, 1)
649-
for cid in tl.range(
650-
0, NUM_MMA_GROUPS, loop_unroll_factor=NUM_MMA_GROUPS
651-
):
711+
for cid in tl.static_range(0, NUM_MMA_GROUPS):
652712
tlx.barrier_wait(o_fulls[cid], phase)
653713
tlx.fence_async_shared()
654714
tlx.barrier_arrive(o_empties[cid])
@@ -661,7 +721,7 @@ def _attn_fwd_ws_persistent(
661721
tile_idx += num_progs
662722

663723
# softmax groups
664-
with tlx.async_task(num_warps=4, registers=152, replicate=NUM_MMA_GROUPS):
724+
with tlx.async_task(num_warps=4, registers=168, replicate=NUM_MMA_GROUPS):
665725
accum_cnt_qk = 0
666726
for i in range(0, tiles_per_sm):
667727
# initialize offsets
@@ -691,7 +751,7 @@ def _attn_fwd_ws_persistent(
691751
tlx.local_store(alpha_tiles[cid * HEAD_DIM], alpha[:, None])
692752
tlx.barrier_arrive(alpha_fulls[cid])
693753

694-
qk = qk * qk_scale - m_ij[:, None]
754+
qk = _fma_f32x2(qk, qk_scale, -m_ij[:, None])
695755
p = tl.math.exp2(qk)
696756
l_ij = tl.sum(p, 1)
697757
p = p.to(tlx.dtype_of(desc_v))

0 commit comments

Comments
 (0)