Skip to content

Commit 5e62f98

Browse files
author
sangchengmeng
committed
[feat]fa support cu_seqlens
1 parent 2e7b02f commit 5e62f98

File tree

1 file changed

+44
-54
lines changed

1 file changed

+44
-54
lines changed

lightllm/models/vit/triton_kernel/flashattention_nopad.py

Lines changed: 44 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,20 @@ def _fwd_kernel(
1313
K,
1414
V,
1515
sm_scale,
16-
seq_len,
1716
Out,
18-
q_stride_b,
1917
q_stride_s,
2018
q_stride_h,
2119
q_stride_d,
22-
k_stride_b,
2320
k_stride_s,
2421
k_stride_h,
2522
k_stride_d,
26-
v_stride_b,
2723
v_stride_s,
2824
v_stride_h,
2925
v_stride_d,
30-
o_stride_b,
3126
o_stride_s,
3227
o_stride_h,
3328
o_stride_d,
3429
head_dim_act,
35-
is_varlen: tl.constexpr,
3630
cu_seqlens,
3731
BLOCK_M: tl.constexpr,
3832
BLOCK_DMODEL: tl.constexpr,
@@ -42,29 +36,17 @@ def _fwd_kernel(
4236
cur_head = tl.program_id(1)
4337
start_m = tl.program_id(0)
4438

45-
if is_varlen == 1:
46-
seq_start = tl.load(cu_seqlens + cur_batch).to(tl.int32)
47-
seq_end = tl.load(cu_seqlens + cur_batch + 1).to(tl.int32)
48-
seq_len = seq_end - seq_start
49-
q_stride_b = 0
50-
k_stride_b = 0
51-
v_stride_b = 0
52-
o_stride_b = 0
53-
else:
54-
seq_start = 0
39+
seq_start = tl.load(cu_seqlens + cur_batch).to(tl.int32)
40+
seq_end = tl.load(cu_seqlens + cur_batch + 1).to(tl.int32)
41+
seq_len = seq_end - seq_start
5542

5643
# initialize offsets
5744
offs_n = tl.arange(0, BLOCK_N)
5845
offs_d = tl.arange(0, BLOCK_DMODEL)
5946
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
6047

6148
mask_d = offs_d < head_dim_act
62-
off_q = (
63-
cur_batch * q_stride_b
64-
+ cur_head * q_stride_h
65-
+ (seq_start + offs_m[:, None]) * q_stride_s
66-
+ offs_d[None, :] * q_stride_d
67-
)
49+
off_q = cur_head * q_stride_h + (seq_start + offs_m[:, None]) * q_stride_s + offs_d[None, :] * q_stride_d
6850
q = tl.load(Q + off_q, mask=(offs_m[:, None] < seq_len) & mask_d[None, :], other=0.0)
6951
# initialize pointer to m and l
7052
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
@@ -75,15 +57,14 @@ def _fwd_kernel(
7557
start_n = tl.multiple_of(start_n, BLOCK_N)
7658
# -- compute qk ----
7759
off_k = (
78-
cur_batch * k_stride_b
79-
+ (seq_start + start_n + offs_n[None, :]) * k_stride_s
60+
(seq_start + start_n + offs_n[None, :]) * k_stride_s
8061
+ cur_head * k_stride_h
8162
+ offs_d[:, None] * k_stride_d
8263
)
8364
k = tl.load(K + off_k, mask=((start_n + offs_n[None, :]) < seq_len) & mask_d[:, None], other=0.0)
8465

8566
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
86-
qk += tl.dot(q, k, out_dtype=tl.float32, allow_tf32=False)
67+
qk += tl.dot(q, k)
8768
qk *= sm_scale
8869
qk += tl.where((start_n + offs_n[None, :]) < seq_len, 0, float("-inf"))
8970

@@ -97,8 +78,7 @@ def _fwd_kernel(
9778

9879
# update acc
9980
off_v = (
100-
cur_batch * v_stride_b
101-
+ (seq_start + start_n + offs_n[:, None]) * v_stride_s
81+
(seq_start + start_n + offs_n[:, None]) * v_stride_s
10282
+ cur_head * v_stride_h
10383
+ offs_d[None, :] * v_stride_d
10484
)
@@ -115,12 +95,7 @@ def _fwd_kernel(
11595
o_scale = tl.exp(m_i - l_i)
11696
acc = acc * o_scale[:, None]
11797
# initialize pointers to output
118-
off_o = (
119-
cur_batch * o_stride_b
120-
+ (seq_start + offs_m[:, None]) * o_stride_s
121-
+ cur_head * o_stride_h
122-
+ offs_d[None, :] * o_stride_d
123-
)
98+
off_o = (seq_start + offs_m[:, None]) * o_stride_s + cur_head * o_stride_h + offs_d[None, :] * o_stride_d
12499
out_ptrs = Out + off_o
125100
tl.store(out_ptrs, acc, mask=(offs_m[:, None] < seq_len) & mask_d[None, :])
126101
return
@@ -132,49 +107,57 @@ def _flash_attention_triton_fwd(
132107
v,
133108
o,
134109
cu_seqlens=None, # q k v cu_seqlens,
135-
max_seqlens=None,
110+
max_seqlen=None,
136111
):
137112
BLOCK = 64
138113
# shape constraints
114+
assert q.shape == k.shape == v.shape == o.shape, "q, k, v, o must have the same shape"
139115

140-
batch_size, seq_len, head_num, head_dim = q.shape
141-
if cu_seqlens is not None and max_seqlens is not None:
142-
assert q.shape[0] == 1
116+
if q.ndim == 4:
117+
bs, seq_len, head_num, head_dim = q.shape
118+
total_len = bs * seq_len
119+
reshape_fn = lambda t: t.view(total_len, head_num, head_dim)
120+
q, k, v, o = [reshape_fn(x) for x in (q, k, v, o)]
121+
elif q.ndim == 3:
122+
total_len, head_num, head_dim = q.shape
123+
else:
124+
raise ValueError("q,k,v,o must be 3d or 4d")
125+
126+
if cu_seqlens is None: # 说明是定长的
127+
cu_seqlens = torch.arange(bs + 1, dtype=torch.int32, device=q.device) * seq_len
128+
else:
143129
cu_seqlens = cu_seqlens.to(q.device, torch.int32)
144-
seq_len = max_seqlens
145-
batch_size = cu_seqlens.numel() - 1
130+
131+
if max_seqlen is None:
132+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
133+
134+
batch_size = cu_seqlens.numel() - 1
146135

147136
d_pad = triton.next_power_of_2(head_dim)
148137
sm_scale = 1.0 / (head_dim ** 0.5) # 计算scale系数
149-
# grid = (batch_size, head_num, triton.cdiv(seq_len, BLOCK)) # batch, head,
150-
grid = (triton.cdiv(seq_len, BLOCK), head_num, batch_size) # batch, head,
138+
139+
grid = (triton.cdiv(max_seqlen, BLOCK), head_num, batch_size) # batch, head,
151140
num_warps = 4
152141
_fwd_kernel[grid](
153142
q,
154143
k,
155144
v,
156145
sm_scale,
157-
seq_len,
158146
o,
159147
q.stride(0),
160148
q.stride(1),
161149
q.stride(2),
162-
q.stride(3),
163150
k.stride(0),
164151
k.stride(1),
165152
k.stride(2),
166-
k.stride(3),
167153
v.stride(0),
168154
v.stride(1),
169155
v.stride(2),
170-
v.stride(3),
171156
o.stride(0),
172157
o.stride(1),
173158
o.stride(2),
174-
o.stride(3),
175159
head_dim,
176-
is_varlen=1 if cu_seqlens is not None else 0,
177-
cu_seqlens=0 if cu_seqlens is None else cu_seqlens,
160+
cu_seqlens,
178161
BLOCK_M=BLOCK,
179162
BLOCK_DMODEL=d_pad,
180163
BLOCK_N=BLOCK,
@@ -198,10 +181,17 @@ def flash_attention_v3_fwd(
198181
v,
199182
o,
200183
cu_seqlens=None,
201-
max_seqlens=None,
184+
max_seqlen=None,
202185
):
203186
head_dim = q.shape[-1]
204187
softmax_scale = head_dim ** -0.5
188+
if cu_seqlens is not None:
189+
cu_seqlens = cu_seqlens.to(q.device, torch.int32)
190+
if q.ndim == 4:
191+
bs, seq_len, head_num, head_dim = q.shape
192+
total_len = bs * seq_len
193+
reshape_fn = lambda t: t.view(total_len, head_num, head_dim)
194+
q, k, v, o = [reshape_fn(x) for x in (q, k, v, o)]
205195
_flash_attn_forward(
206196
q,
207197
k,
@@ -214,8 +204,8 @@ def flash_attention_v3_fwd(
214204
None, # cu_seqlens_q/k/k_new
215205
None,
216206
None, # seqused_q/k
217-
max_seqlens,
218-
max_seqlens, # max_seqlen_q/k
207+
max_seqlen,
208+
max_seqlen, # max_seqlen_q/k
219209
None,
220210
None,
221211
None, # page_table, kv_batch_idx, leftpad_k,
@@ -239,15 +229,15 @@ def flash_attention_v3_fwd(
239229
_flash_attn_v3_available = False
240230

241231

242-
def flash_attention_fwd(q, k, v, o, cu_seqlens=None, max_seqlens=None):
232+
def flash_attention_fwd(q, k, v, o, cu_seqlens=None, max_seqlen=None):
243233
"""
244234
统一的 Flash Attention 接口。如果 _flash_attn_forward 存在,
245235
则使用 flash_attention_v3_fwd,否则使用 Triton 版本。
246236
"""
247237
if _flash_attn_v3_available and is_hopper():
248-
flash_attention_v3_fwd(q, k, v, o, cu_seqlens, max_seqlens)
238+
flash_attention_v3_fwd(q, k, v, o, cu_seqlens, max_seqlen)
249239
else:
250-
_flash_attention_triton_fwd(q, k, v, o, cu_seqlens, max_seqlens)
240+
_flash_attention_triton_fwd(q, k, v, o, cu_seqlens, max_seqlen)
251241

252242

253243
def torch_att(q, k, v):

0 commit comments

Comments
 (0)