diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index f71be7c8ecc0..d98859abf0f0 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -68,17 +68,32 @@ flash_attn_3_func = None flash_attn_3_varlen_func = None -if DIFFUSERS_ENABLE_HUB_KERNELS: - if not is_kernels_available(): - raise ImportError( - "To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`." - ) - from ..utils.kernels_utils import _get_fa3_from_hub +flash_attn_func_hub = None +flash_attn_3_func_hub = None - flash_attn_interface_hub = _get_fa3_from_hub() - flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func -else: - flash_attn_3_func_hub = None +try: + if DIFFUSERS_ENABLE_HUB_KERNELS: + if not is_kernels_available(): + raise ImportError( + "To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`." + ) + from ..utils.kernels_utils import _get_fa3_from_hub + + flash_attn_interface_hub = _get_fa3_from_hub() + flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func + flash_attn_func_hub = flash_attn_3_func_hub # point generic hub variable + else: + if not is_kernels_available(): + raise ImportError( + "To use FA kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`." + ) + from ..utils.kernels_utils import _get_fa_from_hub + + flash_attn_interface_hub = _get_fa_from_hub() + flash_attn_func_hub = flash_attn_interface_hub.flash_attn_func +except Exception: + # Keep variables as None if initialization fails + pass if _CAN_USE_SAGE_ATTN: from sageattention import ( @@ -167,6 +182,7 @@ class AttentionBackendName(str, Enum): _FLASH_VARLEN_3 = "_flash_varlen_3" _FLASH_3_HUB = "_flash_3_hub" # _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet. + _FLASH_HUB = "_flash_hub" # PyTorch native FLEX = "flex" @@ -376,6 +392,16 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." ) + elif backend in [AttentionBackendName._FLASH_HUB]: + if not DIFFUSERS_ENABLE_HUB_KERNELS: + raise RuntimeError( + f"Flash Attention Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`." + ) + if not is_kernels_available(): + raise RuntimeError( + f"Flash Attention Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." + ) + elif backend in [ AttentionBackendName.SAGE, AttentionBackendName.SAGE_VARLEN, @@ -720,6 +746,37 @@ def _flash_attention_3_hub( return (out[0], out[1]) if return_attn_probs else out +@_AttentionBackendRegistry.register( + AttentionBackendName._FLASH_HUB, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _flash_attention_hub( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + deterministic: bool = False, + return_attn_probs: bool = False, +) -> torch.Tensor: + out = flash_attn_func_hub( + q=query, + k=key, + v=value, + softmax_scale=scale, + causal=is_causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=return_attn_probs, + ) + return out(out[0], out[1]) if return_attn_probs else out + + @_AttentionBackendRegistry.register( AttentionBackendName._FLASH_VARLEN_3, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], diff --git a/src/diffusers/utils/kernels_utils.py b/src/diffusers/utils/kernels_utils.py index 26d6e3972fb7..5fa630823e31 100644 --- a/src/diffusers/utils/kernels_utils.py +++ b/src/diffusers/utils/kernels_utils.py @@ -7,6 +7,8 @@ _DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3" +_DEFAULT_HUB_ID_FA = "kernels-community/flash-attn" + def _get_fa3_from_hub(): if not is_kernels_available(): @@ -21,3 +23,17 @@ def _get_fa3_from_hub(): except Exception as e: logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}") raise + + +def _get_fa_from_hub(): + if not is_kernels_available(): + return None + else: + from kernels import get_kernel + + try: + flash_attn_hub = get_kernel(_DEFAULT_HUB_ID_FA) + return flash_attn_hub + except Exception as e: + logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA}' from the Hub: {e}") + raise