Skip to content

Commit 2fa8e53

Browse files
committed
update with proper handling.
1 parent 3ce378e commit 2fa8e53

File tree

1 file changed

+23
-18
lines changed

1 file changed

+23
-18
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -68,27 +68,32 @@
6868
flash_attn_3_func = None
6969
flash_attn_3_varlen_func = None
7070

71-
if DIFFUSERS_ENABLE_HUB_KERNELS:
72-
if not is_kernels_available():
73-
raise ImportError(
74-
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
75-
)
76-
from ..utils.kernels_utils import _get_fa3_from_hub
71+
flash_attn_func_hub = None
72+
flash_attn_3_func_hub = None
7773

78-
flash_attn_interface_hub = _get_fa3_from_hub()
79-
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
80-
flash_attn_3_func_hub = None
74+
try:
75+
if DIFFUSERS_ENABLE_HUB_KERNELS:
76+
if not is_kernels_available():
77+
raise ImportError(
78+
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
79+
)
80+
from ..utils.kernels_utils import _get_fa3_from_hub
8181

82-
else:
83-
if not is_kernels_available():
84-
raise ImportError(
85-
"To use FA kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
86-
)
87-
from ..utils.kernels_utils import _get_fa_from_hub
82+
flash_attn_interface_hub = _get_fa3_from_hub()
83+
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
84+
flash_attn_func_hub = flash_attn_3_func_hub # point generic hub variable
85+
else:
86+
if not is_kernels_available():
87+
raise ImportError(
88+
"To use FA kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
89+
)
90+
from ..utils.kernels_utils import _get_fa_from_hub
8891

89-
flash_attn_interface_hub = _get_fa_from_hub()
90-
flash_attn_func_hub = flash_attn_interface_hub.flash_attn_func
91-
flash_attn_func_hub= None
92+
flash_attn_interface_hub = _get_fa_from_hub()
93+
flash_attn_func_hub = flash_attn_interface_hub.flash_attn_func
94+
except Exception:
95+
# Keep variables as None if initialization fails
96+
pass
9297

9398
if _CAN_USE_SAGE_ATTN:
9499
from sageattention import (

0 commit comments

Comments
 (0)