Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 42 additions & 23 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS


logger = get_logger(__name__) # pylint: disable=invalid-name

_REQUIRED_FLASH_VERSION = "2.6.3"
_REQUIRED_SAGE_VERSION = "2.1.1"
_REQUIRED_FLEX_VERSION = "2.5.0"
Expand All @@ -52,6 +54,7 @@
_CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION)
_CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION)

_DEFAULT_HUB_ID_FA3 = "kernels-community/vllm-flash-attn3"

if _CAN_USE_FLASH_ATTN:
from flash_attn import flash_attn_func, flash_attn_varlen_func
Expand All @@ -64,8 +67,16 @@
from flash_attn_interface import flash_attn_func as flash_attn_3_func
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
else:
flash_attn_3_func = None
flash_attn_3_varlen_func = None
try:
from kernels import get_kernel

vllm_flash_attn3 = get_kernel(_DEFAULT_HUB_ID_FA3)
flash_attn_3_func = vllm_flash_attn3.flash_attn_func
flash_attn_3_varlen_func = vllm_flash_attn3.flash_attn_varlen_func
logger.debug(f"Using Flash Attention 3 from {_DEFAULT_HUB_ID_FA3} using the `kernels` lib.")
except ImportError:
flash_attn_3_func = None
flash_attn_3_varlen_func = None


if _CAN_USE_SAGE_ATTN:
Expand Down Expand Up @@ -132,8 +143,6 @@ def wrap(func):
_register_fake = register_fake_no_op


logger = get_logger(__name__) # pylint: disable=invalid-name

# TODO(aryan): Add support for the following:
# - Sage Attention++
# - block sparse, radial and other attention methods
Expand Down Expand Up @@ -346,7 +355,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
)

elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]:
if not _CAN_USE_FLASH_ATTN_3:
if not _CAN_USE_FLASH_ATTN_3 and (flash_attn_3_func is None and flash_attn_3_varlen_func is None):
raise RuntimeError(
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
)
Expand Down Expand Up @@ -636,24 +645,34 @@ def _flash_attention_3(
deterministic: bool = False,
return_attn_probs: bool = False,
) -> torch.Tensor:
out, lse, *_ = flash_attn_3_func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
attention_chunk=0,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
)
sig = inspect.signature(flash_attn_3_func)
accepted = set(sig.parameters)
params = {
"q": query,
"k": key,
"v": value,
"softmax_scale": scale,
"causal": is_causal,
"qv": None,
"q_descale": None,
"k_descale": None,
"v_descale": None,
"window_size": window_size,
"attention_chunk": 0,
"softcap": softcap,
"num_splits": 1,
"pack_gqa": None,
"deterministic": deterministic,
"sm_margin": 0,
}
kwargs = {}
for name, value in params.items():
if name not in accepted:
logger.debug(f"{name} is not accepted by the `flash_attn_3_func` method, so it will be discarded.")
else:
kwargs[name] = value

out, lse, *_ = flash_attn_3_func(**kwargs)
return (out, lse) if return_attn_probs else out


Expand Down