Skip to content

Commit 2ecaf43

Browse files
njriasanmeta-codesync[bot]
authored andcommitted
[TLX] [FA] Enable BLOCK_SIZE=128 with H-DIM=64 on TLX's FA implementation (#700)
Summary: Adds the fixes to the kernel needed to use `BLOCK_SIZE=128` with `H-DIM=64`. This allows TLX to get a consistent blocksize with both gluon and cutlass, both of which are always using 256, 128 for both H-DIM=128 and H-DIM=64. Pull Request resolved: #700 Reviewed By: adamomainz Differential Revision: D88070466 Pulled By: njriasan fbshipit-source-id: 6217249cbf8a87587a65a07df110e53d2678512f
1 parent 6536e28 commit 2ecaf43

File tree

1 file changed

+45
-19
lines changed

1 file changed

+45
-19
lines changed

third_party/tlx/tutorials/blackwell-fa-ws-pipelined-persistent_test.py

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,30 @@ def _host_descriptor_pre_hook(nargs):
4343
num_warps=4,
4444
pre_hook=_host_descriptor_pre_hook,
4545
),
46+
triton.Config(
47+
{
48+
"BLOCK_M": 256,
49+
"BLOCK_N": 128,
50+
"NUM_BUFFERS_Q": 1,
51+
"NUM_BUFFERS_KV": 6,
52+
"NUM_BUFFERS_QK": 1,
53+
"NUM_MMA_GROUPS": 2,
54+
"NUM_MMA_SLICES": 2,
55+
},
56+
num_stages=0,
57+
num_warps=4,
58+
pre_hook=_host_descriptor_pre_hook,
59+
),
4660
]
4761

4862

63+
def prune_configs_by_hdim(configs, named_args, **kwargs):
64+
HEAD_DIM = kwargs["HEAD_DIM"]
65+
target_kv_buffers = 6 if HEAD_DIM == 64 else 3
66+
# Only match HEAD_DIM for BLOCK_N
67+
return [conf for conf in configs if conf.kwargs.get("NUM_BUFFERS_KV", 0) == target_kv_buffers]
68+
69+
4970
@triton.jit
5071
def _get_bufidx_phase(accum_cnt, NUM_BUFFERS_KV):
5172
bufIdx = accum_cnt % NUM_BUFFERS_KV
@@ -161,15 +182,15 @@ def _mask_scalar(qk, col_limit_right, s, i):
161182

162183

163184
@triton.jit
164-
def _apply_causal_mask(qk, col_limit_right, HEAD_DIM: tl.constexpr):
185+
def _apply_causal_mask(qk, col_limit_right, BLOCK_N: tl.constexpr):
165186
# Apply causal mask via a bitmask calculated for each block of 16 elements.
166187
# This allows the efficient R2P (register to predicate) instruction to be used at the SASS level.
167188
# Credit to Tri Dao,
168189
# https://github.com/Dao-AILab/flash-attention/commit/bac1001e4f6caa09d70537495d6746a685a2fa78
169190
#
170191
# NOTE: We use map_elementiwse here in order to generate an interleaved sequence of instructions
171192
# that processes one element of qk at a time. This improves ptxas's resulting SASS.
172-
offs_n = tl.arange(0, HEAD_DIM)[None, :]
193+
offs_n = tl.arange(0, BLOCK_N)[None, :]
173194
s = offs_n & ~0xF
174195
i = offs_n & 0xF
175196
return tl.map_elementwise(_mask_scalar, qk, col_limit_right, s, i)
@@ -209,16 +230,16 @@ def _softmax_inner_loop(
209230

210231
if STAGE == 2:
211232
col_limit_right = (offs_m - start_n + 1)[:, None]
212-
qk = _apply_causal_mask(qk, col_limit_right, HEAD_DIM)
233+
qk = _apply_causal_mask(qk, col_limit_right, BLOCK_N)
213234

214235
# compute m_i, p in registers
215236
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
216237

217238
# -- compute correction factor
218239
alpha = tl.math.exp2(m_i - m_ij)
219240
tlx.barrier_wait(tlx.local_view(alpha_empties, cid), qk_phase ^ 1)
220-
# Use alpha[0] for cid=0, and alpha[HEAD_DIM] for cid=1
221-
tlx.local_store(tlx.local_view(alpha_tiles, cid * HEAD_DIM), alpha[:, None])
241+
# Use alpha[0] for cid=0, and alpha[BLOCK_N] for cid=1
242+
tlx.local_store(tlx.local_view(alpha_tiles, cid * BLOCK_N), alpha[:, None])
222243
tlx.barrier_arrive(tlx.local_view(alpha_fulls, cid))
223244

224245
qk = _fma_f32x2(qk, qk_scale, -m_ij[:, None])
@@ -243,7 +264,11 @@ def _softmax_inner_loop(
243264
return m_i, l_i, accum_cnt_qk
244265

245266

246-
@triton.autotune(configs=configs, key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "STAGE"])
267+
@triton.autotune(
268+
configs=configs,
269+
key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "STAGE"],
270+
prune_configs_by={"early_config_prune": prune_configs_by_hdim},
271+
)
247272
@triton.jit
248273
def _attn_fwd_ws(sm_scale, M, #
249274
Z, H, desc_q, desc_k, desc_v, desc_o, N_CTX, #
@@ -258,7 +283,6 @@ def _attn_fwd_ws(sm_scale, M, #
258283
NUM_MMA_GROUPS: tl.constexpr, #
259284
NUM_MMA_SLICES: tl.constexpr, #
260285
):
261-
tl.static_assert(BLOCK_N <= HEAD_DIM)
262286
tl.static_assert(NUM_MMA_GROUPS == 2)
263287
tl.static_assert(NUM_BUFFERS_QK == 1)
264288
tl.static_assert(NUM_BUFFERS_Q == 1)
@@ -357,8 +381,8 @@ def _attn_fwd_ws(sm_scale, M, #
357381
for cid in tl.static_range(0, NUM_MMA_GROUPS):
358382
# -- update output accumulator --
359383
tlx.barrier_wait(alpha_fulls[cid], phase)
360-
# Use alpha[0] for cid=0, and alpha[HEAD_DIM] for cid=1
361-
alpha_1 = tlx.local_load(alpha_tiles[cid * HEAD_DIM])
384+
# Use alpha[0] for cid=0, and alpha[BLOCK_N] for cid=1
385+
alpha_1 = tlx.local_load(alpha_tiles[cid * BLOCK_N])
362386
tlx.barrier_arrive(alpha_empties[cid])
363387
for slice_id in tl.static_range(0, NUM_MMA_SLICES):
364388
subslice = tlx.subslice(
@@ -377,11 +401,11 @@ def _attn_fwd_ws(sm_scale, M, #
377401
for cid in tl.static_range(0, NUM_MMA_GROUPS):
378402
# epilogue
379403
tlx.barrier_wait(l_fulls[cid], phase)
380-
# Use l[1]/l[1+HEAD_DIM] and m[2][2 + HEAD_DIM]
381-
# to disambigulate from alpha[0]/alpha[HEAD_DIM]
382-
l = tlx.local_load(l_tiles[cid * HEAD_DIM + 1])
404+
# Use l[1]/l[1+BLOCK_N] and m[2][2 + BLOCK_N]
405+
# to disambigulate from alpha[0]/alpha[BLOCK_N]
406+
l = tlx.local_load(l_tiles[cid * BLOCK_N + 1])
383407
tlx.barrier_arrive(qk_empties[cid])
384-
m = tlx.local_load(m_tiles[cid * HEAD_DIM + 2])
408+
m = tlx.local_load(m_tiles[cid * BLOCK_N + 2])
385409
m += tl.math.log2(l)
386410
offs_m = (start_m * BLOCK_M + cid * BLOCK_M_SPLIT + tl.arange(0, BLOCK_M_SPLIT))
387411
m_ptrs = M + off_hz * N_CTX + offs_m
@@ -479,10 +503,10 @@ def _attn_fwd_ws(sm_scale, M, #
479503
)
480504

481505
# prepare l_i for the epilog
482-
# Use l[1]/l[1+HEAD_DIM] and m[2][2 + HEAD_DIM]
483-
# to disambigulate from alpha[0]/alpha[HEAD_DIM]
484-
tlx.local_store(l_tiles[cid * HEAD_DIM + 1], l_i[:, None])
485-
tlx.local_store(m_tiles[cid * HEAD_DIM + 2], m_i[:, None])
506+
# Use l[1]/l[1+BLOCK_N] and m[2][2 + BLOCK_N]
507+
# to disambigulate from alpha[0]/alpha[BLOCK_N]
508+
tlx.local_store(l_tiles[cid * BLOCK_N + 1], l_i[:, None])
509+
tlx.local_store(m_tiles[cid * BLOCK_N + 2], m_i[:, None])
486510
tlx.barrier_arrive(l_fulls[cid])
487511
tile_idx += num_progs
488512

@@ -1621,7 +1645,7 @@ def grid(meta):
16211645
@pytest.mark.parametrize("Z", [8])
16221646
@pytest.mark.parametrize("H", [16])
16231647
@pytest.mark.parametrize("N_CTX", [1024])
1624-
@pytest.mark.parametrize("HEAD_DIM", [128])
1648+
@pytest.mark.parametrize("HEAD_DIM", [64, 128])
16251649
@pytest.mark.parametrize("mode", ["fwd", "bwd"])
16261650
@pytest.mark.parametrize("provider", ["triton-fp16"])
16271651
@pytest.mark.parametrize("causal", [True, False])
@@ -1633,7 +1657,9 @@ def test_op(Z, H, N_CTX, HEAD_DIM, mode, provider, causal, dtype=torch.float16):
16331657
sm_scale = 0.5
16341658
# reference implementation
16351659
ref_dtype = dtype
1636-
if mode == "fwd" and not causal:
1660+
if mode == "bwd" and HEAD_DIM == 64:
1661+
pytest.skip("Only test bwd with 128")
1662+
elif mode == "fwd" and not causal and HEAD_DIM == 128:
16371663
pytest.skip("Only test fwd with causal")
16381664
elif mode == "bwd" and causal:
16391665
pytest.skip("Causal not supported for bwd yet")

0 commit comments

Comments
 (0)