Skip to content

Commit d45865c

Browse files
committed
custom flash_attn
1 parent d70d188 commit d45865c

File tree

4 files changed

+47
-158
lines changed

4 files changed

+47
-158
lines changed

lightllm/common/basemodel/triton_kernel/gen_decode_params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def gen_decode_params(b_seq_len: torch.Tensor):
1616

1717
if enable_fa3_mtp:
1818
b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(
19-
b_q_seq_len[: len(b_seq_len) // mtp_size], b_kv_seq_len[mtp_size - 1 :: mtp_size]
19+
b_q_seq_len[mtp_size - 1 :: mtp_size], b_kv_seq_len[mtp_size - 1 :: mtp_size]
2020
)
2121
else:
2222
b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len)

lightllm/common/flash_attn.py

Lines changed: 36 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,11 @@
1-
# This file is adapted from sgl-project/sglang:
2-
# https://github.com/sgl-project/sglang/blob/main/sgl-kernel/python/sgl_kernel/flash_attn.py
3-
# The original code and this file are licensed under the Apache License, Version 2.0.
4-
#
5-
# Copyright (c) sgl-project and other contributors.
6-
# Modifications Copyright (c) LightLLM contributors.
7-
#
8-
# Licensed under the Apache License, Version 2.0 (the "License");
9-
# you may not use this file except in compliance with the License.
10-
# You may obtain a copy of the License at
11-
#
12-
# http://www.apache.org/licenses/LICENSE-2.0
13-
#
14-
# Unless required by applicable law or agreed to in writing, software
15-
# distributed under the License is distributed on an "AS IS" BASIS,
16-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17-
# See the License for the specific language governing permissions and
18-
# limitations under the License.
191
import torch
202
from typing import List, Optional, Tuple, Union
213
from lightllm.utils.log_utils import init_logger
224

235
logger = init_logger(__name__)
246

257

26-
def maybe_contiguous(x):
8+
def get_contiguous(x):
279
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
2810

2911

@@ -34,152 +16,61 @@ def maybe_contiguous(x):
3416

3517
def flash_attn_with_kvcache_mtp(
3618
q,
37-
k_cache,
38-
v_cache,
39-
k=None,
40-
v=None,
41-
qv=None,
42-
rotary_cos=None,
43-
rotary_sin=None,
44-
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
45-
cache_batch_idx: Optional[torch.Tensor] = None,
46-
cache_leftpad: Optional[torch.Tensor] = None,
47-
page_table: Optional[torch.Tensor] = None,
19+
k,
20+
v,
21+
k_new: Optional[torch.Tensor] = None,
22+
v_new: Optional[torch.Tensor] = None,
23+
q_v: Optional[torch.Tensor] = None,
4824
cu_seqlens_q: Optional[torch.Tensor] = None,
25+
cu_seqlens_k: Optional[torch.Tensor] = None,
4926
cu_seqlens_k_new: Optional[torch.Tensor] = None,
27+
seqused_q: Optional[torch.Tensor] = None,
28+
seqused_k: Optional[torch.Tensor] = None,
5029
max_seqlen_q: Optional[int] = None,
30+
max_seqlen_k: Optional[int] = None,
31+
page_table: Optional[torch.Tensor] = None,
32+
cache_batch_idx: Optional[torch.Tensor] = None,
33+
cache_leftpad: Optional[torch.Tensor] = None,
34+
rotary_cos: Optional[torch.Tensor] = None,
35+
rotary_sin: Optional[torch.Tensor] = None,
5136
rotary_seqlens: Optional[torch.Tensor] = None,
5237
q_descale: Optional[torch.Tensor] = None,
5338
k_descale: Optional[torch.Tensor] = None,
5439
v_descale: Optional[torch.Tensor] = None,
5540
softmax_scale=None,
56-
causal=False,
57-
window_size=(-1, -1), # -1 means infinite context window
41+
is_causal=False,
42+
window_size=(-1, -1),
5843
softcap=0.0, # 0.0 means deactivated
59-
rotary_interleaved=True,
44+
is_rotary_interleaved=True,
6045
scheduler_metadata=None,
61-
num_splits=0, # Can be tuned for speed
62-
pack_gqa=None, # Can be tuned for speed
63-
sm_margin=0, # Can be tuned if some SMs are used for communication
64-
return_softmax_lse=False,
46+
num_splits=0,
47+
pack_gqa=None,
48+
sm_margin=0,
6549
mtp_step=0,
6650
):
67-
"""
68-
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
69-
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
70-
the previous step, and update them with the new keys/values from the current step, and do
71-
attention with the updated cache, all in 1 kernel.
72-
73-
If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
74-
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
75-
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
76-
77-
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
78-
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
79-
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
80-
and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
81-
If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
82-
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
83-
84-
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
85-
86-
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
87-
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
88-
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
89-
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
90-
91-
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
92-
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
93-
1 1 1 1 0
94-
1 1 1 1 1
95-
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
96-
0 0
97-
0 0
98-
0 0
99-
1 0
100-
1 1
101-
If the row of the mask is all zero, the output will be zero.
102-
103-
If window_size != (-1, -1), implements sliding window local attention. Query at position i
104-
will only attend to keys between
105-
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
106-
107-
Note: Does not support backward pass.
108-
109-
Arguments:
110-
q: (batch_size, seqlen, nheads, headdim)
111-
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
112-
or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
113-
page_block_size must be a multiple of 256.
114-
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
115-
or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
116-
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
117-
k with k_cache, starting at the indices specified by cache_seqlens.
118-
v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
119-
qv [optional]: (batch_size, seqlen, nheads, headdim_v)
120-
rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
121-
to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
122-
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
123-
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
124-
KV cache.
125-
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
126-
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
127-
If the indices are not distinct, and k and v are provided, the values updated in the cache
128-
might come from any of the duplicate indices.
129-
cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
130-
page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
131-
softmax_scale: float. The scaling of QK^T before applying softmax.
132-
Default to 1 / sqrt(headdim).
133-
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
134-
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
135-
softcap: float. Anything > 0 activates softcapping attention.
136-
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
137-
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
138-
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
139-
(i.e. GPT-NeoX style).
140-
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
141-
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
142-
to automatically determine the number of splits.
143-
Don't change this unless you know what you are doing.
144-
return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
145-
146-
Return:
147-
out: (batch_size, seqlen, nheads, headdim).
148-
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
149-
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
150-
normalization factor).
151-
"""
152-
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
153-
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
51+
assert k.stride(-1) == 1, "k must have contiguous last dimension"
52+
assert v.stride(-1) == 1, "v must have contiguous last dimension"
15453
if softmax_scale is None:
155-
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
156-
if cache_seqlens is not None and isinstance(cache_seqlens, int):
157-
cache_seqlens = torch.full((k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device)
158-
cache_seqlens = maybe_contiguous(cache_seqlens)
159-
160-
q, k_cache, k, v = [maybe_contiguous(x) for x in (q, k_cache, k, v)]
161-
v_cache = v_cache.contiguous() if v_cache.stride(-1) != 1 and v_cache.stride(-3) != 1 else v_cache
162-
cu_seqlens_q, cu_seqlens_k_new = [maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k_new)]
163-
page_table, cache_batch_idx, cache_leftpad = [
164-
maybe_contiguous(x) for x in (page_table, cache_batch_idx, cache_leftpad)
165-
]
166-
rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
167-
rotary_seqlens = maybe_contiguous(rotary_seqlens)
54+
softmax_scale = (q.shape[-1] + (q_v.shape[-1] if q_v is not None else 0)) ** (-0.5)
55+
seqused_k = get_contiguous(seqused_k)
16856

169-
# out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
57+
q, k, k_new, v_new = [get_contiguous(x) for x in (q, k, k_new, v_new)]
58+
v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
59+
cu_seqlens_q, cu_seqlens_k_new = [get_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k_new)]
60+
page_table = get_contiguous(page_table)
17061
out, softmax_lse, *rest = flash_attn_3_mtp.fwd(
17162
q,
172-
k_cache,
173-
v_cache,
17463
k,
17564
v,
176-
qv,
65+
k_new,
66+
v_new,
67+
q_v,
17768
None, # out
17869
cu_seqlens_q,
17970
None, # cu_seqlens_k
18071
cu_seqlens_k_new,
18172
None, # seqused_q
182-
cache_seqlens,
73+
seqused_k,
18374
max_seqlen_q,
18475
None, # max_seqlen_k
18576
page_table,
@@ -192,19 +83,19 @@ def flash_attn_with_kvcache_mtp(
19283
k_descale,
19384
v_descale,
19485
softmax_scale,
195-
causal,
86+
is_causal,
19687
window_size[0],
19788
window_size[1],
19889
0,
19990
softcap,
200-
rotary_interleaved,
91+
is_rotary_interleaved,
20192
scheduler_metadata,
20293
num_splits,
20394
pack_gqa,
20495
sm_margin,
20596
mtp_step,
20697
)
207-
return (out, softmax_lse, *rest) if return_softmax_lse else out
98+
return out
20899

209100
except:
210101
flash_attn_3_mtp = None

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -563,21 +563,20 @@ def _token_gqa_decode_attention_mtp(
563563
k_descale, v_descale = None, None
564564
o_tensor = flash_attn_with_kvcache_mtp(
565565
q=q_rope.reshape(-1, self.tp_q_head_num_ * self.mtp_size, self.qk_rope_head_dim),
566-
k_cache=k_rope,
567-
v_cache=kv_nope,
568-
qv=q_nope.reshape(-1, self.tp_q_head_num_ * self.mtp_size, self.kv_lora_rank),
566+
k=k_rope,
567+
v=kv_nope,
568+
q_v=q_nope.reshape(-1, self.tp_q_head_num_ * self.mtp_size, self.kv_lora_rank),
569569
page_table=infer_state.page_table[self.mtp_size - 1 :: self.mtp_size],
570-
cache_seqlens=infer_state.b_seq_len[self.mtp_size - 1 :: self.mtp_size].contiguous(),
570+
seqused_k=infer_state.b_seq_len[self.mtp_size - 1 :: self.mtp_size].contiguous(),
571571
cu_seqlens_q=infer_state.cu_seqlens_q,
572572
cu_seqlens_k_new=infer_state.cu_seqlens_k,
573573
max_seqlen_q=1,
574574
softmax_scale=self.softmax_scale,
575-
causal=True,
575+
is_causal=True,
576576
window_size=(-1, -1),
577577
softcap=0.0,
578578
k_descale=k_descale,
579579
v_descale=v_descale,
580-
return_softmax_lse=False,
581580
mtp_step=self.mtp_step,
582581
)
583582
return o_tensor.view(-1, self.tp_q_head_num_, self.kv_lora_rank)

test/benchmark/kernel/benchmark_fa3_decode_mtp.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,21 +113,20 @@ def run_fa3_mla_mtp(
113113
def flash_mla_fa3():
114114
out = flash_attn_with_kvcache_mtp(
115115
q=q_pe.view(-1, BLOCK_H, dpe),
116-
k_cache=blocked_k_pe,
117-
v_cache=blocked_k_nope,
118-
qv=q_nope.view(-1, BLOCK_H, dv),
116+
k=blocked_k_pe,
117+
v=blocked_k_nope,
118+
q_v=q_nope.view(-1, BLOCK_H, dv),
119119
page_table=block_table,
120-
cache_seqlens=cache_seqlens,
120+
seqused_k=cache_seqlens,
121121
cu_seqlens_q=cu_seqlens_q,
122122
cu_seqlens_k_new=cu_seqlens_k,
123123
max_seqlen_q=1,
124124
softmax_scale=scale,
125-
causal=True,
125+
is_causal=True,
126126
window_size=(-1, -1),
127127
softcap=0.0,
128128
k_descale=k_descale,
129129
v_descale=v_descale,
130-
return_softmax_lse=False,
131130
mtp_step=1,
132131
)
133132
return out.view([b, s_q, h_q, dv])

0 commit comments

Comments
 (0)