Skip to content

Commit 5b8ade8

Browse files
mxuaxhsliuustc0106
andauthored
[Diffusion][Bugfix] Fix the flash_attn backends selection logic (vllm-project#983)
Signed-off-by: mxuax <mxuax@connect.ust.hk> Signed-off-by: XU Mingshi <91017482+mxuax@users.noreply.github.com> Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
1 parent def3956 commit 5b8ade8

File tree

8 files changed

+183
-126
lines changed

8 files changed

+183
-126
lines changed

vllm_omni/diffusion/attention/backends/flash_attn.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,26 @@
1010
AttentionMetadata,
1111
)
1212

13+
# Import flash attention functions with fallback chain from utils/fa.py
14+
# FA3 (fa3_fwd_interface) -> FA3 (flash_attn_interface) -> FA2 (flash_attn)
15+
from vllm_omni.diffusion.attention.backends.utils.fa import (
16+
HAS_FLASH_ATTN,
17+
_pad_input,
18+
_unpad_input,
19+
_upad_input,
20+
flash_attn_func,
21+
flash_attn_varlen_func,
22+
)
23+
1324
logger = init_logger(__name__)
1425

26+
if not HAS_FLASH_ATTN:
27+
raise ImportError(
28+
"FlashAttentionBackend requires Flash Attention. "
29+
"Please install one of: fa3-fwd, flash-attention, or flash-attn. "
30+
"Otherwise, use SDPA backend by setting DIFFUSION_ATTENTION_BACKEND=TORCH_SDPA"
31+
)
32+
1533

1634
class FlashAttentionBackend(AttentionBackend):
1735
accept_output_buffer: bool = True
@@ -56,14 +74,6 @@ def forward_cuda(
5674
attn_metadata: AttentionMetadata = None,
5775
) -> torch.Tensor:
5876
"""CUDA/ROCm flash attention implementation."""
59-
from vllm_omni.diffusion.attention.backends.utils.fa import (
60-
_pad_input,
61-
_unpad_input,
62-
_upad_input,
63-
flash_attn_func,
64-
flash_attn_varlen_func,
65-
)
66-
6777
query_length = query.size(1)
6878
attention_mask = attn_metadata.attn_mask if attn_metadata is not None else None
6979
# Contains at least one padding token in the sequence
@@ -92,13 +102,16 @@ def forward_cuda(
92102
out = _pad_input(out_unpad, indices_q, query.size(0), query_length)
93103

94104
else:
95-
out: torch.Tensor = flash_attn_func(
105+
out = flash_attn_func(
96106
query,
97107
key,
98108
value,
99109
causal=self.causal,
100110
softmax_scale=self.softmax_scale,
101111
)
112+
# FA3 may return (out, lse) tuple, FA2 returns just out
113+
if isinstance(out, tuple):
114+
out = out[0]
102115
return out
103116

104117
def forward_npu(

vllm_omni/diffusion/attention/backends/ring/ring_globals.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,57 +3,80 @@
33
# Copyright (c) 2024, Jiarui Fang.
44
# Adapted from https://github.com/feifeibear/long-context-attention
55

6-
7-
# test if flash_attn is available
6+
# test if flash_attn (FA2) is available
87
try:
98
import flash_attn # noqa: F401
10-
from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward # noqa: F401
9+
from flash_attn.flash_attn_interface import _flash_attn_forward # noqa: F401
1110

1211
HAS_FLASH_ATTN = True
13-
except ImportError:
12+
except (ImportError, ModuleNotFoundError):
1413
HAS_FLASH_ATTN = False
1514

15+
# FA3 detection: try multiple sources (forward only, no backward needed for inference)
16+
# Source 1: flash_attn_interface (from flash-attention source build)
17+
# Source 2: fa3_fwd_interface (from fa3-fwd PyPI package, supports Ampere/Ada/Hopper)
18+
# Note: FA3 high-level API may or may not return softmax_lse depending on version.
19+
# For Ring Attention which requires LSE, we fall back to low-level API if needed.
20+
HAS_FA3 = False
21+
fa3_fwd_func = None # Low-level forward function (_flash_attn_forward)
22+
fa3_attn_func = None # High-level attention function (flash_attn_func)
23+
24+
# Try flash_attn_interface first (from flash-attention source build)
1625
try:
17-
from flash_attn_interface import _flash_attn_backward as flash_attn_func_hopper_backward # noqa: F401
18-
from flash_attn_interface import _flash_attn_forward as flash_attn_forward_hopper # noqa: F401
19-
from flash_attn_interface import flash_attn_func as flash3_attn_func # noqa: F401
26+
from flash_attn_interface import _flash_attn_forward as fa3_fwd_func # noqa: F401
27+
from flash_attn_interface import flash_attn_func as fa3_attn_func # noqa: F401
28+
29+
HAS_FA3 = True
30+
except (ImportError, ModuleNotFoundError):
31+
pass
32+
33+
# Fallback: try fa3_fwd_interface (PyPI package, supports Ampere/Ada/Hopper)
34+
if not HAS_FA3:
35+
try:
36+
from fa3_fwd_interface import _flash_attn_forward as fa3_fwd_func # noqa: F401
37+
from fa3_fwd_interface import flash_attn_func as fa3_attn_func # noqa: F401
38+
39+
HAS_FA3 = True
40+
except (ImportError, ModuleNotFoundError):
41+
pass
2042

21-
HAS_FLASH_ATTN_HOPPER = True
22-
except ImportError:
23-
HAS_FLASH_ATTN_HOPPER = False
43+
# Legacy aliases for backward compatibility
44+
HAS_FLASH_ATTN_HOPPER = HAS_FA3
45+
flash_attn_forward_hopper = fa3_fwd_func
46+
flash3_attn_func = fa3_attn_func
2447

2548
try:
2649
from flashinfer.prefill import single_prefill_with_kv_cache # noqa: F401
2750

2851
HAS_FLASHINFER = True
29-
except ImportError:
52+
except (ImportError, ModuleNotFoundError):
3053
HAS_FLASHINFER = False
3154

3255
try:
3356
import aiter # noqa: F401
3457
from aiter import flash_attn_func as flash_attn_func_aiter # noqa: F401
3558

3659
HAS_AITER = True
37-
except ImportError:
60+
except (ImportError, ModuleNotFoundError):
3861
HAS_AITER = False
3962

4063
try:
4164
import sageattention # noqa: F401
4265

4366
HAS_SAGE_ATTENTION = True
44-
except ImportError:
67+
except (ImportError, ModuleNotFoundError):
4568
HAS_SAGE_ATTENTION = False
4669

4770
try:
4871
import spas_sage_attn # noqa: F401
4972

5073
HAS_SPARSE_SAGE_ATTENTION = True
51-
except ImportError:
74+
except (ImportError, ModuleNotFoundError):
5275
HAS_SPARSE_SAGE_ATTENTION = False
5376

5477
try:
5578
import torch_npu # noqa: F401
5679

5780
HAS_NPU = True
58-
except ImportError:
81+
except (ImportError, ModuleNotFoundError):
5982
HAS_NPU = False

vllm_omni/diffusion/attention/backends/ring/ring_kernels.py

Lines changed: 29 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@
66

77
import torch
88

9-
from .ring_globals import HAS_AITER, HAS_FLASH_ATTN, HAS_FLASH_ATTN_HOPPER, HAS_FLASHINFER, HAS_NPU
9+
from .ring_globals import (
10+
HAS_AITER,
11+
HAS_FA3,
12+
HAS_FLASH_ATTN,
13+
HAS_FLASHINFER,
14+
fa3_fwd_func,
15+
)
1016

1117
_scaled_dot_product_flash_attention = torch.ops.aten._scaled_dot_product_flash_attention
1218
_scaled_dot_product_efficient_attention = torch.ops.aten._scaled_dot_product_efficient_attention
@@ -26,23 +32,11 @@
2632
import flash_attn
2733
from flash_attn.flash_attn_interface import _flash_attn_forward
2834

29-
if HAS_FLASH_ATTN_HOPPER:
30-
from flash_attn_interface import _flash_attn_backward as flash_attn_func_hopper_backward
31-
from flash_attn_interface import _flash_attn_forward as flash_attn_forward_hopper
32-
from flash_attn_interface import flash_attn_func as flash3_attn_func
33-
else:
34-
flash_attn_forward_hopper = None
35-
flash_attn_func_hopper_backward = None
36-
flash3_attn_func = None
37-
3835
if HAS_FLASHINFER:
3936
from flashinfer.prefill import single_prefill_with_kv_cache
4037

4138
_LOG2_E = math.log2(math.e)
4239

43-
if HAS_NPU:
44-
import torch_npu
45-
4640

4741
def pytorch_attn_forward(
4842
q: torch.Tensor,
@@ -146,52 +140,35 @@ def flash_attn_forward(
146140
return block_out, block_lse
147141

148142

149-
def flash_attn3_func_forward(
150-
q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax
151-
):
152-
assert HAS_FLASH_ATTN_HOPPER
153-
# current signature of flash_attn_forward_hopper:
154-
# (q, k, v, softmax_scale, causal, window_size, descale_q=None, descale_k=None, descale_v=None, gqa_parallel=False)
143+
def fa3_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax):
144+
"""FA3 forward pass for inference.
145+
146+
FA3 supports Ampere, Ada, and Hopper GPUs. Dropout is ignored since FA3 is inference-only.
147+
Uses low-level API (_flash_attn_forward) which always returns softmax_lse,
148+
required for Ring Attention's correct accumulation.
149+
"""
150+
assert HAS_FA3, "FA3 is not available"
151+
assert fa3_fwd_func is not None, "FA3 low-level API (fa3_fwd_func) not available"
155152

156-
out, softmax_lse, *unused = flash_attn_forward_hopper(
157-
q=q,
158-
k=k,
159-
v=v,
160-
k_new=None,
161-
v_new=None,
162-
qv=None,
163-
out=None,
164-
cu_seqlens_q=None,
165-
cu_seqlens_k=None,
166-
cu_seqlens_k_new=None,
167-
seqused_q=None,
168-
seqused_k=None,
169-
max_seqlen_q=None,
170-
max_seqlen_k=None,
171-
page_table=None,
172-
kv_batch_idx=None,
173-
leftpad_k=None,
174-
rotary_cos=None,
175-
rotary_sin=None,
176-
seqlens_rotary=None,
177-
q_descale=None,
178-
k_descale=None,
179-
v_descale=None,
153+
# Low-level API always returns (out, softmax_lse, S_dmask, rng_state)
154+
out, softmax_lse, *_ = fa3_fwd_func(
155+
q,
156+
k,
157+
v,
180158
softmax_scale=softmax_scale,
181-
causal=False,
182-
window_size=(-1, -1),
183-
attention_chunk=0,
184-
softcap=0.0,
185-
rotary_interleaved=True,
186-
scheduler_metadata=None,
187-
num_splits=0,
188-
pack_gqa=None,
189-
sm_margin=0,
159+
causal=causal,
160+
window_size_left=window_size[0] if window_size else -1,
161+
window_size_right=window_size[1] if window_size else -1,
162+
softcap=softcap if softcap else 0.0,
190163
)
191164

192165
return out, softmax_lse
193166

194167

168+
# Legacy alias for backward compatibility
169+
flash_attn3_func_forward = fa3_forward
170+
171+
195172
def flash_attn_forward_aiter(
196173
q,
197174
k,
@@ -264,20 +241,3 @@ def flashinfer_attn_forward(
264241
raise ValueError(f"Invalid input shape: {q.shape}")
265242
lse = lse / _LOG2_E
266243
return out, lse
267-
268-
269-
def npu_attn_forward(q, k, v, softmax_scale=None, layout="BSND"):
270-
assert HAS_NPU, "torch_npu is not available"
271-
softmax_scale = q.shape[-1] ** (-0.5)
272-
block_out, block_lse = torch_npu.npu_fused_infer_attention_score(
273-
q,
274-
k,
275-
v,
276-
num_heads=q.shape[-2],
277-
input_layout=layout,
278-
scale=softmax_scale,
279-
softmax_lse_flag=True,
280-
pre_tokens=65535,
281-
next_tokens=65535,
282-
)
283-
return block_out, block_lse.squeeze(dim=-1)

vllm_omni/diffusion/attention/backends/ring/ring_selector.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import torch
1111

1212
from .ring_globals import (
13-
HAS_NPU,
1413
HAS_SAGE_ATTENTION,
1514
HAS_SPARSE_SAGE_ATTENTION,
1615
)
@@ -28,9 +27,6 @@
2827
if HAS_SPARSE_SAGE_ATTENTION:
2928
from spas_sage_attn.autotune import SparseAttentionMeansim
3029

31-
if HAS_NPU:
32-
from torch_npu import npu_fused_infer_attention_score
33-
3430

3531
class AttnType(Enum):
3632
AITER = "aiter"
@@ -44,7 +40,6 @@ class AttnType(Enum):
4440
SAGE_FP8 = "sage_fp8"
4541
SAGE_FP8_SM90 = "sage_fp8_sm90"
4642
SPARSE_SAGE = "sparse_sage"
47-
NPU = "npu"
4843

4944
@classmethod
5045
def from_string(cls, s: str):
@@ -157,11 +152,6 @@ def fn(q, k, v, causal=False, softmax_scale=None, *args, **kwargs):
157152

158153
return fn
159154

160-
elif impl_type == AttnType.NPU:
161-
if not HAS_NPU:
162-
raise ImportError("torch_npu is not available!")
163-
return npu_fused_infer_attention_score
164-
165155
elif attn_processor is not None:
166156
return attn_processor
167157

vllm_omni/diffusion/attention/backends/utils/fa.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,53 @@
1717

1818
from vllm_omni.platforms import current_omni_platform
1919

20+
# Flash Attention function detection with fallback chain
21+
flash_attn_func = None
22+
flash_attn_varlen_func = None
23+
2024
if current_omni_platform.is_rocm():
21-
from vllm._aiter_ops import is_aiter_found_and_supported
25+
# ROCm: try Aiter first
26+
try:
27+
from vllm._aiter_ops import is_aiter_found_and_supported
2228

23-
# Choose to enable this by default on ROCm
24-
# Whenever possible as it is the fastest backend
25-
if is_aiter_found_and_supported():
26-
from aiter import flash_attn_func, flash_attn_varlen_func # noqa: F401
27-
else:
28-
raise ImportError("Aiter is not found and supported on currentROCm device.")
29+
if is_aiter_found_and_supported():
30+
from aiter import flash_attn_func, flash_attn_varlen_func # noqa: F401
31+
except (ImportError, ModuleNotFoundError):
32+
pass
2933
else:
30-
from fa3_fwd_interface import flash_attn_func, flash_attn_varlen_func # noqa: F401
34+
# CUDA: try FA3 -> FA2 fallback chain
35+
# Try FA3 from fa3-fwd PyPI package
36+
try:
37+
from fa3_fwd_interface import flash_attn_func, flash_attn_varlen_func # noqa: F401
38+
except (ImportError, ModuleNotFoundError):
39+
pass
40+
41+
# Fallback: Try FA3 from flash-attention source build
42+
if flash_attn_func is None:
43+
try:
44+
from flash_attn_interface import flash_attn_func, flash_attn_varlen_func # noqa: F401
45+
except (ImportError, ModuleNotFoundError):
46+
pass
47+
48+
# Fallback: Try FA2 from flash-attn package (try multiple import paths)
49+
if flash_attn_func is None:
50+
try:
51+
from flash_attn import flash_attn_func, flash_attn_varlen_func # noqa: F401
52+
except (ImportError, ModuleNotFoundError):
53+
pass
54+
55+
if flash_attn_func is None:
56+
try:
57+
from flash_attn.flash_attn_interface import ( # noqa: F401
58+
flash_attn_func,
59+
flash_attn_varlen_func,
60+
)
61+
except (ImportError, ModuleNotFoundError):
62+
pass
63+
64+
# If no FA backend available, SDPA backend will be selected at the platform level
65+
# flash_attn_func and flash_attn_varlen_func will be None
66+
HAS_FLASH_ATTN = flash_attn_func is not None
3167

3268

3369
def _index_first_axis(tensor, indices):

0 commit comments

Comments
 (0)