Skip to content

Commit d72f284

Browse files
committed
support fa3
1 parent 7e2c8b8 commit d72f284

File tree

13 files changed

+156
-17
lines changed

13 files changed

+156
-17
lines changed

lmdeploy/pytorch/backends/attention.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(
3636
logit_softcapping: float = None,
3737
causal: bool = True,
3838
use_flash_mla: bool = False,
39+
use_flash_attn3: bool = False,
3940
**kwargs,
4041
) -> None:
4142
if scale is None:
@@ -57,6 +58,7 @@ def __init__(
5758
self.logit_softcapping = logit_softcapping
5859
self.causal = causal
5960
self.use_flash_mla = use_flash_mla
61+
self.use_flash_attn3 = use_flash_attn3
6062

6163
@abstractmethod
6264
def forward(
@@ -92,6 +94,7 @@ def build(
9294
logical_softcapping: float = None,
9395
causal: bool = True,
9496
use_flash_mla: bool = False,
97+
use_flash_attn3: bool = False,
9598
learnable_sink: bool = False,
9699
**kwargs,
97100
) -> AttentionImpl[T]:

lmdeploy/pytorch/backends/cuda/attention.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -407,8 +407,9 @@ def __init__(
407407
causal=causal,
408408
**kwargs,
409409
)
410-
from lmdeploy.pytorch.third_party.flash_attn_interface import flash_attn_varlen_func
410+
from lmdeploy.pytorch.third_party.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
411411
self.flash_attn_varlen_func_v3 = flash_attn_varlen_func
412+
self.flash_attn_with_kvcache_v3 = flash_attn_with_kvcache
412413

413414
def forward(
414415
self,
@@ -460,11 +461,10 @@ def forward(
460461
quant_policy=quant_policy,
461462
)
462463

463-
q_shape = query.shape
464-
o_shape = q_shape[:-1] + (self.v_head_size, )
465-
attn_output = query.new_empty(o_shape)
466-
467464
if is_decoding:
465+
q_shape = query.shape
466+
o_shape = q_shape[:-1] + (self.v_head_size, )
467+
attn_output = query.new_empty(o_shape)
468468
self.paged_attention_fwd(
469469
query,
470470
k_cache,
@@ -480,6 +480,24 @@ def forward(
480480
logit_softcapping=self.logit_softcapping,
481481
)
482482
else:
483+
sliding_window = (-1, -1) if self.sliding_window is None else self.sliding_window
484+
if isinstance(sliding_window, int):
485+
sliding_window = (sliding_window, sliding_window)
486+
attn_output = self.flash_attn_with_kvcache_v3(
487+
query,
488+
k_cache,
489+
v_cache,
490+
cache_seqlens=attn_metadata.kv_seqlens.to(torch.int32),
491+
cu_seqlens_q=attn_metadata.cu_seqlens_q,
492+
cu_seqlens_k_new=attn_metadata.cu_seqlens_k,
493+
max_seqlen_q=max_q_seqlen,
494+
page_table=block_offsets,
495+
softmax_scale=self.scale,
496+
causal=self.causal,
497+
window_size=sliding_window,
498+
softcap=-1.0 if self.logit_softcapping is None else self.logit_softcapping,
499+
)
500+
return attn_output
483501
flatten_k, flatten_v = self.flatten_kv_cache(
484502
k_cache,
485503
v_cache,
@@ -527,6 +545,7 @@ def build(
527545
logical_softcapping: float = None,
528546
causal: bool = True,
529547
use_flash_mla: bool = False,
548+
use_flash_attn3: bool = False,
530549
learnable_sink: bool = False,
531550
**kwargs,
532551
) -> TritonAttentionImpl:
@@ -542,7 +561,7 @@ def build(
542561
logical_softcapping=logical_softcapping,
543562
causal=causal,
544563
**kwargs)
545-
elif use_fa3 and not alibi and not learnable_sink:
564+
elif use_flash_attn3 and not alibi and not learnable_sink:
546565
return FA3Impl(num_heads,
547566
head_size,
548567
scale=scale,

lmdeploy/pytorch/backends/cuda/op_backend.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,16 @@ def update_step_context(cls, step_context):
130130
kv_seqlens = step_context.kv_seqlens
131131
kv_start_loc = None
132132
kv_flatten_size = None
133-
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(q_seqlens, dim=0, dtype=torch.int32), (1, 0))
134-
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(kv_seqlens, dim=0, dtype=torch.int32), (1, 0))
135-
if not step_context.is_decoding:
133+
use_flash_mla = getattr(step_context.model_config, 'use_flash_mla', False)
134+
use_flash_attn3 = getattr(step_context.model_config, 'use_flash_attn3', False)
135+
cu_seqlens_q = None
136+
cu_seqlens_k = None
137+
if use_flash_mla or use_flash_attn3:
138+
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(q_seqlens, dim=0, dtype=torch.int32), (1, 0))
139+
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(kv_seqlens, dim=0, dtype=torch.int32), (1, 0))
140+
step_context.block_offsets = step_context.block_offsets.to(torch.int32)
141+
142+
if (not step_context.is_decoding) and not use_flash_attn3:
136143
kv_start_loc = kv_seqlens.cumsum(0) - kv_seqlens
137144
kv_flatten_size = step_context.sum_kv_seqlen
138145
attn_metadata = attn_meta_cls(
@@ -147,7 +154,7 @@ def update_step_context(cls, step_context):
147154
cu_seqlens_q=cu_seqlens_q,
148155
cu_seqlens_k=cu_seqlens_k,
149156
)
150-
if getattr(step_context.model_config, 'use_flash_mla', False) is True:
157+
if use_flash_mla:
151158
if step_context.is_decoding is True:
152159
cls.update_meta_flashmla(attn_metadata, step_context.model_config.num_attention_heads)
153160

lmdeploy/pytorch/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ class ModelConfig:
201201
cogvlm_style: bool = False
202202
custom_module_map: Dict[str, setattr] = None
203203
use_flash_mla: bool = False
204+
use_flash_attn3: bool = False
204205

205206
def get_head_size(self):
206207
"""Get head size."""

lmdeploy/pytorch/configurations/llama.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .builder import AutoModelConfigBuilder
33
from .default import DefaultModelConfigBuilder
4+
from .utils import flash_attn_v3_available
45

56

67
class LlamaModelConfigBuilder(AutoModelConfigBuilder):
@@ -26,5 +27,6 @@ def build(cls, hf_config, model_path: str = None, is_draft_model: bool = False,
2627
num_layers = cfg.num_layers
2728
hf_config.aux_hidden_state_layers = (2, num_layers // 2, num_layers - 3)
2829
cfg.hf_config = hf_config
29-
30+
cfg.use_flash_attn3 = flash_attn_v3_available()
31+
cfg.hf_config.use_flash_attn3 = cfg.use_flash_attn3
3032
return cfg

lmdeploy/pytorch/configurations/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,18 @@ def flash_mla_available():
1919
except ImportError:
2020
logger.warning('For higher performance, please install flash_mla https://github.com/deepseek-ai/FlashMLA')
2121
return use_flash_mla
22+
23+
24+
def flash_attn_v3_available():
25+
"""Check if flash attn v3 is available."""
26+
use_fa3 = False
27+
try:
28+
# Now flash-attention only support FA3 for sm90a && cuda >= 12.3
29+
if (torch.cuda.get_device_capability()[0] == 9) and (torch.version.cuda >= '12.3'):
30+
import flash_attn_interface # noqa: F401
31+
assert torch.ops.flash_attn_3 is not None
32+
use_fa3 = True
33+
except Exception:
34+
logger.warning('For higher performance, please install FlashAttention-3 '
35+
'https://github.com/Dao-AILab/flash-attention')
36+
return use_fa3

lmdeploy/pytorch/engine/engine.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import numpy as np
1111
import torch
12+
from torch.profiler import record_function
1213

1314
from lmdeploy.messages import PytorchEngineConfig, RequestMetrics, ResponseType, SpeculativeConfig
1415
from lmdeploy.pytorch.disagg.config import EngineRole
@@ -747,6 +748,7 @@ def __has_values(input_multimodals):
747748

748749
@torch.inference_mode()
749750
@logging_timer('create_spec_inputs', logger)
751+
@record_function('create_spec_inputs')
750752
def _create_spec_inputs(self, messages: SeqList, token_ids: List[List[int]]):
751753
"""Create spec inputs from messages."""
752754

@@ -782,6 +784,7 @@ def _create_spec_inputs(self, messages: SeqList, token_ids: List[List[int]]):
782784

783785
@torch.inference_mode()
784786
@logging_timer('CreateModelInputs', logger)
787+
@record_function('CreateModelInputs')
785788
def create_model_inputs(self, messages: SeqList, is_prefill: bool):
786789
"""Create model inputs from messages.
787790
@@ -933,6 +936,7 @@ def _make_spec_stats(self, seqs: SeqList, next_token_ids: torch.LongTensor):
933936

934937
return all_stats
935938

939+
@record_function('make_infer_outputs')
936940
def _make_infer_outputs(
937941
self,
938942
batched_outputs: BatchedOutputs,

lmdeploy/pytorch/engine/model_agent.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,11 +1262,13 @@ def build_cache_engine(self):
12621262

12631263
def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap):
12641264
cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map)
1265-
output = model_forward(self.patched_model,
1266-
inputs,
1267-
self.cache_engine,
1268-
stream=self.stream,
1269-
output_position_ids=False)
1265+
output = model_forward(
1266+
self.patched_model,
1267+
inputs,
1268+
self.cache_engine,
1269+
stream=self.stream,
1270+
output_position_ids=self.spec_agent is not None,
1271+
)
12701272
return output
12711273

12721274
async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap):

lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,3 +333,79 @@ def flatten_kv_cache(k_caches: Tensor,
333333
)
334334

335335
return k_states, v_states
336+
337+
338+
@triton.testing.perf_report(
339+
triton.testing.Benchmark(
340+
x_names=['max_seq_len'],
341+
x_vals=[128 * i for i in range(1, 33)],
342+
line_arg='provider',
343+
line_vals=['hsd', 'shd'],
344+
line_names=['hsd', 'shd'],
345+
styles=[('blue', '-'), ('red', '-')],
346+
ylabel='time/ms',
347+
plot_name='bench-flatten-kvcache-performance',
348+
args={},
349+
))
350+
def bench_flatten_kv_cache(max_seq_len: int,
351+
batch_size: int = 128,
352+
num_blocks: int = 6400,
353+
block_size: int = 64,
354+
dtype: torch.dtype = torch.float16,
355+
provider='hsd'):
356+
"""Benchmark."""
357+
head_dim = 128
358+
num_head = 8
359+
seqlens = torch.tensor([max_seq_len] * batch_size, dtype=torch.long, device='cuda')
360+
block_offsets = torch.arange(batch_size * ((max_seq_len + block_size) // block_size),
361+
dtype=torch.int32,
362+
device='cuda').reshape(batch_size, -1)
363+
out_size = batch_size * ((max_seq_len + block_size) // block_size) * block_size
364+
start_loc = seqlens.cumsum(0) - seqlens
365+
366+
k_caches = torch.randn((num_blocks, block_size, num_head, head_dim),
367+
dtype=dtype,
368+
device='cuda',
369+
requires_grad=False)
370+
v_caches = torch.randn((num_blocks, block_size, num_head, head_dim),
371+
dtype=dtype,
372+
device='cuda',
373+
requires_grad=False)
374+
375+
def flatten_hsd():
376+
return flatten_kv_cache(k_caches,
377+
v_caches,
378+
seqlens,
379+
block_offsets,
380+
start_loc,
381+
out_size,
382+
flatten_kv_layout='hsd',
383+
kv_layout='bshd',
384+
out_dtype=dtype)
385+
386+
def flatten_shd():
387+
return flatten_kv_cache(k_caches,
388+
v_caches,
389+
seqlens,
390+
block_offsets,
391+
start_loc,
392+
out_size,
393+
flatten_kv_layout='shd',
394+
kv_layout='bshd',
395+
out_dtype=dtype)
396+
397+
if provider == 'hsd':
398+
flatten_op = flatten_hsd
399+
else:
400+
flatten_op = flatten_shd
401+
quantiles = [0.5, 0.2, 0.8]
402+
ms, min_ms, max_ms = triton.testing.do_bench(flatten_op, quantiles=quantiles, rep=500)
403+
404+
def perf(ms):
405+
return ms
406+
407+
return perf(ms), perf(max_ms), perf(min_ms)
408+
409+
410+
if __name__ == '__main__':
411+
bench_flatten_kv_cache.run(print_data=True, show_plots=True, save_path='perf_flatten_kv_cache')

lmdeploy/pytorch/models/llama.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(self, config: LlamaConfig, dtype: torch.dtype = None, device: torch
4949
head_dim,
5050
num_kv_heads=num_key_value_heads,
5151
v_head_size=head_dim,
52+
use_flash_attn3=getattr(config, 'use_flash_attn3', False),
5253
)
5354

5455
# o_proj

0 commit comments

Comments
 (0)