Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 67 additions & 10 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,32 @@
flash_attn_3_func = None
flash_attn_3_varlen_func = None

if DIFFUSERS_ENABLE_HUB_KERNELS:
if not is_kernels_available():
raise ImportError(
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
)
from ..utils.kernels_utils import _get_fa3_from_hub
flash_attn_func_hub = None
flash_attn_3_func_hub = None

flash_attn_interface_hub = _get_fa3_from_hub()
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
else:
flash_attn_3_func_hub = None
try:
if DIFFUSERS_ENABLE_HUB_KERNELS:
if not is_kernels_available():
raise ImportError(
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
)
from ..utils.kernels_utils import _get_fa3_from_hub

flash_attn_interface_hub = _get_fa3_from_hub()
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
flash_attn_func_hub = flash_attn_3_func_hub # point generic hub variable
else:
if not is_kernels_available():
raise ImportError(
"To use FA kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
)
from ..utils.kernels_utils import _get_fa_from_hub

flash_attn_interface_hub = _get_fa_from_hub()
flash_attn_func_hub = flash_attn_interface_hub.flash_attn_func
except Exception:
# Keep variables as None if initialization fails
pass

if _CAN_USE_SAGE_ATTN:
from sageattention import (
Expand Down Expand Up @@ -167,6 +182,7 @@ class AttentionBackendName(str, Enum):
_FLASH_VARLEN_3 = "_flash_varlen_3"
_FLASH_3_HUB = "_flash_3_hub"
# _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
_FLASH_HUB = "_flash_hub"

# PyTorch native
FLEX = "flex"
Expand Down Expand Up @@ -376,6 +392,16 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)

elif backend in [AttentionBackendName._FLASH_HUB]:
if not DIFFUSERS_ENABLE_HUB_KERNELS:
raise RuntimeError(
f"Flash Attention Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
)
if not is_kernels_available():
raise RuntimeError(
f"Flash Attention Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)

elif backend in [
AttentionBackendName.SAGE,
AttentionBackendName.SAGE_VARLEN,
Expand Down Expand Up @@ -720,6 +746,37 @@ def _flash_attention_3_hub(
return (out[0], out[1]) if return_attn_probs else out


@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _flash_attention_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: Optional[float] = None,
is_causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
alibi_slopes: Optional[torch.Tensor] = None,
deterministic: bool = False,
return_attn_probs: bool = False,
) -> torch.Tensor:
out = flash_attn_func_hub(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=return_attn_probs,
)
return out(out[0], out[1]) if return_attn_probs else out


@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_VARLEN_3,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
Expand Down
16 changes: 16 additions & 0 deletions src/diffusers/utils/kernels_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"

_DEFAULT_HUB_ID_FA = "kernels-community/flash-attn"


def _get_fa3_from_hub():
if not is_kernels_available():
Expand All @@ -21,3 +23,17 @@ def _get_fa3_from_hub():
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
raise


def _get_fa_from_hub():
if not is_kernels_available():
return None
else:
from kernels import get_kernel

try:
flash_attn_hub = get_kernel(_DEFAULT_HUB_ID_FA)
return flash_attn_hub
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA}' from the Hub: {e}")
raise