Skip to content

Commit 208642d

Browse files
authored
[Attention] Fix attn kernel for general GQA group size (#2543)
This PR fixes the TIR prefill attention kernels to support a broader list of GQA group sizes.
1 parent d5fbde2 commit 208642d

File tree

3 files changed

+58
-44
lines changed

3 files changed

+58
-44
lines changed

python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def create_flashinfer_paged_kv_cache(
155155
in self.metadata["model_type"]
156156
)
157157
# filter by attention group size
158-
or kwargs["num_attention_heads"] // kwargs["num_key_value_heads"] not in [1, 4, 6, 8]
158+
or kwargs["num_attention_heads"] // kwargs["num_key_value_heads"] not in [1, 4, 8]
159159
):
160160
return
161161

python/mlc_llm/nn/kv_cache.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,6 @@ def _attention_prefill(h_kv, h_q, d, dtype, sliding_window: bool, target: Target
525525
bdx = 32
526526
num_warps = 4
527527
tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16
528-
L_per_cta = tile_x // group_size
529528

530529
# Otherwise we would exceed maxComputeWorkgroupStorageSize
531530
if (
@@ -641,8 +640,7 @@ def batch_prefill_paged_kv(
641640

642641
if T.tvm_thread_invariant(batch_idx[0] < batch_size):
643642
b_idx: T.int32 = batch_idx[0]
644-
L_start: T.int32 = q_indptr[b_idx] + tile_id[0] * L_per_cta
645-
H_qo_start: T.int32 = by * group_size
643+
LH_start: T.int32 = tile_id[0] * tile_x
646644

647645
cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
648646
cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
@@ -672,8 +670,8 @@ def batch_prefill_paged_kv(
672670
i, j = T.axis.remap("SS", [li, lj])
673671
T.reads()
674672
T.writes()
675-
cur_L = L_start + i // group_size
676-
cur_H_qo = H_qo_start + i % group_size
673+
cur_L = q_indptr[b_idx] + (LH_start + i) // group_size
674+
cur_H_qo = by * group_size + (LH_start + i) % group_size
677675
if cur_L < q_indptr[b_idx + 1]:
678676
Q_smem[i, j] = T.if_then_else(
679677
rotary_mode == 1,
@@ -742,9 +740,10 @@ def batch_prefill_paged_kv(
742740
m_prev[i] = m_smem[row]
743741
m_new[i] = m_smem[row]
744742
# mask out of kv_chunk_len S
743+
row_: T.int32 = (LH_start + row) // group_size
745744
for j in T.serial(tile_z):
746745
if _causal_mask(causal,
747-
row=tile_id[0] * L_per_cta + row // group_size,
746+
row=row_,
748747
col=L_kv_start + j,
749748
kv_len=kv_chunk_len[0],
750749
qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]):
@@ -757,8 +756,9 @@ def batch_prefill_paged_kv(
757756
for j in T.serial(tile_z):
758757
# this is to avoid sync inside condition branch
759758
if row < tile_x:
759+
row_: T.int32 = (LH_start + row) // group_size
760760
if _causal_mask(causal,
761-
row=tile_id[0] * L_per_cta + row // group_size,
761+
row=row_,
762762
col=L_kv_start + j,
763763
kv_len=kv_chunk_len[0],
764764
qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]):
@@ -790,15 +790,19 @@ def batch_prefill_paged_kv(
790790
for li, lj in T.grid(tile_x, tile_y):
791791
with T.block("O_store"):
792792
i, j = T.axis.remap("SS", [li, lj])
793-
if L_start + i // group_size < q_indptr[b_idx + 1]:
794-
output[L_start + i // group_size, H_qo_start + i % group_size, j] = O_local[i, j] / d_smem[i]
793+
cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size
794+
cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size
795+
if cur_L < q_indptr[b_idx + 1]:
796+
output[cur_L, cur_H_qo, j] = O_local[i, j] / d_smem[i]
795797

796798
# Store LSE to gmem
797799
for li in T.grid(tile_x):
798800
with T.block("lse_store"):
799801
i = T.axis.remap("S", [li])
800-
if L_start + i // group_size < q_indptr[b_idx + 1]:
801-
lse[L_start + i // group_size, H_qo_start + i % group_size] = m_smem[i] + T.log2(d_smem[i])
802+
cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size
803+
cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size
804+
if cur_L < q_indptr[b_idx + 1]:
805+
lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
802806

803807
# move to next tile
804808
tile_id[0] += NUM_BLKS
@@ -1218,7 +1222,6 @@ def _attention_prefill_ragged(
12181222
bdx = 32
12191223
num_warps = 4
12201224
tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16
1221-
L_per_cta = tile_x // group_size
12221225

12231226
# Otherwise we would exceed maxComputeWorkgroupStorageSize
12241227
if (
@@ -1313,8 +1316,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
13131316

13141317
if T.tvm_thread_invariant(batch_idx[0] < batch_size):
13151318
b_idx: T.int32 = batch_idx[0]
1316-
L_start: T.int32 = q_indptr[b_idx] + tile_id[0] * L_per_cta
1317-
H_qo_start: T.int32 = by * group_size
1319+
LH_start: T.int32 = tile_id[0] * tile_x
13181320

13191321
kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx]
13201322
T.tvm_storage_sync("shared")
@@ -1338,8 +1340,8 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
13381340
i, j = T.axis.remap("SS", [li, lj])
13391341
T.reads()
13401342
T.writes()
1341-
cur_L = L_start + i // group_size
1342-
cur_H_qo = H_qo_start + i % group_size
1343+
cur_L = q_indptr[b_idx] + (LH_start + i) // group_size
1344+
cur_H_qo = by * group_size + (LH_start + i) % group_size
13431345
if cur_L < q_indptr[b_idx + 1]:
13441346
Q_smem[i, j] = T.if_then_else(
13451347
rotary_mode == 1,
@@ -1403,9 +1405,10 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
14031405
m_prev[i] = m_smem[row]
14041406
m_new[i] = m_smem[row]
14051407
# mask out of kv_chunk_len S
1408+
row_: T.int32 = (LH_start + row) // group_size
14061409
for j in T.serial(tile_z):
14071410
if _causal_mask(causal,
1408-
row=tile_id[0] * L_per_cta + row // group_size,
1411+
row=row_,
14091412
col=L_kv_start + j,
14101413
kv_len=kv_chunk_len[0],
14111414
qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]):
@@ -1418,8 +1421,9 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
14181421
for j in T.serial(tile_z):
14191422
# this is to avoid sync inside condition branch
14201423
if row < tile_x:
1424+
row_: T.int32 = (LH_start + row) // group_size
14211425
if _causal_mask(causal,
1422-
row=tile_id[0] * L_per_cta + row // group_size,
1426+
row=row_,
14231427
col=L_kv_start + j,
14241428
kv_len=kv_chunk_len[0],
14251429
qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]):
@@ -1451,15 +1455,19 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
14511455
for li, lj in T.grid(tile_x, tile_y):
14521456
with T.block("O_store"):
14531457
i, j = T.axis.remap("SS", [li, lj])
1454-
if L_start + i // group_size < q_indptr[b_idx + 1]:
1455-
output[L_start + i // group_size, H_qo_start + i % group_size, j] = O_local[i, j] / d_smem[i]
1458+
cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size
1459+
cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size
1460+
if cur_L < q_indptr[b_idx + 1]:
1461+
output[cur_L, cur_H_qo, j] = O_local[i, j] / d_smem[i]
14561462

14571463
# Store LSE to gmem
14581464
for li in T.grid(tile_x):
14591465
with T.block("lse_store"):
14601466
i = T.axis.remap("S", [li])
1461-
if L_start + i // group_size < q_indptr[b_idx + 1]:
1462-
lse[L_start + i // group_size, H_qo_start + i % group_size] = m_smem[i] + T.log2(d_smem[i])
1467+
cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size
1468+
cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size
1469+
if cur_L < q_indptr[b_idx + 1]:
1470+
lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
14631471

14641472
# move to next tile
14651473
tile_id[0] += NUM_BLKS

python/mlc_llm/op/tree_attn.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ def tree_attn(h_kv, h_q, d, dtype, target: Target): # pylint: disable=unused-ar
7272
bdx = 32
7373
num_warps = 4
7474
tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16
75-
L_per_cta = tile_x // group_size
7675

7776
# Otherwise we would exceed maxComputeWorkgroupStorageSize
7877
if (
@@ -170,8 +169,7 @@ def batch_tree_attn( # pylint: disable=too-many-branches
170169

171170
if T.tvm_thread_invariant(batch_idx[0] < batch_size):
172171
b_idx: T.int32 = batch_idx[0]
173-
L_start: T.int32 = q_indptr[b_idx] + tile_id[0] * L_per_cta
174-
H_qo_start: T.int32 = by * group_size
172+
LH_start: T.int32 = tile_id[0] * tile_x
175173

176174
kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx]
177175
T.tvm_storage_sync("shared")
@@ -195,8 +193,8 @@ def batch_tree_attn( # pylint: disable=too-many-branches
195193
i, j = T.axis.remap("SS", [li, lj])
196194
T.reads()
197195
T.writes()
198-
cur_L = L_start + i // group_size
199-
cur_H_qo = H_qo_start + i % group_size
196+
cur_L = q_indptr[b_idx] + (LH_start + i) // group_size
197+
cur_H_qo = by * group_size + (LH_start + i) % group_size
200198
if cur_L < q_indptr[b_idx + 1]:
201199
Q_smem[i, j] = T.if_then_else(
202200
rotary_mode == 1,
@@ -251,13 +249,15 @@ def batch_tree_attn( # pylint: disable=too-many-branches
251249
m_prev[i] = m_smem[row]
252250
m_new[i] = m_smem[row]
253251
# mask out of kv_chunk_len S
252+
row_: T.int32 = (LH_start + row) // group_size
254253
for j in T.serial(tile_z):
255-
if _tree_mask(row=tile_id[0] * L_per_cta + row // group_size,
256-
col=L_kv_start + j,
257-
mask_ptr=mask,
258-
offset=mn_indptr[b_idx],
259-
stride=q_indptr[b_idx + 1] - q_indptr[b_idx],
260-
kv_len=kv_chunk_len[0]):
254+
if _tree_mask(
255+
row=row_,
256+
col=L_kv_start + j,
257+
mask_ptr=mask,
258+
offset=mn_indptr[b_idx],
259+
stride=q_indptr[b_idx + 1] - q_indptr[b_idx],
260+
kv_len=kv_chunk_len[0]):
261261
m_new[i] = T.max(m_new[i], S_smem[row, j])
262262
d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
263263

@@ -267,12 +267,14 @@ def batch_tree_attn( # pylint: disable=too-many-branches
267267
for j in T.serial(tile_z):
268268
# this is to avoid sync inside condition branch
269269
if row < tile_x:
270-
if _tree_mask(row=tile_id[0] * L_per_cta + row // group_size,
271-
col=L_kv_start + j,
272-
mask_ptr=mask,
273-
offset=mn_indptr[b_idx],
274-
stride=q_indptr[b_idx + 1] - q_indptr[b_idx],
275-
kv_len=kv_chunk_len[0]):
270+
row_: T.int32 = (LH_start + row) // group_size
271+
if _tree_mask(
272+
row=row_,
273+
col=L_kv_start + j,
274+
mask_ptr=mask,
275+
offset=mn_indptr[b_idx],
276+
stride=q_indptr[b_idx + 1] - q_indptr[b_idx],
277+
kv_len=kv_chunk_len[0]):
276278
S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
277279
else:
278280
S_smem[row, j] = T.exp2(-5e4 - m_new[i])
@@ -301,15 +303,19 @@ def batch_tree_attn( # pylint: disable=too-many-branches
301303
for li, lj in T.grid(tile_x, tile_y):
302304
with T.block("O_store"):
303305
i, j = T.axis.remap("SS", [li, lj])
304-
if L_start + i // group_size < q_indptr[b_idx + 1]:
305-
output[L_start + i // group_size, H_qo_start + i % group_size, j] = O_local[i, j] / d_smem[i]
306+
cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size
307+
cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size
308+
if cur_L < q_indptr[b_idx + 1]:
309+
output[cur_L, cur_H_qo, j] = O_local[i, j] / d_smem[i]
306310

307311
# Store LSE to gmem
308312
for li in T.grid(tile_x):
309313
with T.block("lse_store"):
310314
i = T.axis.remap("S", [li])
311-
if L_start + i // group_size < q_indptr[b_idx + 1]:
312-
lse[L_start + i // group_size, H_qo_start + i % group_size] = m_smem[i] + T.log2(d_smem[i])
315+
cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size
316+
cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size
317+
if cur_L < q_indptr[b_idx + 1]:
318+
lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
313319

314320
# move to next tile
315321
tile_id[0] += NUM_BLKS

0 commit comments

Comments
 (0)