Skip to content

Commit 827fc15

Browse files
committed
feat: try loading fa3 using kernels when available.
1 parent cf1ca72 commit 827fc15

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
3939

4040

41+
logger = get_logger(__name__) # pylint: disable=invalid-name
42+
4143
_REQUIRED_FLASH_VERSION = "2.6.3"
4244
_REQUIRED_SAGE_VERSION = "2.1.1"
4345
_REQUIRED_FLEX_VERSION = "2.5.0"
@@ -52,6 +54,7 @@
5254
_CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION)
5355
_CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION)
5456

57+
_DEFAULT_HUB_ID_FA3 = "kernels-community/vllm-flash-attn3"
5558

5659
if _CAN_USE_FLASH_ATTN:
5760
from flash_attn import flash_attn_func, flash_attn_varlen_func
@@ -64,8 +67,16 @@
6467
from flash_attn_interface import flash_attn_func as flash_attn_3_func
6568
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
6669
else:
67-
flash_attn_3_func = None
68-
flash_attn_3_varlen_func = None
70+
try:
71+
from kernels import get_kernel
72+
73+
vllm_flash_attn3 = get_kernel(_DEFAULT_HUB_ID_FA3)
74+
flash_attn_3_func = vllm_flash_attn3.flash_attn_func
75+
flash_attn_3_varlen_func = vllm_flash_attn3.flash_attn_varlen_func
76+
logger.debug(f"Using Flash Attention 3 from {_DEFAULT_HUB_ID_FA3} using the `kernels` lib.")
77+
except ImportError:
78+
flash_attn_3_func = None
79+
flash_attn_3_varlen_func = None
6980

7081

7182
if _CAN_USE_SAGE_ATTN:
@@ -132,8 +143,6 @@ def wrap(func):
132143
_register_fake = register_fake_no_op
133144

134145

135-
logger = get_logger(__name__) # pylint: disable=invalid-name
136-
137146
# TODO(aryan): Add support for the following:
138147
# - Sage Attention++
139148
# - block sparse, radial and other attention methods

0 commit comments

Comments
 (0)