Skip to content

Commit a21d269

Browse files
njriasanmeta-codesync[bot]
authored andcommitted
[TLX] Support pid swizzling with TLX FA (#707)
Summary: Adds PID swizzling which improves performance of causal attention significantly. This uses the heuristics from gluon. Pull Request resolved: #707 Reviewed By: htyu Differential Revision: D88106967 Pulled By: njriasan fbshipit-source-id: 5e21629a1b8b68a76087fcb5bb7ed54519a03143
1 parent 99ddd86 commit a21d269

File tree

1 file changed

+106
-13
lines changed

1 file changed

+106
-13
lines changed

third_party/tlx/tutorials/blackwell-fa-ws-pipelined-persistent_test.py

Lines changed: 106 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,37 @@ def _host_descriptor_pre_hook(nargs):
3838
"NUM_BUFFERS_QK": 1,
3939
"NUM_MMA_GROUPS": 2,
4040
"NUM_MMA_SLICES": 2,
41+
"GROUP_SIZE_N": 1,
42+
},
43+
num_stages=0,
44+
num_warps=4,
45+
pre_hook=_host_descriptor_pre_hook,
46+
),
47+
triton.Config(
48+
{
49+
"BLOCK_M": 256,
50+
"BLOCK_N": 128,
51+
"NUM_BUFFERS_Q": 1,
52+
"NUM_BUFFERS_KV": 3,
53+
"NUM_BUFFERS_QK": 1,
54+
"NUM_MMA_GROUPS": 2,
55+
"NUM_MMA_SLICES": 2,
56+
"GROUP_SIZE_N": 4,
57+
},
58+
num_stages=0,
59+
num_warps=4,
60+
pre_hook=_host_descriptor_pre_hook,
61+
),
62+
triton.Config(
63+
{
64+
"BLOCK_M": 256,
65+
"BLOCK_N": 128,
66+
"NUM_BUFFERS_Q": 1,
67+
"NUM_BUFFERS_KV": 6,
68+
"NUM_BUFFERS_QK": 1,
69+
"NUM_MMA_GROUPS": 2,
70+
"NUM_MMA_SLICES": 2,
71+
"GROUP_SIZE_N": 1,
4172
},
4273
num_stages=0,
4374
num_warps=4,
@@ -52,6 +83,7 @@ def _host_descriptor_pre_hook(nargs):
5283
"NUM_BUFFERS_QK": 1,
5384
"NUM_MMA_GROUPS": 2,
5485
"NUM_MMA_SLICES": 2,
86+
"GROUP_SIZE_N": 4,
5587
},
5688
num_stages=0,
5789
num_warps=4,
@@ -62,9 +94,13 @@ def _host_descriptor_pre_hook(nargs):
6294

6395
def prune_configs_by_hdim(configs, named_args, **kwargs):
6496
HEAD_DIM = kwargs["HEAD_DIM"]
97+
STAGE = kwargs["STAGE"]
6598
target_kv_buffers = 6 if HEAD_DIM == 64 else 3
66-
# Only match HEAD_DIM for BLOCK_N
67-
return [conf for conf in configs if conf.kwargs.get("NUM_BUFFERS_KV", 0) == target_kv_buffers]
99+
target_group_size_n = 4 if STAGE == 3 else 1
100+
return [
101+
conf for conf in configs if conf.kwargs.get("NUM_BUFFERS_KV", 0) == target_kv_buffers
102+
and conf.kwargs.get("GROUP_SIZE_N", 0) == target_group_size_n
103+
]
68104

69105

70106
@triton.jit
@@ -140,9 +176,21 @@ def _get_fused_loop_bounds(start_m, N_CTX, BLOCK_M, STAGE: tl.constexpr):
140176

141177

142178
@triton.jit
143-
def _compute_offsets(tile_idx, n_tile_num, H, N_CTX, BLOCK_M, STAGE: tl.constexpr):
144-
start_m = tile_idx % n_tile_num
145-
off_hz = tile_idx // n_tile_num
179+
def _compute_offsets(
180+
tile_idx,
181+
H,
182+
num_pid_n,
183+
num_pid_in_group,
184+
N_CTX,
185+
BLOCK_M: tl.constexpr,
186+
STAGE: tl.constexpr,
187+
GROUP_SIZE_N: tl.constexpr,
188+
):
189+
group_id = tile_idx // num_pid_in_group
190+
first_pid_n = group_id * GROUP_SIZE_N
191+
group_size_n = min(num_pid_n - first_pid_n, GROUP_SIZE_N)
192+
start_m = (tile_idx % num_pid_in_group) // group_size_n
193+
off_hz = first_pid_n + (tile_idx % group_size_n)
146194
off_z = off_hz // H
147195
off_h = off_hz % H
148196
offset_y = off_z * (N_CTX * H) + off_h * N_CTX
@@ -282,6 +330,7 @@ def _attn_fwd_ws(sm_scale, M, #
282330
NUM_BUFFERS_QK: tl.constexpr, #
283331
NUM_MMA_GROUPS: tl.constexpr, #
284332
NUM_MMA_SLICES: tl.constexpr, #
333+
GROUP_SIZE_N: tl.constexpr, #
285334
):
286335
tl.static_assert(NUM_MMA_GROUPS == 2)
287336
tl.static_assert(NUM_BUFFERS_QK == 1)
@@ -292,10 +341,12 @@ def _attn_fwd_ws(sm_scale, M, #
292341
# original grid
293342
# triton.cdiv(q.shape[2], META["BLOCK_M"]),
294343
# q.shape[0] * q.shape[1],
295-
n_tile_num = tl.cdiv(N_CTX, BLOCK_M)
296344
prog_id = tl.program_id(0)
297345
num_progs = tl.num_programs(0)
298-
total_tiles = n_tile_num * Z * H
346+
num_pid_m = tl.cdiv(N_CTX, BLOCK_M)
347+
num_pid_n = Z * H
348+
num_pid_in_group = num_pid_m * GROUP_SIZE_N
349+
total_tiles = num_pid_m * Z * H
299350

300351
tiles_per_sm = total_tiles // num_progs
301352
if prog_id < total_tiles % num_progs:
@@ -375,7 +426,15 @@ def _attn_fwd_ws(sm_scale, M, #
375426
for i in range(0, tiles_per_sm):
376427
# initialize offsets
377428
start_m, off_hz, lo, hi, qo_offset_y, kv_offset_y = _compute_offsets(
378-
tile_idx, n_tile_num, H, N_CTX, BLOCK_M, STAGE)
429+
tile_idx,
430+
H,
431+
num_pid_n,
432+
num_pid_in_group,
433+
N_CTX,
434+
BLOCK_M,
435+
STAGE,
436+
GROUP_SIZE_N,
437+
)
379438
for _ in tl.range(lo, hi, BLOCK_N):
380439
_, phase = _get_bufidx_phase(accum_cnt, 1)
381440
for cid in tl.static_range(0, NUM_MMA_GROUPS):
@@ -439,7 +498,15 @@ def _attn_fwd_ws(sm_scale, M, #
439498
for i in range(0, tiles_per_sm):
440499
# initialize offsets
441500
start_m, off_hz, lo, hi, qo_offset_y, kv_offset_y = _compute_offsets(
442-
tile_idx, n_tile_num, H, N_CTX, BLOCK_M, STAGE)
501+
tile_idx,
502+
H,
503+
num_pid_n,
504+
num_pid_in_group,
505+
N_CTX,
506+
BLOCK_M,
507+
STAGE,
508+
GROUP_SIZE_N,
509+
)
443510
# initialize pointer to m and l
444511
m_i = tl.zeros([BLOCK_M_SPLIT], dtype=tl.float32) - float("inf")
445512
l_i = tl.zeros([BLOCK_M_SPLIT], dtype=tl.float32) + 1.0
@@ -517,7 +584,16 @@ def _attn_fwd_ws(sm_scale, M, #
517584

518585
for j in range(0, tiles_per_sm):
519586
# initialize offsets
520-
_, _, lo, hi, _, _ = _compute_offsets(tile_idx, n_tile_num, H, N_CTX, BLOCK_M, STAGE)
587+
_, _, lo, hi, _, _ = _compute_offsets(
588+
tile_idx,
589+
H,
590+
num_pid_n,
591+
num_pid_in_group,
592+
N_CTX,
593+
BLOCK_M,
594+
STAGE,
595+
GROUP_SIZE_N,
596+
)
521597

522598
q_bufIdx, q_phase = _get_bufidx_phase(j, NUM_BUFFERS_Q)
523599
k_bufIdx, k_phase = _get_bufidx_phase(accum_cnt_kv, NUM_BUFFERS_KV)
@@ -685,8 +761,16 @@ def _attn_fwd_ws(sm_scale, M, #
685761
accum_cnt_kv = 0
686762
for i in range(0, tiles_per_sm):
687763
# initialize offsets
688-
_, _, lo, hi, qo_offset_y, kv_offset_y = _compute_offsets(tile_idx, n_tile_num, H, N_CTX, BLOCK_M,
689-
STAGE)
764+
_, _, lo, hi, qo_offset_y, kv_offset_y = _compute_offsets(
765+
tile_idx,
766+
H,
767+
num_pid_n,
768+
num_pid_in_group,
769+
N_CTX,
770+
BLOCK_M,
771+
STAGE,
772+
GROUP_SIZE_N,
773+
)
690774

691775
# load q0
692776
q_bufIdx, q_phase = _get_bufidx_phase(i, NUM_BUFFERS_Q)
@@ -758,7 +842,16 @@ def _attn_fwd_ws(sm_scale, M, #
758842
# initialize offsets
759843
for i in range(0, tiles_per_sm):
760844
# initialize offsets
761-
_, _, _, _, qo_offset_y, _ = _compute_offsets(tile_idx, n_tile_num, H, N_CTX, BLOCK_M, STAGE)
845+
_, _, _, _, qo_offset_y, _ = _compute_offsets(
846+
tile_idx,
847+
H,
848+
num_pid_n,
849+
num_pid_in_group,
850+
N_CTX,
851+
BLOCK_M,
852+
STAGE,
853+
GROUP_SIZE_N,
854+
)
762855
_, phase = _get_bufidx_phase(i, 1)
763856
for cid in tl.static_range(0, NUM_MMA_GROUPS):
764857
tlx.barrier_wait(o_fulls[cid], phase)

0 commit comments

Comments
 (0)