Skip to content

Commit 72fc8aa

Browse files
authored
[Multi Modal] Add FA3 in VIT (vllm-project#24347)
Signed-off-by: wwl2755 <[email protected]>
1 parent fdb09c7 commit 72fc8aa

File tree

13 files changed

+247
-66
lines changed

13 files changed

+247
-66
lines changed

tests/entrypoints/openai/test_vision.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@
3434
],
3535
[
3636
"The image shows a Venn diagram with three over",
37-
"The image shows a Venn diagram with three intersect",
37+
"This image shows a Venn diagram with three over",
3838
],
3939
[
4040
"This image displays a gradient of colors ranging from",
41-
"The image displays a gradient of colors ranging from",
41+
"This image displays a gradient of colors forming a spectrum",
4242
],
4343
]
4444

tests/kernels/attention/test_mha_attn.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,31 +36,52 @@ def test_mha_attn_platform(device: str):
3636
torch.set_default_dtype(torch.float16)
3737

3838
if device == "cpu":
39-
with patch("vllm.attention.selector.current_platform",
40-
CpuPlatform()), \
41-
patch("vllm.platforms.current_platform", CpuPlatform()):
39+
with patch("vllm.attention.layer.current_platform", CpuPlatform()), \
40+
patch("vllm.model_executor.models.vision.current_platform",
41+
CpuPlatform()):
4242
attn = MultiHeadAttention(16, 64, scale=1)
43-
assert attn.attn_backend == _Backend.TORCH_SDPA_VLLM_V1
43+
assert attn.attn_backend == _Backend.TORCH_SDPA
4444
elif device == "hip":
45-
with patch("vllm.attention.selector.current_platform",
46-
RocmPlatform()), \
47-
patch("vllm.platforms.current_platform", RocmPlatform()), \
48-
patch("vllm.attention.layer.current_platform", RocmPlatform()):
45+
with patch("vllm.attention.layer.current_platform", RocmPlatform()), \
46+
patch("vllm.model_executor.models.vision.current_platform",
47+
RocmPlatform()):
4948
attn = MultiHeadAttention(16, 64, scale=1)
5049
assert attn.attn_backend == _Backend.TORCH_SDPA
5150
else:
52-
with patch("vllm.attention.selector.current_platform",
53-
CudaPlatform()), \
54-
patch("vllm.platforms.current_platform", CudaPlatform()):
51+
# Test CUDA with head_size=64 (divisible by 32)
52+
# - should use vLLM's FlashAttention
53+
with patch("vllm.attention.layer.current_platform", CudaPlatform()), \
54+
patch("vllm.model_executor.models.vision.current_platform",
55+
CudaPlatform()):
5556
attn = MultiHeadAttention(16, 64, scale=1)
56-
assert attn.attn_backend == _Backend.XFORMERS
57+
assert attn.attn_backend == _Backend.FLASH_ATTN
5758

58-
with patch("vllm.attention.selector.current_platform",
59+
# Test CUDA with head_size=72 (not divisible by 32)
60+
# - with upstream FA not available
61+
# - should use xformers
62+
with patch("vllm.attention.layer.current_platform", CudaPlatform()), \
63+
patch("vllm.model_executor.models.vision.current_platform",
5964
CudaPlatform()), \
60-
patch("vllm.platforms.current_platform", CudaPlatform()):
65+
patch("vllm.attention.layer.check_upstream_fa_availability",
66+
return_value=False):
6167
attn = MultiHeadAttention(16, 72, scale=1)
6268
assert attn.attn_backend == _Backend.XFORMERS
6369

70+
# Test CUDA with head_size=72 (not divisible by 32)
71+
# - with upstream FA available
72+
# - should use upstream FA
73+
with patch("vllm.attention.layer.current_platform", CudaPlatform()), \
74+
patch("vllm.model_executor.models.vision.current_platform",
75+
CudaPlatform()), \
76+
patch("vllm.attention.layer.check_upstream_fa_availability",
77+
return_value=True), \
78+
patch.dict('sys.modules', {'flash_attn': type('MockFlashAttn', (),
79+
{
80+
'flash_attn_varlen_func': lambda *args, **kwargs: None
81+
})()}):
82+
attn = MultiHeadAttention(16, 72, scale=1)
83+
assert attn.attn_backend == _Backend.FLASH_ATTN
84+
6485

6586
def ref_attention(
6687
query: torch.Tensor,

vllm/attention/layer.py

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from vllm.model_executor.layers.quantization.base_config import (
2424
QuantizationConfig)
2525
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
26+
from vllm.model_executor.models.vision import get_vit_attn_backend
2627
from vllm.platforms import _Backend, current_platform
2728
from vllm.utils import direct_register_custom_op
2829

@@ -55,6 +56,14 @@ def check_xformers_availability():
5556
return USE_XFORMERS_OPS
5657

5758

59+
def check_upstream_fa_availability(dtype: torch.dtype):
60+
if dtype in (torch.float16, torch.bfloat16) and current_platform.is_cuda(
61+
) and current_platform.has_device_capability(80):
62+
from transformers.utils import is_flash_attn_2_available
63+
return is_flash_attn_2_available()
64+
return False
65+
66+
5867
class Attention(nn.Module, AttentionLayerBase):
5968
"""Attention layer.
6069
@@ -349,29 +358,55 @@ def __init__(
349358
f"divisible by num_kv_heads ({self.num_kv_heads})"
350359
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
351360

361+
# During model initialization, the default dtype is set as the model
362+
# weight and activation dtype.
352363
dtype = torch.get_default_dtype()
353-
attn_backend = get_attn_backend(head_size,
354-
dtype,
355-
kv_cache_dtype=None,
356-
block_size=16,
357-
is_attention_free=False)
358-
backend = backend_name_to_enum(attn_backend.get_name())
364+
365+
# Determine the attention backend
366+
backend = get_vit_attn_backend(head_size=head_size, dtype=dtype)
367+
368+
# Some auto-selected backends can be upgraded
369+
# to upstream flash attention if available.
370+
# If vllm native fa is selected, we use it directly.
371+
use_upstream_fa = False
372+
if backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
373+
dtype):
374+
backend = _Backend.FLASH_ATTN
375+
use_upstream_fa = True
376+
359377
if current_platform.is_rocm():
360378
# currently, only torch_sdpa is supported on rocm
361379
self.attn_backend = _Backend.TORCH_SDPA
362380
else:
381+
363382
self.attn_backend = backend if backend in {
364383
_Backend.TORCH_SDPA,
365384
_Backend.TORCH_SDPA_VLLM_V1,
366385
_Backend.XFORMERS,
367386
_Backend.PALLAS_VLLM_V1,
368387
_Backend.ROCM_AITER_FA,
369-
} else current_platform.get_vit_attn_backend()
388+
_Backend.FLASH_ATTN,
389+
_Backend.FLASH_ATTN_VLLM_V1,
390+
} else _Backend.TORCH_SDPA
370391

371392
if (self.attn_backend == _Backend.XFORMERS
372393
and not check_xformers_availability()):
373394
self.attn_backend = _Backend.TORCH_SDPA
374395

396+
if self.attn_backend in {
397+
_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1
398+
}:
399+
if use_upstream_fa:
400+
from flash_attn import flash_attn_varlen_func
401+
self._flash_attn_varlen_func = flash_attn_varlen_func
402+
else:
403+
from vllm.vllm_flash_attn import flash_attn_varlen_func
404+
self._flash_attn_varlen_func = flash_attn_varlen_func
405+
406+
logger.info_once(
407+
f"MultiHeadAttention attn_backend: {self.attn_backend}, "
408+
f"use_upstream_fa: {use_upstream_fa}")
409+
375410
def forward(
376411
self,
377412
query: torch.Tensor,
@@ -392,7 +427,31 @@ def forward(
392427
key = torch.repeat_interleave(key, num_repeat, dim=2)
393428
value = torch.repeat_interleave(value, num_repeat, dim=2)
394429

395-
if self.attn_backend == _Backend.XFORMERS:
430+
if self.attn_backend in {
431+
_Backend.FLASH_ATTN,
432+
_Backend.FLASH_ATTN_VLLM_V1,
433+
}:
434+
435+
cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
436+
step=q_len,
437+
dtype=torch.int32,
438+
device=query.device)
439+
cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len,
440+
step=kv_len,
441+
dtype=torch.int32,
442+
device=key.device)
443+
444+
out = self._flash_attn_varlen_func(
445+
query.flatten(0, 1),
446+
key.flatten(0, 1),
447+
value.flatten(0, 1),
448+
cu_seqlens_q=cu_seqlens_q,
449+
cu_seqlens_k=cu_seqlens_k,
450+
max_seqlen_q=q_len,
451+
max_seqlen_k=kv_len,
452+
softmax_scale=self.scale,
453+
)
454+
elif self.attn_backend == _Backend.XFORMERS:
396455
from xformers import ops as xops
397456

398457
out = xops.memory_efficient_attention_forward(query,

vllm/model_executor/models/ernie45_vl.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from einops import rearrange, repeat
3535
from transformers import BatchFeature
3636

37+
from vllm.attention.layer import check_upstream_fa_availability
3738
from vllm.config import VllmConfig
3839
from vllm.distributed import parallel_state
3940
from vllm.distributed import utils as dist_utils
@@ -170,7 +171,16 @@ def __init__(
170171
prefix=f"{prefix}.proj")
171172

172173
# Detect attention implementation.
173-
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
174+
self.attn_backend = get_vit_attn_backend(
175+
head_size=self.hidden_size_per_attention_head,
176+
dtype=torch.get_default_dtype())
177+
178+
self.use_upstream_fa = False
179+
if self.attn_backend != _Backend.FLASH_ATTN and \
180+
check_upstream_fa_availability(torch.get_default_dtype()):
181+
self.attn_backend = _Backend.FLASH_ATTN
182+
self.use_upstream_fa = True
183+
174184
if self.attn_backend not in {
175185
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
176186
_Backend.ROCM_AITER_FA
@@ -233,7 +243,10 @@ def forward(
233243
if self.attn_backend == _Backend.ROCM_AITER_FA:
234244
from aiter import flash_attn_varlen_func
235245
else:
236-
from flash_attn import flash_attn_varlen_func
246+
if self.use_upstream_fa:
247+
from flash_attn import flash_attn_varlen_func
248+
else:
249+
from vllm.vllm_flash_attn import flash_attn_varlen_func
237250

238251
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
239252

@@ -457,7 +470,11 @@ def __init__(
457470
), "vit's config.hidden must be equal to config.embed_dim"
458471
self.ln = nn.LayerNorm(hidden_size, eps=1e-6)
459472

460-
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
473+
self.attn_backend = get_vit_attn_backend(
474+
head_size=head_dim, dtype=torch.get_default_dtype())
475+
if self.attn_backend != _Backend.FLASH_ATTN and \
476+
check_upstream_fa_availability(torch.get_default_dtype()):
477+
self.attn_backend = _Backend.FLASH_ATTN
461478

462479
@property
463480
def dtype(self) -> torch.dtype:

vllm/model_executor/models/glm4_1v.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
Glm4vVideoProcessor)
4545
from transformers.video_utils import VideoMetadata
4646

47+
from vllm.attention.layer import check_upstream_fa_availability
4748
from vllm.config import VllmConfig
4849
from vllm.distributed import (get_tensor_model_parallel_world_size,
4950
parallel_state)
@@ -260,7 +261,15 @@ def __init__(
260261
)
261262

262263
# Detect attention implementation.
263-
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
264+
self.attn_backend = get_vit_attn_backend(
265+
head_size=self.hidden_size_per_attention_head,
266+
dtype=torch.get_default_dtype())
267+
self.use_upstream_fa = False
268+
if self.attn_backend != _Backend.FLASH_ATTN and \
269+
check_upstream_fa_availability(torch.get_default_dtype()):
270+
self.attn_backend = _Backend.FLASH_ATTN
271+
self.use_upstream_fa = True
272+
264273
if self.attn_backend not in {
265274
_Backend.FLASH_ATTN,
266275
_Backend.TORCH_SDPA,
@@ -310,7 +319,10 @@ def forward(
310319
if self.attn_backend == _Backend.FLASH_ATTN:
311320
# from vllm_flash_attn.flash_attn_interface import (
312321
# flash_attn_varlen_func)
313-
from flash_attn import flash_attn_varlen_func
322+
if self.use_upstream_fa:
323+
from flash_attn import flash_attn_varlen_func
324+
else:
325+
from vllm.vllm_flash_attn import flash_attn_varlen_func
314326

315327
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
316328

@@ -715,7 +727,11 @@ def __init__(
715727
self.post_layernorm = RMSNorm(vision_config.hidden_size,
716728
eps=vision_config.rms_norm_eps)
717729

718-
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
730+
self.attn_backend = get_vit_attn_backend(
731+
head_size=head_dim, dtype=torch.get_default_dtype())
732+
if self.attn_backend != _Backend.FLASH_ATTN and \
733+
check_upstream_fa_availability(torch.get_default_dtype()):
734+
self.attn_backend = _Backend.FLASH_ATTN
719735

720736
@property
721737
def dtype(self) -> torch.dtype:

vllm/model_executor/models/keye.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
BaseModelOutputWithPooling)
1818
from transformers.utils import torch_int
1919

20+
from vllm.attention.layer import check_upstream_fa_availability
2021
from vllm.config import VllmConfig
2122
from vllm.distributed import get_tensor_model_parallel_world_size
2223
from vllm.logger import init_logger
@@ -374,7 +375,16 @@ def __init__(
374375
)
375376

376377
# Detect attention implementation.
377-
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
378+
self.attn_backend = get_vit_attn_backend(
379+
head_size=self.head_dim, dtype=torch.get_default_dtype())
380+
381+
self.use_upstream_fa = False
382+
if self.attn_backend != _Backend.FLASH_ATTN and \
383+
check_upstream_fa_availability(
384+
torch.get_default_dtype()):
385+
self.attn_backend = _Backend.FLASH_ATTN
386+
self.use_upstream_fa = True
387+
378388
if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}:
379389
raise RuntimeError(
380390
f"Keye-VL does not support {self.attn_backend} backend now.")
@@ -428,7 +438,10 @@ def forward(
428438
)
429439

430440
if self.attn_backend == _Backend.FLASH_ATTN:
431-
from flash_attn import flash_attn_varlen_func
441+
if self.use_upstream_fa:
442+
from flash_attn import flash_attn_varlen_func
443+
else:
444+
from vllm.vllm_flash_attn import flash_attn_varlen_func
432445

433446
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
434447

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
3939
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
4040

41+
from vllm.attention.layer import check_upstream_fa_availability
4142
from vllm.config import VllmConfig
4243
from vllm.distributed import parallel_state
4344
from vllm.distributed import utils as dist_utils
@@ -298,7 +299,16 @@ def __init__(
298299
disable_tp=use_data_parallel)
299300

300301
# Detect attention implementation.
301-
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
302+
self.attn_backend = get_vit_attn_backend(
303+
head_size=self.hidden_size_per_attention_head,
304+
dtype=torch.get_default_dtype())
305+
self.use_upstream_fa = False
306+
if self.attn_backend != _Backend.FLASH_ATTN and \
307+
check_upstream_fa_availability(
308+
torch.get_default_dtype()):
309+
self.attn_backend = _Backend.FLASH_ATTN
310+
self.use_upstream_fa = True
311+
302312
if self.attn_backend not in {
303313
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
304314
_Backend.ROCM_AITER_FA
@@ -359,7 +369,10 @@ def forward(
359369
if self.attn_backend == _Backend.ROCM_AITER_FA:
360370
from aiter import flash_attn_varlen_func
361371
else:
362-
from flash_attn import flash_attn_varlen_func
372+
if self.use_upstream_fa:
373+
from flash_attn import flash_attn_varlen_func
374+
else:
375+
from vllm.vllm_flash_attn import flash_attn_varlen_func
363376

364377
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
365378

@@ -628,7 +641,12 @@ def __init__(
628641
prefix=f"{prefix}.merger",
629642
use_data_parallel=use_data_parallel,
630643
)
631-
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
644+
self.attn_backend = get_vit_attn_backend(
645+
head_size=head_dim, dtype=torch.get_default_dtype())
646+
if self.attn_backend != _Backend.FLASH_ATTN and \
647+
check_upstream_fa_availability(
648+
torch.get_default_dtype()):
649+
self.attn_backend = _Backend.FLASH_ATTN
632650

633651
@property
634652
def dtype(self) -> torch.dtype:

0 commit comments

Comments
 (0)