1- # This file is adapted from sgl-project/sglang:
2- # https://github.com/sgl-project/sglang/blob/main/sgl-kernel/python/sgl_kernel/flash_attn.py
3- # The original code and this file are licensed under the Apache License, Version 2.0.
4- #
5- # Copyright (c) sgl-project and other contributors.
6- # Modifications Copyright (c) LightLLM contributors.
7- #
8- # Licensed under the Apache License, Version 2.0 (the "License");
9- # you may not use this file except in compliance with the License.
10- # You may obtain a copy of the License at
11- #
12- # http://www.apache.org/licenses/LICENSE-2.0
13- #
14- # Unless required by applicable law or agreed to in writing, software
15- # distributed under the License is distributed on an "AS IS" BASIS,
16- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17- # See the License for the specific language governing permissions and
18- # limitations under the License.
191import torch
202from typing import List , Optional , Tuple , Union
213from lightllm .utils .log_utils import init_logger
224
235logger = init_logger (__name__ )
246
257
26- def maybe_contiguous (x ):
8+ def get_contiguous (x ):
279 return x .contiguous () if x is not None and x .stride (- 1 ) != 1 else x
2810
2911
@@ -34,152 +16,61 @@ def maybe_contiguous(x):
3416
3517 def flash_attn_with_kvcache_mtp (
3618 q ,
37- k_cache ,
38- v_cache ,
39- k = None ,
40- v = None ,
41- qv = None ,
42- rotary_cos = None ,
43- rotary_sin = None ,
44- cache_seqlens : Optional [Union [(int , torch .Tensor )]] = None ,
45- cache_batch_idx : Optional [torch .Tensor ] = None ,
46- cache_leftpad : Optional [torch .Tensor ] = None ,
47- page_table : Optional [torch .Tensor ] = None ,
19+ k ,
20+ v ,
21+ k_new : Optional [torch .Tensor ] = None ,
22+ v_new : Optional [torch .Tensor ] = None ,
23+ q_v : Optional [torch .Tensor ] = None ,
4824 cu_seqlens_q : Optional [torch .Tensor ] = None ,
25+ cu_seqlens_k : Optional [torch .Tensor ] = None ,
4926 cu_seqlens_k_new : Optional [torch .Tensor ] = None ,
27+ seqused_q : Optional [torch .Tensor ] = None ,
28+ seqused_k : Optional [torch .Tensor ] = None ,
5029 max_seqlen_q : Optional [int ] = None ,
30+ max_seqlen_k : Optional [int ] = None ,
31+ page_table : Optional [torch .Tensor ] = None ,
32+ cache_batch_idx : Optional [torch .Tensor ] = None ,
33+ cache_leftpad : Optional [torch .Tensor ] = None ,
34+ rotary_cos : Optional [torch .Tensor ] = None ,
35+ rotary_sin : Optional [torch .Tensor ] = None ,
5136 rotary_seqlens : Optional [torch .Tensor ] = None ,
5237 q_descale : Optional [torch .Tensor ] = None ,
5338 k_descale : Optional [torch .Tensor ] = None ,
5439 v_descale : Optional [torch .Tensor ] = None ,
5540 softmax_scale = None ,
56- causal = False ,
57- window_size = (- 1 , - 1 ), # -1 means infinite context window
41+ is_causal = False ,
42+ window_size = (- 1 , - 1 ),
5843 softcap = 0.0 , # 0.0 means deactivated
59- rotary_interleaved = True ,
44+ is_rotary_interleaved = True ,
6045 scheduler_metadata = None ,
61- num_splits = 0 , # Can be tuned for speed
62- pack_gqa = None , # Can be tuned for speed
63- sm_margin = 0 , # Can be tuned if some SMs are used for communication
64- return_softmax_lse = False ,
46+ num_splits = 0 ,
47+ pack_gqa = None ,
48+ sm_margin = 0 ,
6549 mtp_step = 0 ,
6650 ):
67- """
68- If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
69- k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
70- the previous step, and update them with the new keys/values from the current step, and do
71- attention with the updated cache, all in 1 kernel.
72-
73- If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
74- For example, the KV cache could be pre-allocated with the max sequence length, and you can use
75- cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
76-
77- Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
78- rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
79- If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
80- and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
81- If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
82- indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
83-
84- See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
85-
86- Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
87- than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
88- For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
89- 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
90-
91- If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
92- For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
93- 1 1 1 1 0
94- 1 1 1 1 1
95- If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
96- 0 0
97- 0 0
98- 0 0
99- 1 0
100- 1 1
101- If the row of the mask is all zero, the output will be zero.
102-
103- If window_size != (-1, -1), implements sliding window local attention. Query at position i
104- will only attend to keys between
105- [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
106-
107- Note: Does not support backward pass.
108-
109- Arguments:
110- q: (batch_size, seqlen, nheads, headdim)
111- k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
112- or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
113- page_block_size must be a multiple of 256.
114- v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
115- or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
116- k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
117- k with k_cache, starting at the indices specified by cache_seqlens.
118- v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
119- qv [optional]: (batch_size, seqlen, nheads, headdim_v)
120- rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
121- to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
122- rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
123- cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
124- KV cache.
125- cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
126- If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
127- If the indices are not distinct, and k and v are provided, the values updated in the cache
128- might come from any of the duplicate indices.
129- cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
130- page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
131- softmax_scale: float. The scaling of QK^T before applying softmax.
132- Default to 1 / sqrt(headdim).
133- causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
134- window_size: (left, right). If not (-1, -1), implements sliding window local attention.
135- softcap: float. Anything > 0 activates softcapping attention.
136- rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
137- If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
138- rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
139- (i.e. GPT-NeoX style).
140- num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
141- If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
142- to automatically determine the number of splits.
143- Don't change this unless you know what you are doing.
144- return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
145-
146- Return:
147- out: (batch_size, seqlen, nheads, headdim).
148- softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
149- logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
150- normalization factor).
151- """
152- assert k_cache .stride (- 1 ) == 1 , "k_cache must have contiguous last dimension"
153- assert v_cache .stride (- 1 ) == 1 , "v_cache must have contiguous last dimension"
51+ assert k .stride (- 1 ) == 1 , "k must have contiguous last dimension"
52+ assert v .stride (- 1 ) == 1 , "v must have contiguous last dimension"
15453 if softmax_scale is None :
155- softmax_scale = (q .shape [- 1 ] + (qv .shape [- 1 ] if qv is not None else 0 )) ** (- 0.5 )
156- if cache_seqlens is not None and isinstance (cache_seqlens , int ):
157- cache_seqlens = torch .full ((k_cache .shape [0 ],), cache_seqlens , dtype = torch .int32 , device = k_cache .device )
158- cache_seqlens = maybe_contiguous (cache_seqlens )
159-
160- q , k_cache , k , v = [maybe_contiguous (x ) for x in (q , k_cache , k , v )]
161- v_cache = v_cache .contiguous () if v_cache .stride (- 1 ) != 1 and v_cache .stride (- 3 ) != 1 else v_cache
162- cu_seqlens_q , cu_seqlens_k_new = [maybe_contiguous (x ) for x in (cu_seqlens_q , cu_seqlens_k_new )]
163- page_table , cache_batch_idx , cache_leftpad = [
164- maybe_contiguous (x ) for x in (page_table , cache_batch_idx , cache_leftpad )
165- ]
166- rotary_cos , rotary_sin = [maybe_contiguous (x ) for x in (rotary_cos , rotary_sin )]
167- rotary_seqlens = maybe_contiguous (rotary_seqlens )
54+ softmax_scale = (q .shape [- 1 ] + (q_v .shape [- 1 ] if q_v is not None else 0 )) ** (- 0.5 )
55+ seqused_k = get_contiguous (seqused_k )
16856
169- # out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
57+ q , k , k_new , v_new = [get_contiguous (x ) for x in (q , k , k_new , v_new )]
58+ v = v .contiguous () if v .stride (- 1 ) != 1 and v .stride (- 3 ) != 1 else v
59+ cu_seqlens_q , cu_seqlens_k_new = [get_contiguous (x ) for x in (cu_seqlens_q , cu_seqlens_k_new )]
60+ page_table = get_contiguous (page_table )
17061 out , softmax_lse , * rest = flash_attn_3_mtp .fwd (
17162 q ,
172- k_cache ,
173- v_cache ,
17463 k ,
17564 v ,
176- qv ,
65+ k_new ,
66+ v_new ,
67+ q_v ,
17768 None , # out
17869 cu_seqlens_q ,
17970 None , # cu_seqlens_k
18071 cu_seqlens_k_new ,
18172 None , # seqused_q
182- cache_seqlens ,
73+ seqused_k ,
18374 max_seqlen_q ,
18475 None , # max_seqlen_k
18576 page_table ,
@@ -192,19 +83,19 @@ def flash_attn_with_kvcache_mtp(
19283 k_descale ,
19384 v_descale ,
19485 softmax_scale ,
195- causal ,
86+ is_causal ,
19687 window_size [0 ],
19788 window_size [1 ],
19889 0 ,
19990 softcap ,
200- rotary_interleaved ,
91+ is_rotary_interleaved ,
20192 scheduler_metadata ,
20293 num_splits ,
20394 pack_gqa ,
20495 sm_margin ,
20596 mtp_step ,
20697 )
207- return ( out , softmax_lse , * rest ) if return_softmax_lse else out
98+ return out
20899
209100except :
210101 flash_attn_3_mtp = None
0 commit comments