diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index e1694910997a..9d8896f91b3c 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -16,6 +16,7 @@ import functools import inspect import math +from dataclasses import dataclass from enum import Enum from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union @@ -40,7 +41,7 @@ is_xformers_available, is_xformers_version, ) -from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS +from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS if TYPE_CHECKING: @@ -78,18 +79,6 @@ flash_attn_3_func = None flash_attn_3_varlen_func = None -if DIFFUSERS_ENABLE_HUB_KERNELS: - if not is_kernels_available(): - raise ImportError( - "To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`." - ) - from ..utils.kernels_utils import _get_fa3_from_hub - - flash_attn_interface_hub = _get_fa3_from_hub() - flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func -else: - flash_attn_3_func_hub = None - if _CAN_USE_SAGE_ATTN: from sageattention import ( sageattn, @@ -249,6 +238,25 @@ def _is_context_parallel_enabled( return supports_context_parallel and is_degree_greater_than_1 +@dataclass +class _HubKernelConfig: + """Configuration for downloading and using a hub-based attention kernel.""" + + repo_id: str + function_attr: str + revision: Optional[str] = None + kernel_fn: Optional[Callable] = None + + +# Registry for hub-based attention kernels +_HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = { + # TODO: temporary revision for now. Remove when merged upstream into `main`. + AttentionBackendName._FLASH_3_HUB: _HubKernelConfig( + repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs" + ) +} + + @contextlib.contextmanager def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE): """ @@ -405,13 +413,9 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None # TODO: add support Hub variant of FA3 varlen later elif backend in [AttentionBackendName._FLASH_3_HUB]: - if not DIFFUSERS_ENABLE_HUB_KERNELS: - raise RuntimeError( - f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`." - ) if not is_kernels_available(): raise RuntimeError( - f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." + f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." ) elif backend in [ @@ -555,6 +559,29 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): return q_idx >= kv_idx +# ===== Helpers for downloading kernels ===== +def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None: + if backend not in _HUB_KERNELS_REGISTRY: + return + config = _HUB_KERNELS_REGISTRY[backend] + + if config.kernel_fn is not None: + return + + try: + from kernels import get_kernel + + kernel_module = get_kernel(config.repo_id, revision=config.revision) + kernel_func = getattr(kernel_module, config.function_attr) + + # Cache the downloaded kernel function in the config object + config.kernel_fn = kernel_func + + except Exception as e: + logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}") + raise + + # ===== torch op registrations ===== # Registrations are required for fullgraph tracing compatibility # TODO: this is only required because the beta release FA3 does not have it. There is a PR adding @@ -1322,7 +1349,8 @@ def _flash_attention_3_hub( return_attn_probs: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - out = flash_attn_3_func_hub( + func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn + out = func( q=query, k=key, v=value, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 91daca1ad809..3880418fb03e 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -595,7 +595,11 @@ def set_attention_backend(self, backend: str) -> None: attention as backend. """ from .attention import AttentionModuleMixin - from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements + from .attention_dispatch import ( + AttentionBackendName, + _check_attention_backend_requirements, + _maybe_download_kernel_for_backend, + ) # TODO: the following will not be required when everything is refactored to AttentionModuleMixin from .attention_processor import Attention, MochiAttention @@ -606,8 +610,10 @@ def set_attention_backend(self, backend: str) -> None: available_backends = {x.value for x in AttentionBackendName.__members__.values()} if backend not in available_backends: raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends)) + backend = AttentionBackendName(backend) _check_attention_backend_requirements(backend) + _maybe_download_kernel_for_backend(backend) attention_classes = (Attention, MochiAttention, AttentionModuleMixin) for module in self.modules(): diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 42a53e181034..051a0c034e52 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -46,7 +46,6 @@ DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8 HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").upper() in ENV_VARS_TRUE_VALUES -DIFFUSERS_ENABLE_HUB_KERNELS = os.environ.get("DIFFUSERS_ENABLE_HUB_KERNELS", "").upper() in ENV_VARS_TRUE_VALUES # Below should be `True` if the current version of `peft` and `transformers` are compatible with # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are diff --git a/src/diffusers/utils/kernels_utils.py b/src/diffusers/utils/kernels_utils.py deleted file mode 100644 index 26d6e3972fb7..000000000000 --- a/src/diffusers/utils/kernels_utils.py +++ /dev/null @@ -1,23 +0,0 @@ -from ..utils import get_logger -from .import_utils import is_kernels_available - - -logger = get_logger(__name__) - - -_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3" - - -def _get_fa3_from_hub(): - if not is_kernels_available(): - return None - else: - from kernels import get_kernel - - try: - # TODO: temporary revision for now. Remove when merged upstream into `main`. - flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs") - return flash_attn_3_hub - except Exception as e: - logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}") - raise diff --git a/tests/others/test_attention_backends.py b/tests/others/test_attention_backends.py index 42cdcd56f74a..8f4667792a02 100644 --- a/tests/others/test_attention_backends.py +++ b/tests/others/test_attention_backends.py @@ -7,7 +7,6 @@ ```bash export RUN_ATTENTION_BACKEND_TESTS=yes -export DIFFUSERS_ENABLE_HUB_KERNELS=yes pytest tests/others/test_attention_backends.py ```