Skip to content

Commit 1aa427f

Browse files
authored
[Kernels] Add Flash Linear Attention Kernels (vllm-project#24518)
Signed-off-by: youkaichao <[email protected]>
1 parent 1c63a16 commit 1aa427f

File tree

17 files changed

+2671
-2
lines changed

17 files changed

+2671
-2
lines changed

tools/mypy.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ run_mypy vllm/engine
2929
run_mypy vllm/executor
3030
run_mypy vllm/inputs
3131
run_mypy vllm/lora
32-
run_mypy vllm/model_executor
32+
run_mypy --exclude 'vllm/model_executor/layers/fla/ops' vllm/model_executor
3333
run_mypy vllm/plugins
3434
run_mypy vllm/worker
3535
run_mypy vllm/v1
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
4+
#
5+
# This file contains code copied from the flash-linear-attention project.
6+
# The original source code was licensed under the MIT license and included
7+
# the following copyright notice:
8+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
4+
#
5+
# This file contains code copied from the flash-linear-attention project.
6+
# The original source code was licensed under the MIT license and included
7+
# the following copyright notice:
8+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
9+
from .chunk import chunk_gated_delta_rule
10+
from .fused_recurrent import fused_recurrent_gated_delta_rule
11+
from .layernorm_guard import RMSNormGated
12+
13+
__all__ = [
14+
"RMSNormGated",
15+
"chunk_gated_delta_rule",
16+
"fused_recurrent_gated_delta_rule",
17+
]
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
4+
#
5+
# This file contains code copied from the flash-linear-attention project.
6+
# The original source code was licensed under the MIT license and included
7+
# the following copyright notice:
8+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
9+
# ruff: noqa: E501
10+
import warnings
11+
from typing import Optional
12+
13+
import torch
14+
from einops import rearrange
15+
16+
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
17+
from .chunk_o import chunk_fwd_o
18+
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
19+
from .cumsum import chunk_local_cumsum
20+
from .l2norm import l2norm_fwd
21+
from .solve_tril import solve_tril
22+
from .utils import SUPPRESS_LEVEL, input_guard
23+
from .wy_fast import recompute_w_u_fwd
24+
25+
26+
def chunk_gated_delta_rule_fwd(q: torch.Tensor,
27+
k: torch.Tensor,
28+
v: torch.Tensor,
29+
g: torch.Tensor,
30+
beta: torch.Tensor,
31+
scale: float,
32+
initial_state: torch.Tensor,
33+
output_final_state: bool,
34+
cu_seqlens: Optional[torch.LongTensor] = None):
35+
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
36+
# obtain WY representation. u is actually the new v.
37+
A = chunk_scaled_dot_kkt_fwd(k=k,
38+
beta=beta,
39+
g_cumsum=g,
40+
cu_seqlens=cu_seqlens,
41+
output_dtype=torch.float32)
42+
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
43+
w, u = recompute_w_u_fwd(
44+
k=k,
45+
v=v,
46+
beta=beta,
47+
A=A,
48+
g_cumsum=g,
49+
cu_seqlens=cu_seqlens,
50+
)
51+
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
52+
k=k,
53+
w=w,
54+
u=u,
55+
g=g,
56+
initial_state=initial_state,
57+
output_final_state=output_final_state,
58+
cu_seqlens=cu_seqlens,
59+
)
60+
o = chunk_fwd_o(
61+
q=q,
62+
k=k,
63+
v=v_new,
64+
h=h,
65+
g=g,
66+
scale=scale,
67+
cu_seqlens=cu_seqlens,
68+
)
69+
if SUPPRESS_LEVEL < 3:
70+
return g, o, A, final_state, None, None, None
71+
elif SUPPRESS_LEVEL >= 3:
72+
return g, o, A, final_state, w, h, v_new
73+
74+
75+
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
76+
77+
@staticmethod
78+
@input_guard
79+
@torch.amp.custom_fwd(device_type='cuda')
80+
def forward(ctx,
81+
q: torch.Tensor,
82+
k: torch.Tensor,
83+
v: torch.Tensor,
84+
g: torch.Tensor,
85+
beta: torch.Tensor,
86+
scale: float,
87+
initial_state: torch.Tensor,
88+
output_final_state: bool,
89+
cu_seqlens: Optional[torch.LongTensor] = None,
90+
use_qk_l2norm_in_kernel: bool = False):
91+
if use_qk_l2norm_in_kernel:
92+
q = l2norm_fwd(q)
93+
k = l2norm_fwd(k)
94+
95+
g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd(
96+
q=q,
97+
k=k,
98+
v=v,
99+
g=g,
100+
beta=beta,
101+
scale=scale,
102+
initial_state=initial_state,
103+
output_final_state=output_final_state,
104+
cu_seqlens=cu_seqlens,
105+
)
106+
ctx.scale = scale
107+
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
108+
return o.to(q.dtype), final_state
109+
110+
111+
@torch.compiler.disable
112+
def chunk_gated_delta_rule(q: torch.Tensor,
113+
k: torch.Tensor,
114+
v: torch.Tensor,
115+
g: torch.Tensor,
116+
beta: torch.Tensor,
117+
scale: float = None,
118+
initial_state: torch.Tensor = None,
119+
output_final_state: bool = False,
120+
cu_seqlens: Optional[torch.LongTensor] = None,
121+
head_first: bool = False,
122+
use_qk_l2norm_in_kernel: bool = False):
123+
r"""
124+
Args:
125+
q (torch.Tensor):
126+
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
127+
k (torch.Tensor):
128+
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
129+
v (torch.Tensor):
130+
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
131+
g (torch.Tensor):
132+
(forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
133+
beta (torch.Tensor):
134+
betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
135+
scale (Optional[int]):
136+
Scale factor for the RetNet attention scores.
137+
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
138+
initial_state (Optional[torch.Tensor]):
139+
Initial state of shape `[N, H, K, V]` for `N` input sequences.
140+
For equal-length input sequences, `N` equals the batch size `B`.
141+
Default: `None`.
142+
output_final_state (Optional[bool]):
143+
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
144+
cu_seqlens (torch.LongTensor):
145+
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
146+
consistent with the FlashAttention API.
147+
head_first (Optional[bool]):
148+
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
149+
Default: `False`.
150+
151+
Returns:
152+
o (torch.Tensor):
153+
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
154+
final_state (torch.Tensor):
155+
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
156+
157+
Examples::
158+
>>> import torch
159+
>>> import torch.nn.functional as F
160+
>>> from einops import rearrange
161+
>>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
162+
# inputs with equal lengths
163+
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
164+
>>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
165+
>>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
166+
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
167+
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
168+
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
169+
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
170+
>>> o, ht = chunk_gated_delta_rule(
171+
q, k, v, g, beta,
172+
initial_state=h0,
173+
output_final_state=True
174+
)
175+
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
176+
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
177+
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
178+
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
179+
>>> o_var, ht_var = chunk_gated_delta_rule(
180+
q, k, v, g, beta,
181+
initial_state=h0,
182+
output_final_state=True,
183+
cu_seqlens=cu_seqlens
184+
)
185+
"""
186+
assert q.dtype == k.dtype == v.dtype
187+
assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
188+
assert len(
189+
beta.shape
190+
) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
191+
192+
if head_first:
193+
raise DeprecationWarning(
194+
"head_first is deprecated and will be removed in a future version. "
195+
"Please use head_first=False for now instead.",
196+
stacklevel=2)
197+
q, k, v, beta, g = map(
198+
lambda x: rearrange(x, 'b h t ... -> b t h ...'),
199+
(q, k, v, beta, g))
200+
if not head_first and q.shape[1] < q.shape[2]:
201+
warnings.warn(
202+
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
203+
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
204+
"when head_first=False was specified. "
205+
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
206+
stacklevel=2)
207+
if cu_seqlens is not None:
208+
if q.shape[0] != 1:
209+
raise ValueError(
210+
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
211+
f"Please flatten variable-length inputs before processing.")
212+
if initial_state is not None and initial_state.shape[0] != len(
213+
cu_seqlens) - 1:
214+
raise ValueError(
215+
f"The number of initial states is expected to be equal to the number of input sequences, "
216+
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
217+
)
218+
if scale is None:
219+
scale = k.shape[-1]**-0.5
220+
o, final_state = ChunkGatedDeltaRuleFunction.apply(
221+
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens,
222+
use_qk_l2norm_in_kernel)
223+
if head_first:
224+
o = rearrange(o, 'b t h ... -> b h t ...')
225+
return o, final_state

0 commit comments

Comments
 (0)