Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 47 additions & 19 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import functools
import inspect
import math
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union

Expand All @@ -40,7 +41,7 @@
is_xformers_available,
is_xformers_version,
)
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS


if TYPE_CHECKING:
Expand Down Expand Up @@ -78,18 +79,6 @@
flash_attn_3_func = None
flash_attn_3_varlen_func = None

if DIFFUSERS_ENABLE_HUB_KERNELS:
if not is_kernels_available():
raise ImportError(
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
)
from ..utils.kernels_utils import _get_fa3_from_hub

flash_attn_interface_hub = _get_fa3_from_hub()
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
else:
flash_attn_3_func_hub = None

if _CAN_USE_SAGE_ATTN:
from sageattention import (
sageattn,
Expand Down Expand Up @@ -249,6 +238,25 @@ def _is_context_parallel_enabled(
return supports_context_parallel and is_degree_greater_than_1


@dataclass
class _HubKernelConfig:
"""Configuration for downloading and using a hub-based attention kernel."""

repo_id: str
function_attr: str
revision: Optional[str] = None
kernel_fn: Optional[Callable] = None


# Registry for hub-based attention kernels
_HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
# TODO: temporary revision for now. Remove when merged upstream into `main`.
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
)
}


@contextlib.contextmanager
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
"""
Expand Down Expand Up @@ -405,13 +413,9 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None

# TODO: add support Hub variant of FA3 varlen later
elif backend in [AttentionBackendName._FLASH_3_HUB]:
if not DIFFUSERS_ENABLE_HUB_KERNELS:
raise RuntimeError(
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
)
if not is_kernels_available():
raise RuntimeError(
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)

elif backend in [
Expand Down Expand Up @@ -555,6 +559,29 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
return q_idx >= kv_idx


# ===== Helpers for downloading kernels =====
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
if backend not in _HUB_KERNELS_REGISTRY:
return
config = _HUB_KERNELS_REGISTRY[backend]

if config.kernel_fn is not None:
return

try:
from kernels import get_kernel

kernel_module = get_kernel(config.repo_id, revision=config.revision)
kernel_func = getattr(kernel_module, config.function_attr)

# Cache the downloaded kernel function in the config object
config.kernel_fn = kernel_func

except Exception as e:
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
raise


# ===== torch op registrations =====
# Registrations are required for fullgraph tracing compatibility
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
Expand Down Expand Up @@ -1322,7 +1349,8 @@ def _flash_attention_3_hub(
return_attn_probs: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
out = flash_attn_3_func_hub(
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
Expand Down
8 changes: 7 additions & 1 deletion src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,11 @@ def set_attention_backend(self, backend: str) -> None:
attention as backend.
"""
from .attention import AttentionModuleMixin
from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements
from .attention_dispatch import (
AttentionBackendName,
_check_attention_backend_requirements,
_maybe_download_kernel_for_backend,
)

# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
from .attention_processor import Attention, MochiAttention
Expand All @@ -606,8 +610,10 @@ def set_attention_backend(self, backend: str) -> None:
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
if backend not in available_backends:
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))

backend = AttentionBackendName(backend)
_check_attention_backend_requirements(backend)
_maybe_download_kernel_for_backend(backend)

attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():
Expand Down
1 change: 0 additions & 1 deletion src/diffusers/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").upper() in ENV_VARS_TRUE_VALUES
DIFFUSERS_ENABLE_HUB_KERNELS = os.environ.get("DIFFUSERS_ENABLE_HUB_KERNELS", "").upper() in ENV_VARS_TRUE_VALUES

# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
Expand Down
23 changes: 0 additions & 23 deletions src/diffusers/utils/kernels_utils.py

This file was deleted.

1 change: 0 additions & 1 deletion tests/others/test_attention_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

```bash
export RUN_ATTENTION_BACKEND_TESTS=yes
export DIFFUSERS_ENABLE_HUB_KERNELS=yes

pytest tests/others/test_attention_backends.py
```
Expand Down
Loading