Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 79 additions & 2 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
is_flash_attn_3_available,
is_flash_attn_available,
is_flash_attn_version,
is_kernels_available,
is_sageattention_available,
is_sageattention_version,
is_torch_npu_available,
Expand All @@ -38,6 +39,8 @@
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS


logger = get_logger(__name__) # pylint: disable=invalid-name

_REQUIRED_FLASH_VERSION = "2.6.3"
_REQUIRED_SAGE_VERSION = "2.1.1"
_REQUIRED_FLEX_VERSION = "2.5.0"
Expand Down Expand Up @@ -67,6 +70,20 @@
flash_attn_3_func = None
flash_attn_3_varlen_func = None

if is_kernels_available():
from ..utils.kernels_utils import _get_fa3_from_hub

flash_attn_interface_hub = _get_fa3_from_hub()
if flash_attn_interface_hub is not None:
flash_attn_3_hub_func = flash_attn_interface_hub.flash_attn_func
flash_attn_3_varlen_hub_func = flash_attn_interface_hub.flash_attn_varlen_func
else:
flash_attn_3_hub_func = None
flash_attn_3_varlen_hub_func = None
else:
flash_attn_3_hub_func = None
flash_attn_3_varlen_hub_func = None


if _CAN_USE_SAGE_ATTN:
from sageattention import (
Expand Down Expand Up @@ -132,8 +149,6 @@ def wrap(func):
_register_fake = register_fake_no_op


logger = get_logger(__name__) # pylint: disable=invalid-name

# TODO(aryan): Add support for the following:
# - Sage Attention++
# - block sparse, radial and other attention methods
Expand All @@ -153,6 +168,8 @@ class AttentionBackendName(str, Enum):
FLASH_VARLEN = "flash_varlen"
_FLASH_3 = "_flash_3"
_FLASH_VARLEN_3 = "_flash_varlen_3"
_FLASH_3_HUB = "_flash_3_hub"
_FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.

# PyTorch native
FLEX = "flex"
Expand Down Expand Up @@ -351,6 +368,15 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
)

# TODO: add support Hub variant of FA3 varlen later
elif backend in [AttentionBackendName._FLASH_3_HUB]:
if not is_kernels_available():
raise RuntimeError(
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)
elif backend in [AttentionBackendName._FLASH_VARLEN_3_HUB]:
raise NotImplementedError

elif backend in [
AttentionBackendName.SAGE,
AttentionBackendName.SAGE_VARLEN,
Expand Down Expand Up @@ -514,6 +540,22 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc
return torch.empty_like(query), query.new_empty(lse_shape)


@_custom_op("vllm_flash_attn3::_flash_attn_forward", mutates_args=(), device_types="cuda")
def _wrapped_flash_attn_3_hub(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
out, lse = flash_attn_3_hub_func(query, key, value)
lse = lse.permute(0, 2, 1)
return out, lse


@_register_fake("vllm_flash_attn3::_flash_attn_forward")
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, seq_len, num_heads, head_dim = query.shape
lse_shape = (batch_size, seq_len, num_heads)
return torch.empty_like(query), query.new_empty(lse_shape)


# ===== Attention backends =====


Expand Down Expand Up @@ -657,6 +699,41 @@ def _flash_attention_3(
return (out, lse) if return_attn_probs else out


@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_3_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _flash_attention_3_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: Optional[float] = None,
is_causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
deterministic: bool = False,
return_attn_probs: bool = False,
) -> torch.Tensor:
out, lse, *_ = flash_attn_3_hub_func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
)
return (out, lse) if return_attn_probs else out


@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_VARLEN_3,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
Expand Down
17 changes: 17 additions & 0 deletions src/diffusers/utils/kernels_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from .import_utils import is_kernels_available


_DEFAULT_HUB_ID_FA3 = "kernels-community/vllm-flash-attn3"


def _get_fa3_from_hub():
if not is_kernels_available():
return None
else:
from kernels import get_kernel

try:
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3)
return flash_attn_3_hub
except Exception as e:
raise e
Loading