Skip to content

Commit ff11de3

Browse files
authored
Sub-slicing P for TLX FA kernel
Differential Revision: D84267329 Pull Request resolved: #538
1 parent eb2bd7f commit ff11de3

File tree

1 file changed

+36
-7
lines changed

1 file changed

+36
-7
lines changed

tritonbench/kernels/tlx_attention_ws_pipelined.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,25 @@ def _compute_offsets_persistent(tile_idx, n_tile_num, H, N_CTX, BLOCK_M):
530530
return start_m, off_hz, lo, hi, qo_offset_y, kv_offset_y
531531

532532

533+
@triton.jit
534+
def _split_n(x, SPLIT_FACTOR: tl.constexpr):
535+
if SPLIT_FACTOR == 1:
536+
return (x, )
537+
else:
538+
x0, x1 = x.reshape([x.shape[0], 2, x.shape[1] // 2]).permute(0, 2, 1).split()
539+
return _split_n(x0, SPLIT_FACTOR // 2) + _split_n(x1, SPLIT_FACTOR // 2)
540+
541+
@triton.jit
542+
def _join_n(xs):
543+
if len(xs) == 1:
544+
return xs[0]
545+
else:
546+
x0 = _join_n(xs[:len(xs) // 2])
547+
x1 = _join_n(xs[len(xs) // 2:])
548+
x = tl.join(x0, x1).permute(0, 2, 1).reshape([x0.shape[0], x0.shape[1] * 2])
549+
return x
550+
551+
533552
@triton.autotune(configs=configs_persistent, key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT"])
534553
@triton.jit
535554
def _attn_fwd_ws_persistent(
@@ -711,12 +730,12 @@ def _attn_fwd_ws_persistent(
711730
for cid in tl.static_range(0, NUM_MMA_GROUPS):
712731
tlx.barrier_wait(o_fulls[cid], phase)
713732
tlx.fence_async_shared()
714-
tlx.barrier_arrive(o_empties[cid])
715733
qo_offset_y_split = qo_offset_y + cid * BLOCK_M_SPLIT
716734
tlx.async_descriptor_store(
717735
desc_o, o_tiles[cid], [qo_offset_y_split, 0]
718736
)
719737
tlx.async_descriptor_store_wait(0)
738+
tlx.barrier_arrive(o_empties[cid])
720739

721740
tile_idx += num_progs
722741

@@ -751,17 +770,27 @@ def _attn_fwd_ws_persistent(
751770
tlx.local_store(alpha_tiles[cid * HEAD_DIM], alpha[:, None])
752771
tlx.barrier_arrive(alpha_fulls[cid])
753772

754-
qk = _fma_f32x2(qk, qk_scale, -m_ij[:, None])
755-
p = tl.math.exp2(qk)
756-
l_ij = tl.sum(p, 1)
757-
p = p.to(tlx.dtype_of(desc_v))
758773

759774
# prepare p for the v dot
760775
# Use p[1] for cid=0, and p[3] for cid=1
761776
p_bufIdx = 1 + cid * NUM_MMA_GROUPS
762-
tlx.local_store(p_tiles[p_bufIdx], p)
763-
tlx.barrier_arrive(p_fulls[cid])
764777

778+
qk = _fma_f32x2(qk, qk_scale, -m_ij[:, None])
779+
qks = _split_n(qk, NUM_MMA_SLICES)
780+
ps = ()
781+
for slice_id in tl.static_range(0, NUM_MMA_SLICES):
782+
p_i = tl.math.exp2(qks[slice_id])
783+
p_slice = tlx.subslice(
784+
p_tiles[p_bufIdx],
785+
HEAD_DIM * slice_id // NUM_MMA_SLICES,
786+
HEAD_DIM // NUM_MMA_SLICES,
787+
)
788+
tlx.local_store(p_slice, p_i.to(tlx.dtype_of(desc_v)))
789+
ps = ps + (p_i, )
790+
791+
tlx.barrier_arrive(p_fulls[cid])
792+
p = _join_n(ps)
793+
l_ij = tl.sum(p, 1)
765794
l_i = l_i * alpha + l_ij
766795
m_i = m_ij
767796
accum_cnt_qk += 1

0 commit comments

Comments
 (0)