Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit c316434

Browse files
mmoskaljoerunde
authored andcommitted
[Bugfix][Kernel] allow non-power-of-two head sizes in prefix prefill (vllm-project#4128)
1 parent 556df30 commit c316434

File tree

2 files changed

+28
-18
lines changed

2 files changed

+28
-18
lines changed

tests/kernels/test_prefix_prefill.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
NUM_HEADS = [64]
1212
NUM_QUERIES_PER_KV = [1, 8, 64]
13-
HEAD_SIZES = [128]
13+
HEAD_SIZES = [128, 96]
1414
DTYPES = [torch.float16]
1515
CUDA_DEVICES = [
1616
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)

vllm/attention/ops/prefix_prefill.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ def _fwd_kernel(
4747
stride_v_cache_bl,
4848
num_queries_per_kv: int,
4949
BLOCK_M: tl.constexpr,
50-
BLOCK_DMODEL: tl.constexpr,
50+
BLOCK_DMODEL: tl.constexpr, # head size
51+
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
5152
BLOCK_N: tl.constexpr,
5253
):
5354
cur_batch = tl.program_id(0)
@@ -59,26 +60,30 @@ def _fwd_kernel(
5960
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
6061
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
6162
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
63+
cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len
6264

6365
block_start_loc = BLOCK_M * start_m
6466

6567
# initialize offsets
6668
offs_n = tl.arange(0, BLOCK_N)
67-
offs_d = tl.arange(0, BLOCK_DMODEL)
69+
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
6870
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
6971
off_q = (
7072
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
7173
cur_head * stride_qh + offs_d[None, :] * stride_qd)
7274

73-
q = tl.load(
74-
Q + off_q,
75-
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
76-
other=0.0)
75+
dim_mask = tl.where(
76+
tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)
77+
78+
q = tl.load(Q + off_q,
79+
mask=dim_mask[None, :] &
80+
(offs_m[:, None] < cur_batch_query_len),
81+
other=0.0)
7782

7883
# # initialize pointer to m and l
7984
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
8085
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
81-
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
86+
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
8287

8388
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
8489
start_n = tl.multiple_of(start_n, BLOCK_N)
@@ -99,7 +104,8 @@ def _fwd_kernel(
99104
offs_d[None, :] * stride_v_cache_d +
100105
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
101106
k = tl.load(K_cache + off_k,
102-
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
107+
mask=dim_mask[:, None] &
108+
((start_n + offs_n[None, :]) < cur_batch_ctx_len),
103109
other=0.0)
104110

105111
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
@@ -126,7 +132,8 @@ def _fwd_kernel(
126132
acc = acc * acc_scale[:, None]
127133
# update acc
128134
v = tl.load(V_cache + off_v,
129-
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
135+
mask=dim_mask[None, :] &
136+
((start_n + offs_n[:, None]) < cur_batch_ctx_len),
130137
other=0.0)
131138

132139
p = p.to(v.dtype)
@@ -142,16 +149,15 @@ def _fwd_kernel(
142149
k_ptrs = K + off_k
143150
v_ptrs = V + off_v
144151

145-
block_mask = tl.where(
146-
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
152+
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
147153

148154
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
149155
start_n = tl.multiple_of(start_n, BLOCK_N)
150156
# -- compute qk ----
151157
k = tl.load(k_ptrs +
152158
(cur_batch_in_all_start_index + start_n) * stride_kbs,
153-
mask=(start_n + offs_n[None, :]) <
154-
cur_batch_seq_len - cur_batch_ctx_len,
159+
mask=dim_mask[:, None] &
160+
((start_n + offs_n[None, :]) < cur_batch_query_len),
155161
other=0.0)
156162

157163
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
@@ -179,8 +185,8 @@ def _fwd_kernel(
179185
# update acc
180186
v = tl.load(v_ptrs +
181187
(cur_batch_in_all_start_index + start_n) * stride_vbs,
182-
mask=(start_n + offs_n[:, None]) <
183-
cur_batch_seq_len - cur_batch_ctx_len,
188+
mask=dim_mask[None, :] &
189+
((start_n + offs_n[:, None]) < cur_batch_query_len),
184190
other=0.0)
185191

186192
p = p.to(v.dtype)
@@ -195,7 +201,8 @@ def _fwd_kernel(
195201
out_ptrs = Out + off_o
196202
tl.store(out_ptrs,
197203
acc,
198-
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
204+
mask=dim_mask[None, :] &
205+
(offs_m[:, None] < cur_batch_query_len))
199206
return
200207

201208
@triton.jit
@@ -636,7 +643,8 @@ def context_attention_fwd(q,
636643
# shape constraints
637644
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
638645
assert Lq == Lk and Lk == Lv
639-
assert Lk in {16, 32, 64, 128}
646+
# round up Lk to a power of 2 - this is required for Triton block size
647+
Lk_padded = 2**((Lk - 1).bit_length())
640648

641649
sm_scale = 1.0 / (Lq**0.5)
642650
batch, head = b_seq_len.shape[0], q.shape[1]
@@ -646,6 +654,7 @@ def context_attention_fwd(q,
646654

647655
num_warps = 8 if Lk <= 64 else 8
648656
if alibi_slopes is not None:
657+
assert Lk == Lk_padded
649658
_fwd_kernel_alibi[grid](
650659
q,
651660
k,
@@ -738,6 +747,7 @@ def context_attention_fwd(q,
738747
num_queries_per_kv=num_queries_per_kv,
739748
BLOCK_M=BLOCK,
740749
BLOCK_DMODEL=Lk,
750+
BLOCK_DMODEL_PADDED=Lk_padded,
741751
BLOCK_N=BLOCK,
742752
num_warps=num_warps,
743753
num_stages=1,

0 commit comments

Comments
 (0)