23
23
from vllm .model_executor .layers .quantization .base_config import (
24
24
QuantizationConfig )
25
25
from vllm .model_executor .layers .quantization .kv_cache import BaseKVCacheMethod
26
+ from vllm .model_executor .models .vision import get_vit_attn_backend
26
27
from vllm .platforms import _Backend , current_platform
27
28
from vllm .utils import direct_register_custom_op
28
29
@@ -55,6 +56,14 @@ def check_xformers_availability():
55
56
return USE_XFORMERS_OPS
56
57
57
58
59
+ def check_upstream_fa_availability (dtype : torch .dtype ):
60
+ if dtype in (torch .float16 , torch .bfloat16 ) and current_platform .is_cuda (
61
+ ) and current_platform .has_device_capability (80 ):
62
+ from transformers .utils import is_flash_attn_2_available
63
+ return is_flash_attn_2_available ()
64
+ return False
65
+
66
+
58
67
class Attention (nn .Module , AttentionLayerBase ):
59
68
"""Attention layer.
60
69
@@ -349,29 +358,55 @@ def __init__(
349
358
f"divisible by num_kv_heads ({ self .num_kv_heads } )"
350
359
self .num_queries_per_kv = self .num_heads // self .num_kv_heads
351
360
361
+ # During model initialization, the default dtype is set as the model
362
+ # weight and activation dtype.
352
363
dtype = torch .get_default_dtype ()
353
- attn_backend = get_attn_backend (head_size ,
354
- dtype ,
355
- kv_cache_dtype = None ,
356
- block_size = 16 ,
357
- is_attention_free = False )
358
- backend = backend_name_to_enum (attn_backend .get_name ())
364
+
365
+ # Determine the attention backend
366
+ backend = get_vit_attn_backend (head_size = head_size , dtype = dtype )
367
+
368
+ # Some auto-selected backends can be upgraded
369
+ # to upstream flash attention if available.
370
+ # If vllm native fa is selected, we use it directly.
371
+ use_upstream_fa = False
372
+ if backend != _Backend .FLASH_ATTN and check_upstream_fa_availability (
373
+ dtype ):
374
+ backend = _Backend .FLASH_ATTN
375
+ use_upstream_fa = True
376
+
359
377
if current_platform .is_rocm ():
360
378
# currently, only torch_sdpa is supported on rocm
361
379
self .attn_backend = _Backend .TORCH_SDPA
362
380
else :
381
+
363
382
self .attn_backend = backend if backend in {
364
383
_Backend .TORCH_SDPA ,
365
384
_Backend .TORCH_SDPA_VLLM_V1 ,
366
385
_Backend .XFORMERS ,
367
386
_Backend .PALLAS_VLLM_V1 ,
368
387
_Backend .ROCM_AITER_FA ,
369
- } else current_platform .get_vit_attn_backend ()
388
+ _Backend .FLASH_ATTN ,
389
+ _Backend .FLASH_ATTN_VLLM_V1 ,
390
+ } else _Backend .TORCH_SDPA
370
391
371
392
if (self .attn_backend == _Backend .XFORMERS
372
393
and not check_xformers_availability ()):
373
394
self .attn_backend = _Backend .TORCH_SDPA
374
395
396
+ if self .attn_backend in {
397
+ _Backend .FLASH_ATTN , _Backend .FLASH_ATTN_VLLM_V1
398
+ }:
399
+ if use_upstream_fa :
400
+ from flash_attn import flash_attn_varlen_func
401
+ self ._flash_attn_varlen_func = flash_attn_varlen_func
402
+ else :
403
+ from vllm .vllm_flash_attn import flash_attn_varlen_func
404
+ self ._flash_attn_varlen_func = flash_attn_varlen_func
405
+
406
+ logger .info_once (
407
+ f"MultiHeadAttention attn_backend: { self .attn_backend } , "
408
+ f"use_upstream_fa: { use_upstream_fa } " )
409
+
375
410
def forward (
376
411
self ,
377
412
query : torch .Tensor ,
@@ -392,7 +427,31 @@ def forward(
392
427
key = torch .repeat_interleave (key , num_repeat , dim = 2 )
393
428
value = torch .repeat_interleave (value , num_repeat , dim = 2 )
394
429
395
- if self .attn_backend == _Backend .XFORMERS :
430
+ if self .attn_backend in {
431
+ _Backend .FLASH_ATTN ,
432
+ _Backend .FLASH_ATTN_VLLM_V1 ,
433
+ }:
434
+
435
+ cu_seqlens_q = torch .arange (0 , (bsz + 1 ) * q_len ,
436
+ step = q_len ,
437
+ dtype = torch .int32 ,
438
+ device = query .device )
439
+ cu_seqlens_k = torch .arange (0 , (bsz + 1 ) * kv_len ,
440
+ step = kv_len ,
441
+ dtype = torch .int32 ,
442
+ device = key .device )
443
+
444
+ out = self ._flash_attn_varlen_func (
445
+ query .flatten (0 , 1 ),
446
+ key .flatten (0 , 1 ),
447
+ value .flatten (0 , 1 ),
448
+ cu_seqlens_q = cu_seqlens_q ,
449
+ cu_seqlens_k = cu_seqlens_k ,
450
+ max_seqlen_q = q_len ,
451
+ max_seqlen_k = kv_len ,
452
+ softmax_scale = self .scale ,
453
+ )
454
+ elif self .attn_backend == _Backend .XFORMERS :
396
455
from xformers import ops as xops
397
456
398
457
out = xops .memory_efficient_attention_forward (query ,
0 commit comments