Skip to content

Commit 00bd07b

Browse files
[XPU] Update Bagel's flash_attn_varlen_func to fa utils (vllm-project#1295)
Signed-off-by: zhenwei-intel <zhenwei.liu@intel.com>
1 parent 36a7971 commit 00bd07b

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

vllm_omni/diffusion/attention/backends/utils/fa.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@
3030
from aiter import flash_attn_func, flash_attn_varlen_func # noqa: F401
3131
except (ImportError, ModuleNotFoundError):
3232
pass
33+
elif current_omni_platform.is_xpu():
34+
try:
35+
from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func # noqa: F401
36+
except (ImportError, ModuleNotFoundError):
37+
pass
3338
else:
3439
# CUDA: try FA3 -> FA2 fallback chain
3540
# Try FA3 from fa3-fwd PyPI package

vllm_omni/diffusion/models/bagel/bagel_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
3131
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
3232
from vllm.transformers_utils.configs.bagel import BagelConfig
33-
from vllm.vllm_flash_attn import flash_attn_varlen_func
3433

34+
from vllm_omni.diffusion.attention.backends.utils.fa import flash_attn_varlen_func
3535
from vllm_omni.diffusion.layers.rope import RotaryEmbedding
3636

3737

0 commit comments

Comments
 (0)