Skip to content

Commit b2183e6

Browse files
author
sangchengmeng
committed
[feat]fa support cu_seqlens
1 parent 5e62f98 commit b2183e6

File tree

5 files changed

+95
-88
lines changed

5 files changed

+95
-88
lines changed

lightllm/models/qwen2_5_vl/qwen2_5_visual.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
2424
from lightllm.models.qwen2_vl.qwen2_visual import PatchEmbed, VisionRotaryEmbedding
2525
from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
26+
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
2627

2728
# adapted from
2829
# https://github.com/huggingface/transformers/blob/
@@ -149,12 +150,9 @@ def forward(
149150
cos, sin = position_embeddings
150151
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
151152

152-
q = q.unsqueeze(0)
153-
k = k.unsqueeze(0)
154-
v = v.unsqueeze(0)
155-
153+
cu_seqlens = cu_seqlens.to(q.device, torch.int32)
156154
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
157-
attn_output = torch.empty_like(q)
155+
attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device)
158156
flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen)
159157
attn_output = attn_output.reshape(seq_length, -1)
160158
attn_output = self.proj(attn_output)

lightllm/models/qwen2_vl/qwen2_visual.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
4545
from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor
4646
from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
47+
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
4748

4849
from transformers.utils import is_flash_attn_2_available
4950

@@ -224,10 +225,13 @@ def forward(
224225
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
225226
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb)
226227
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb)
227-
v = v.unsqueeze(0)
228+
q = q.squeeze(0)
229+
k = k.squeeze(0)
228230

231+
cu_seqlens = cu_seqlens.to(q.device, torch.int32)
229232
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
230-
attn_output = torch.empty_like(q, dtype=q.dtype, device=q.device)
233+
attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device)
234+
231235
flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen)
232236
attn_output = attn_output.reshape(seq_length, -1)
233237
attn_output = self.proj(attn_output)

lightllm/models/vit/layer_infer/transformer_layer_infer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,13 @@ def _get_qkv(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tens
103103

104104
def _context_attention_kernel(self, q, k, v) -> torch.Tensor:
105105
out = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device)
106-
batch_size = q.shape[0]
107-
seq_len = q.shape[1]
108-
flash_attention_fwd(q, k, v, out)
106+
batch_size, seq_len, head_num, head_dim = q.shape
107+
total_len = batch_size * seq_len
108+
reshape = lambda t: t.view(total_len, head_num, head_dim)
109+
q, k, v, out = map(reshape, (q, k, v, out))
110+
cu_seqlens = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device) * seq_len
111+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
112+
flash_attention_fwd(q, k, v, out, cu_seqlens, max_seqlen)
109113
return out.reshape(batch_size, seq_len, -1)
110114

111115
def _get_o(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:

lightllm/models/vit/triton_kernel/flashattention_nopad.py

Lines changed: 11 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
import triton
33
import triton.language as tl
44
import math
5+
import time
56
import torch.nn.functional as F
7+
from typing import Optional, Tuple
68
from lightllm.utils.device_utils import is_hopper
79

810
if triton.__version__ >= "2.1.0":
@@ -82,9 +84,7 @@ def _fwd_kernel(
8284
+ cur_head * v_stride_h
8385
+ offs_d[None, :] * v_stride_d
8486
)
85-
v = tl.load(V + off_v, mask=((start_n + offs_n[:, None]) < seq_len) & mask_d[None, :], other=0.0).to(
86-
tl.float32
87-
)
87+
v = tl.load(V + off_v, mask=((start_n + offs_n[:, None]) < seq_len) & mask_d[None, :], other=0.0)
8888
p = p.to(v.dtype)
8989
acc += tl.dot(p, v)
9090
# update m_i and l_i
@@ -106,36 +106,17 @@ def _flash_attention_triton_fwd(
106106
k,
107107
v,
108108
o,
109-
cu_seqlens=None, # q k v cu_seqlens,
110-
max_seqlen=None,
109+
cu_seqlens, # q k v cu_seqlens,
110+
max_seqlen,
111111
):
112112
BLOCK = 64
113113
# shape constraints
114-
assert q.shape == k.shape == v.shape == o.shape, "q, k, v, o must have the same shape"
115-
116-
if q.ndim == 4:
117-
bs, seq_len, head_num, head_dim = q.shape
118-
total_len = bs * seq_len
119-
reshape_fn = lambda t: t.view(total_len, head_num, head_dim)
120-
q, k, v, o = [reshape_fn(x) for x in (q, k, v, o)]
121-
elif q.ndim == 3:
122-
total_len, head_num, head_dim = q.shape
123-
else:
124-
raise ValueError("q,k,v,o must be 3d or 4d")
125-
126-
if cu_seqlens is None: # 说明是定长的
127-
cu_seqlens = torch.arange(bs + 1, dtype=torch.int32, device=q.device) * seq_len
128-
else:
129-
cu_seqlens = cu_seqlens.to(q.device, torch.int32)
130-
131-
if max_seqlen is None:
132-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
133-
114+
assert q.ndim == k.ndim == v.ndim == o.ndim == 3, "q, k, v, o must be 3D tensors"
115+
_, head_num, head_dim = q.shape
134116
batch_size = cu_seqlens.numel() - 1
135117

136-
d_pad = triton.next_power_of_2(head_dim)
137118
sm_scale = 1.0 / (head_dim ** 0.5) # 计算scale系数
138-
119+
d_pad = triton.next_power_of_2(head_dim)
139120
grid = (triton.cdiv(max_seqlen, BLOCK), head_num, batch_size) # batch, head,
140121
num_warps = 4
141122
_fwd_kernel[grid](
@@ -180,18 +161,11 @@ def flash_attention_v3_fwd(
180161
k,
181162
v,
182163
o,
183-
cu_seqlens=None,
184-
max_seqlen=None,
164+
cu_seqlens,
165+
max_seqlen,
185166
):
186167
head_dim = q.shape[-1]
187168
softmax_scale = head_dim ** -0.5
188-
if cu_seqlens is not None:
189-
cu_seqlens = cu_seqlens.to(q.device, torch.int32)
190-
if q.ndim == 4:
191-
bs, seq_len, head_num, head_dim = q.shape
192-
total_len = bs * seq_len
193-
reshape_fn = lambda t: t.view(total_len, head_num, head_dim)
194-
q, k, v, o = [reshape_fn(x) for x in (q, k, v, o)]
195169
_flash_attn_forward(
196170
q,
197171
k,
@@ -229,7 +203,7 @@ def flash_attention_v3_fwd(
229203
_flash_attn_v3_available = False
230204

231205

232-
def flash_attention_fwd(q, k, v, o, cu_seqlens=None, max_seqlen=None):
206+
def flash_attention_fwd(q, k, v, o, cu_seqlens, max_seqlen):
233207
"""
234208
统一的 Flash Attention 接口。如果 _flash_attn_forward 存在,
235209
则使用 flash_attention_v3_fwd,否则使用 Triton 版本。
@@ -238,44 +212,3 @@ def flash_attention_fwd(q, k, v, o, cu_seqlens=None, max_seqlen=None):
238212
flash_attention_v3_fwd(q, k, v, o, cu_seqlens, max_seqlen)
239213
else:
240214
_flash_attention_triton_fwd(q, k, v, o, cu_seqlens, max_seqlen)
241-
242-
243-
def torch_att(q, k, v):
244-
head_dim = q.shape[-1]
245-
q = q.transpose(1, 2)
246-
k = k.transpose(1, 2)
247-
v = v.transpose(1, 2)
248-
scale = head_dim ** -0.5
249-
attn = (q * scale) @ k.transpose(-2, -1)
250-
attn = attn.softmax(dim=-1)
251-
out = attn @ v
252-
out = out.transpose(1, 2).contiguous()
253-
return out
254-
255-
256-
def test():
257-
import torch
258-
import numpy as np
259-
260-
B, L, H, D = 4, 1025, 7, 128
261-
dtype = torch.float16
262-
q = torch.empty((B, L, H, D), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
263-
k = torch.empty((B, L, H, D), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
264-
v = torch.empty((B, L, H, D), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
265-
o = torch.empty((B, L, H, D), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
266-
torch_out = torch_att(q, k, v)
267-
import time
268-
269-
torch.cuda.synchronize()
270-
a = time.time()
271-
for i in range(100):
272-
flash_attention_fwd(q, k, v, o)
273-
# o = torch_att(q, k, v)
274-
torch.cuda.synchronize()
275-
b = time.time()
276-
# print(o.shape, torch_out.shape)
277-
print((b - a) / 100 * 1000)
278-
279-
print("max ", torch.max(torch.abs(torch_out - o)))
280-
print("mean ", torch.mean(torch.abs(torch_out - o)))
281-
assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
import math
5+
import time
6+
import torch.nn.functional as F
7+
from typing import Optional, Tuple
8+
from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
9+
10+
11+
def reference_attention_varlen(q, k, v, cu):
12+
"""
13+
q, k, v : (total_len, n_head, D)
14+
cu_seqlen : prefix sums (batch+1,)
15+
"""
16+
total, n_head, d = q.shape
17+
out = torch.empty_like(q)
18+
scale = 1.0 / math.sqrt(d)
19+
20+
for b in range(cu.numel() - 1):
21+
s, e = cu[b].item(), cu[b + 1].item()
22+
q_b, k_b, v_b = q[s:e], k[s:e], v[s:e] # (seq, head, D)
23+
24+
q_hsd = q_b.permute(1, 0, 2) # (head, seq, D)
25+
k_hds = k_b.permute(1, 2, 0) # (head, D, seq)
26+
v_hsd = v_b.permute(1, 0, 2) # (head, seq, D)
27+
28+
scores = torch.matmul(q_hsd, k_hds) * scale # (head, seq, seq)
29+
probs = torch.softmax(scores.float(), dim=-1)
30+
31+
out_hsd = torch.matmul(probs, v_hsd.float()) # (head, seq, D)
32+
out[s:e] = out_hsd.permute(1, 0, 2).to(q.dtype) # back to (seq, head, D)
33+
34+
return out
35+
36+
37+
def test_varlen(batch=4, heads=8, d=80, dtype=torch.bfloat16, atol=1e-2, device="cuda:0"):
38+
torch.manual_seed(0)
39+
lengths = torch.randint(1, 257, (batch,))
40+
max_len = int(lengths.max().item())
41+
42+
cu = torch.zeros(batch + 1, dtype=torch.int32, device=device)
43+
cu[1:] = torch.cumsum(lengths, 0)
44+
tot = int(cu[-1])
45+
46+
q = torch.randn(tot, heads, d, dtype=dtype, device=device)
47+
k = torch.randn_like(q)
48+
v = torch.randn_like(q)
49+
out_tri = torch.randn_like(q)
50+
flash_attention_fwd(q, k, v, out_tri, cu, max_len)
51+
a = time.time()
52+
for _ in range(1000):
53+
flash_attention_fwd(q, k, v, out_tri, cu, max_len)
54+
b = time.time()
55+
print(f"flash_attention_fwd time: {(b - a) / 1000 * 1000:.2f} ms")
56+
out_ref = reference_attention_varlen(q, k, v, cu)
57+
58+
max_err = (out_ref - out_tri).abs().max().item()
59+
mean_err = (out_ref - out_tri).abs().mean().item()
60+
print(f"{dtype}: max {max_err:.6f}, mean {mean_err:.6f}")
61+
torch.testing.assert_close(out_tri, out_ref, atol=atol, rtol=0)
62+
63+
64+
if __name__ == "__main__":
65+
tests = [(torch.float16, 1e-2), (torch.bfloat16, 2e-2)]
66+
for dt, tol in tests:
67+
test_varlen(dtype=dt, atol=tol)
68+
print("✓ variable-length Flash-Attention all dtypes pass")

0 commit comments

Comments
 (0)