Skip to content

Commit 9625d82

Browse files
authored
Add FA3 (#3623)
* add fa3 * fix cu_seqlens_k on chat mode * add fa3 for qwen3 * fix fa3 * replace bhsd with hsd
1 parent b5ceeed commit 9625d82

File tree

3 files changed

+237
-39
lines changed

3 files changed

+237
-39
lines changed

lmdeploy/pytorch/backends/cuda/attention.py

Lines changed: 199 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,23 @@
66
import torch
77

88
from lmdeploy.pytorch.distributed import get_tp_world_rank
9+
from lmdeploy.utils import get_logger
910

1011
from ..attention import AttentionBuilder, AttentionImpl, AttentionMetadata
1112

13+
logger = get_logger('lmdeploy')
14+
15+
use_fa3 = False
16+
try:
17+
# Now flash-attention only support FA3 for sm90a && cuda >= 12.3
18+
if (torch.cuda.get_device_capability()[0] == 9) and (torch.version.cuda >= '12.3'):
19+
import flash_attn_interface # noqa: F401
20+
assert torch.ops.flash_attn_3 is not None
21+
use_fa3 = True
22+
except Exception:
23+
logger.warning('For higher performance, please install FlashAttention-3 '
24+
'https://github.com/Dao-AILab/flash-attention')
25+
1226

1327
@dataclass
1428
class TritonAttentionMetadata(AttentionMetadata):
@@ -25,6 +39,8 @@ class TritonAttentionMetadata(AttentionMetadata):
2539
# flash mla
2640
tile_scheduler_metadata: torch.Tensor = None
2741
num_splits: torch.Tensor = None
42+
cu_seqlens_q: torch.Tensor = None
43+
cu_seqlens_k: torch.Tensor = None
2844

2945

3046
def _cdiv(a, b):
@@ -89,7 +105,6 @@ def forward(
89105
inplace: bool = True,
90106
) -> torch.Tensor:
91107
"""forward."""
92-
93108
block_offsets = attn_metadata.block_offsets
94109
q_start_loc = attn_metadata.q_start_loc
95110
fill_q_start_loc = q_start_loc
@@ -129,7 +144,6 @@ def forward(
129144
q_shape = query.shape
130145
o_shape = q_shape[:-1] + (self.v_head_size, )
131146
attn_output = query.new_empty(o_shape)
132-
133147
is_decoding = attn_metadata.is_decoding
134148
if not self.alibi:
135149
if is_decoding:
@@ -286,7 +300,6 @@ def forward(
286300

287301
q_shape = query.shape
288302
o_shape = q_shape[:-1] + (self.v_head_size, )
289-
attn_output = query.new_empty(o_shape)
290303

291304
is_decoding = attn_metadata.is_decoding
292305
if is_decoding:
@@ -302,7 +315,6 @@ def forward(
302315
tile_scheduler_metadata=attn_metadata.tile_scheduler_metadata,
303316
num_splits=attn_metadata.num_splits,
304317
causal=True)
305-
306318
else:
307319
BLOCK_BS = k_cache.size(1)
308320
# pad one more block to avoid invalid kv visit
@@ -313,26 +325,179 @@ def forward(
313325
kv_seqlens,
314326
block_offsets,
315327
start_loc=kv_start_loc,
316-
out_size=out_size,
328+
out_size=kv_flatten_size if use_fa3 else out_size,
317329
out_dtype=query.dtype,
318330
k_scales_zeros=k_scales_zeros,
319331
v_scales_zeros=v_scales_zeros,
320332
quant_policy=quant_policy,
333+
flatten_kv_layout='shd' if use_fa3 else 'hsd',
321334
)
322-
self.flash_attention_fwd(
335+
if use_fa3:
336+
q_rope = query[:, :, self.v_head_size:]
337+
q_nope = query[:, :, :self.v_head_size]
338+
k_rope = flatten_k.view(kv_flatten_size, self.num_kv_heads, -1)[:, :, self.v_head_size:]
339+
c_kv = flatten_k.view(kv_flatten_size, self.num_kv_heads, -1)[:, :, :self.v_head_size]
340+
from flash_attn_interface import flash_attn_varlen_func
341+
attn_output, _ = flash_attn_varlen_func(
342+
q=q_rope,
343+
k=k_rope,
344+
v=c_kv,
345+
qv=q_nope,
346+
cu_seqlens_q=attn_metadata.cu_seqlens_q,
347+
cu_seqlens_k=attn_metadata.cu_seqlens_k,
348+
max_seqlen_q=max_q_seqlen,
349+
max_seqlen_k=kv_flatten_size,
350+
softmax_scale=self.scale,
351+
causal=self.causal,
352+
window_size=(-1, -1) if self.sliding_window is None else self.sliding_window,
353+
softcap=-1.0 if self.logit_softcapping is None else self.logit_softcapping,
354+
)
355+
else:
356+
attn_output = query.new_empty(o_shape)
357+
self.flash_attention_fwd(
358+
query,
359+
flatten_k,
360+
flatten_v,
361+
attn_output,
362+
q_start_loc=q_start_loc,
363+
q_seqlens=q_seqlens,
364+
kv_start_loc=kv_start_loc,
365+
kv_seqlens=kv_seqlens,
366+
max_seqlen=max_q_seqlen,
367+
window_size=self.sliding_window,
368+
sm_scale=self.scale,
369+
logit_softcapping=self.logit_softcapping,
370+
causal=self.causal,
371+
)
372+
return attn_output
373+
374+
375+
class FA3Impl(TritonAttentionImpl):
376+
"""Triton attention implementation."""
377+
378+
def __init__(
379+
self,
380+
num_heads: int,
381+
head_size: int,
382+
scale: float = None,
383+
num_kv_heads: int = None,
384+
v_head_size: int = None,
385+
alibi: bool = False,
386+
sliding_window: int = None,
387+
logit_softcapping: float = None,
388+
causal: bool = True,
389+
**kwargs,
390+
):
391+
assert alibi is False, 'alibi not supported for FA3'
392+
super().__init__(
393+
num_heads=num_heads,
394+
head_size=head_size,
395+
scale=scale,
396+
num_kv_heads=num_kv_heads,
397+
v_head_size=v_head_size,
398+
alibi=alibi,
399+
sliding_window=sliding_window,
400+
logit_softcapping=logit_softcapping,
401+
causal=causal,
402+
**kwargs,
403+
)
404+
from flash_attn_interface import flash_attn_varlen_func
405+
self.flash_attn_varlen_func_v3 = flash_attn_varlen_func
406+
407+
def forward(
408+
self,
409+
query: torch.Tensor,
410+
key: torch.Tensor,
411+
value: torch.Tensor,
412+
k_cache: torch.Tensor,
413+
v_cache: torch.Tensor,
414+
attn_metadata: TritonAttentionMetadata,
415+
k_scales_zeros: torch.Tensor = None,
416+
v_scales_zeros: torch.Tensor = None,
417+
inplace: bool = True,
418+
) -> torch.Tensor:
419+
"""forward."""
420+
block_offsets = attn_metadata.block_offsets
421+
q_start_loc = attn_metadata.q_start_loc
422+
fill_q_start_loc = q_start_loc
423+
q_seqlens = attn_metadata.q_seqlens
424+
fill_seqlens = q_seqlens
425+
kv_start_loc = attn_metadata.kv_start_loc
426+
kv_seqlens = attn_metadata.kv_seqlens
427+
kv_flatten_size = attn_metadata.kv_flatten_size
428+
quant_policy = attn_metadata.quant_policy
429+
if attn_metadata.is_decoding:
430+
max_q_seqlen = 1
431+
else:
432+
max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))
433+
fill_max_q_seqlen = max_q_seqlen
434+
if attn_metadata.fill_seqlens is not None:
435+
fill_seqlens = attn_metadata.fill_seqlens
436+
fill_max_q_seqlen = key.numel() // (key.size(-1) * key.size(-2))
437+
fill_q_start_loc = fill_seqlens.cumsum(0) - fill_seqlens
438+
is_decoding = attn_metadata.is_decoding
439+
# fill kv cache
440+
if key is not None and value is not None:
441+
self.fill_kv_cache(
442+
key,
443+
value,
444+
k_cache,
445+
v_cache,
446+
fill_q_start_loc,
447+
fill_seqlens,
448+
kv_seq_length=kv_seqlens,
449+
max_q_seq_length=fill_max_q_seqlen,
450+
block_offsets=block_offsets,
451+
k_scales_zeros=k_scales_zeros,
452+
v_scales_zeros=v_scales_zeros,
453+
quant_policy=quant_policy,
454+
)
455+
456+
q_shape = query.shape
457+
o_shape = q_shape[:-1] + (self.v_head_size, )
458+
attn_output = query.new_empty(o_shape)
459+
460+
if is_decoding:
461+
self.paged_attention_fwd(
323462
query,
324-
flatten_k,
325-
flatten_v,
463+
k_cache,
464+
v_cache,
326465
attn_output,
327-
q_start_loc=q_start_loc,
328-
q_seqlens=q_seqlens,
329-
kv_start_loc=kv_start_loc,
466+
block_offsets,
330467
kv_seqlens=kv_seqlens,
331-
max_seqlen=max_q_seqlen,
468+
k_scales_zeros=k_scales_zeros,
469+
v_scales_zeros=v_scales_zeros,
470+
quant_policy=quant_policy,
332471
window_size=self.sliding_window,
333472
sm_scale=self.scale,
334473
logit_softcapping=self.logit_softcapping,
474+
)
475+
else:
476+
flatten_k, flatten_v = self.flatten_kv_cache(
477+
k_cache,
478+
v_cache,
479+
kv_seqlens,
480+
block_offsets,
481+
start_loc=kv_start_loc,
482+
out_size=kv_flatten_size,
483+
out_dtype=query.dtype,
484+
k_scales_zeros=k_scales_zeros,
485+
v_scales_zeros=v_scales_zeros,
486+
quant_policy=quant_policy,
487+
flatten_kv_layout='shd',
488+
)
489+
attn_output, _ = self.flash_attn_varlen_func_v3(
490+
q=query,
491+
k=flatten_k,
492+
v=flatten_v,
493+
cu_seqlens_q=attn_metadata.cu_seqlens_q,
494+
cu_seqlens_k=attn_metadata.cu_seqlens_k,
495+
max_seqlen_q=max_q_seqlen,
496+
max_seqlen_k=kv_flatten_size,
497+
softmax_scale=self.scale,
335498
causal=self.causal,
499+
window_size=(-1, -1) if self.sliding_window is None else self.sliding_window,
500+
softcap=-1.0 if self.logit_softcapping is None else self.logit_softcapping,
336501
)
337502
return attn_output
338503

@@ -366,13 +531,25 @@ def build(
366531
logical_softcapping=logical_softcapping,
367532
causal=causal,
368533
**kwargs)
369-
return TritonAttentionImpl(num_heads,
370-
head_size,
371-
scale=scale,
372-
num_kv_heads=num_kv_heads,
373-
v_head_size=v_head_size,
374-
alibi=alibi,
375-
sliding_window=sliding_window,
376-
logical_softcapping=logical_softcapping,
377-
causal=causal,
378-
**kwargs)
534+
elif use_fa3 and not alibi:
535+
return FA3Impl(num_heads,
536+
head_size,
537+
scale=scale,
538+
num_kv_heads=num_kv_heads,
539+
v_head_size=v_head_size,
540+
alibi=alibi,
541+
sliding_window=sliding_window,
542+
logical_softcapping=logical_softcapping,
543+
causal=causal,
544+
**kwargs)
545+
else:
546+
return TritonAttentionImpl(num_heads,
547+
head_size,
548+
scale=scale,
549+
num_kv_heads=num_kv_heads,
550+
v_head_size=v_head_size,
551+
alibi=alibi,
552+
sliding_window=sliding_window,
553+
logical_softcapping=logical_softcapping,
554+
causal=causal,
555+
**kwargs)

lmdeploy/pytorch/backends/cuda/op_backend.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ def update_step_context(cls, step_context):
130130
kv_seqlens = step_context.kv_seqlens
131131
kv_start_loc = None
132132
kv_flatten_size = None
133+
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(q_seqlens, dim=0, dtype=torch.int32), (1, 0))
134+
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(kv_seqlens, dim=0, dtype=torch.int32), (1, 0))
133135
if not step_context.is_decoding:
134136
kv_start_loc = kv_seqlens.cumsum(0) - kv_seqlens
135137
kv_flatten_size = kv_seqlens.sum().item()
@@ -142,6 +144,8 @@ def update_step_context(cls, step_context):
142144
kv_seqlens=kv_seqlens,
143145
kv_flatten_size=kv_flatten_size,
144146
quant_policy=step_context.kv_quant_policy,
147+
cu_seqlens_q=cu_seqlens_q,
148+
cu_seqlens_k=cu_seqlens_k,
145149
)
146150
if getattr(step_context.model_config, 'use_flash_mla', False) is True:
147151
if step_context.is_decoding is True:

lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,8 @@ def flatten_kv_cache(k_caches: Tensor,
202202
k_scales_zeros: Tensor = None,
203203
v_scales_zeros: Tensor = None,
204204
quant_policy: Literal[0, 4, 8] = 0,
205-
kv_layout: str = 'bshd'):
205+
kv_layout: str = 'bshd',
206+
flatten_kv_layout: str = 'hsd'):
206207
"""Recovery paged kv cache to normal kv cache."""
207208
if kv_layout == 'bshd':
208209
b_dim, s_dim, h_dim, d_dim = (0, 1, 2, 3)
@@ -230,17 +231,34 @@ def flatten_kv_cache(k_caches: Tensor,
230231
BLOCK_DK = triton.next_power_of_2(k_head_dim)
231232
BLOCK_DV = triton.next_power_of_2(v_head_dim)
232233
BLOCK_BS = k_caches.size(s_dim)
233-
234-
k_states = k_caches.new_empty(num_heads, out_size, k_head_dim, dtype=out_dtype)
235-
236-
grid = (num_blocks, batch_size, num_heads)
237-
if quant_policy == 0:
238-
shared_kv = k_caches.data_ptr() == v_caches.data_ptr() and v_head_dim < k_head_dim
239-
if shared_kv:
234+
shared_kv = k_caches.data_ptr() == v_caches.data_ptr() and v_head_dim < k_head_dim
235+
if flatten_kv_layout == 'hsd':
236+
k_states = k_caches.new_empty(num_heads, out_size, k_head_dim, dtype=out_dtype)
237+
if quant_policy == 0 and shared_kv:
240238
v_states = k_states[..., :v_head_dim]
241239
v_head_dim = 0
242240
else:
243241
v_states = v_caches.new_empty(num_heads, out_size, v_head_dim, dtype=out_dtype)
242+
stride_koh = k_states.stride(0)
243+
stride_kos = k_states.stride(1)
244+
stride_voh = v_states.stride(0)
245+
stride_vos = v_states.stride(1)
246+
elif flatten_kv_layout == 'shd':
247+
k_states = k_caches.new_empty(out_size, num_heads, k_head_dim, dtype=out_dtype)
248+
if quant_policy == 0 and shared_kv:
249+
v_states = k_states[..., :v_head_dim]
250+
v_head_dim = 0
251+
else:
252+
v_states = v_caches.new_empty(out_size, num_heads, v_head_dim, dtype=out_dtype)
253+
stride_koh = k_states.stride(1)
254+
stride_kos = k_states.stride(0)
255+
stride_voh = v_states.stride(1)
256+
stride_vos = v_states.stride(0)
257+
else:
258+
raise RuntimeError('Unsupported layout.')
259+
260+
grid = (num_blocks, batch_size, num_heads)
261+
if quant_policy == 0:
244262
_flatten_kv_cache[grid](
245263
k_caches,
246264
v_caches,
@@ -257,11 +275,11 @@ def flatten_kv_cache(k_caches: Tensor,
257275
stride_vcs=v_caches.stride(s_dim),
258276
stride_vch=v_caches.stride(h_dim),
259277
stride_vcd=v_caches.stride(d_dim),
260-
stride_koh=k_states.stride(0),
261-
stride_kos=k_states.stride(1),
278+
stride_koh=stride_koh,
279+
stride_kos=stride_kos,
262280
stride_kod=k_states.stride(2),
263-
stride_voh=v_states.stride(0),
264-
stride_vos=v_states.stride(1),
281+
stride_voh=stride_voh,
282+
stride_vos=stride_vos,
265283
stride_vod=v_states.stride(2),
266284
stride_boff=block_offsets.stride(0),
267285
OUT_SIZE=out_size,
@@ -272,7 +290,6 @@ def flatten_kv_cache(k_caches: Tensor,
272290
BLOCK_DV=BLOCK_DV,
273291
)
274292
else:
275-
v_states = v_caches.new_empty(num_heads, out_size, v_head_dim, dtype=out_dtype)
276293
_flatten_kv_cache_quant[grid](
277294
k_caches,
278295
v_caches,
@@ -299,11 +316,11 @@ def flatten_kv_cache(k_caches: Tensor,
299316
stride_vszs=v_scales_zeros.stride(s_dim),
300317
stride_vszh=v_scales_zeros.stride(h_dim),
301318
stride_vszd=v_scales_zeros.stride(d_dim),
302-
stride_koh=k_states.stride(0),
303-
stride_kos=k_states.stride(1),
319+
stride_koh=stride_koh,
320+
stride_kos=stride_kos,
304321
stride_kod=k_states.stride(2),
305-
stride_voh=v_states.stride(0),
306-
stride_vos=v_states.stride(1),
322+
stride_voh=stride_voh,
323+
stride_vos=stride_vos,
307324
stride_vod=v_states.stride(2),
308325
stride_boff=block_offsets.stride(0),
309326
quant_policy=quant_policy,

0 commit comments

Comments
 (0)