|
68 | 68 | flash_attn_3_func = None |
69 | 69 | flash_attn_3_varlen_func = None |
70 | 70 |
|
71 | | -if DIFFUSERS_ENABLE_HUB_KERNELS: |
72 | | - if not is_kernels_available(): |
73 | | - raise ImportError( |
74 | | - "To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`." |
75 | | - ) |
76 | | - from ..utils.kernels_utils import _get_fa3_from_hub |
| 71 | +flash_attn_func_hub = None |
| 72 | +flash_attn_3_func_hub = None |
77 | 73 |
|
78 | | - flash_attn_interface_hub = _get_fa3_from_hub() |
79 | | - flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func |
80 | | - flash_attn_3_func_hub = None |
| 74 | +try: |
| 75 | + if DIFFUSERS_ENABLE_HUB_KERNELS: |
| 76 | + if not is_kernels_available(): |
| 77 | + raise ImportError( |
| 78 | + "To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`." |
| 79 | + ) |
| 80 | + from ..utils.kernels_utils import _get_fa3_from_hub |
81 | 81 |
|
82 | | -else: |
83 | | - if not is_kernels_available(): |
84 | | - raise ImportError( |
85 | | - "To use FA kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`." |
86 | | - ) |
87 | | - from ..utils.kernels_utils import _get_fa_from_hub |
| 82 | + flash_attn_interface_hub = _get_fa3_from_hub() |
| 83 | + flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func |
| 84 | + flash_attn_func_hub = flash_attn_3_func_hub # point generic hub variable |
| 85 | + else: |
| 86 | + if not is_kernels_available(): |
| 87 | + raise ImportError( |
| 88 | + "To use FA kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`." |
| 89 | + ) |
| 90 | + from ..utils.kernels_utils import _get_fa_from_hub |
88 | 91 |
|
89 | | - flash_attn_interface_hub = _get_fa_from_hub() |
90 | | - flash_attn_func_hub = flash_attn_interface_hub.flash_attn_func |
91 | | - flash_attn_func_hub= None |
| 92 | + flash_attn_interface_hub = _get_fa_from_hub() |
| 93 | + flash_attn_func_hub = flash_attn_interface_hub.flash_attn_func |
| 94 | +except Exception: |
| 95 | + # Keep variables as None if initialization fails |
| 96 | + pass |
92 | 97 |
|
93 | 98 | if _CAN_USE_SAGE_ATTN: |
94 | 99 | from sageattention import ( |
|
0 commit comments