Skip to content

Commit bc40971

Browse files
committed
change to Hub.
1 parent ac43e84 commit bc40971

File tree

2 files changed

+113
-38
lines changed

2 files changed

+113
-38
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 96 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
is_flash_attn_3_available,
2727
is_flash_attn_available,
2828
is_flash_attn_version,
29+
is_kernels_available,
2930
is_sageattention_available,
3031
is_sageattention_version,
3132
is_torch_npu_available,
@@ -54,7 +55,6 @@
5455
_CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION)
5556
_CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION)
5657

57-
_DEFAULT_HUB_ID_FA3 = "kernels-community/vllm-flash-attn3"
5858

5959
if _CAN_USE_FLASH_ATTN:
6060
from flash_attn import flash_attn_func, flash_attn_varlen_func
@@ -67,16 +67,22 @@
6767
from flash_attn_interface import flash_attn_func as flash_attn_3_func
6868
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
6969
else:
70-
try:
71-
from kernels import get_kernel
70+
flash_attn_3_func = None
71+
flash_attn_3_varlen_func = None
7272

73-
vllm_flash_attn3 = get_kernel(_DEFAULT_HUB_ID_FA3)
74-
flash_attn_3_func = vllm_flash_attn3.flash_attn_func
75-
flash_attn_3_varlen_func = vllm_flash_attn3.flash_attn_varlen_func
76-
logger.debug(f"Using Flash Attention 3 from {_DEFAULT_HUB_ID_FA3} using the `kernels` lib.")
77-
except ImportError:
78-
flash_attn_3_func = None
79-
flash_attn_3_varlen_func = None
73+
if is_kernels_available():
74+
from ..utils.kernels_utils import _get_fa3_from_hub
75+
76+
flash_attn_interface_hub = _get_fa3_from_hub()
77+
if flash_attn_interface_hub is not None:
78+
flash_attn_3_hub_func = flash_attn_interface_hub.flash_attn_func
79+
flash_attn_3_varlen_hub_func = flash_attn_interface_hub.flash_attn_varlen_func
80+
else:
81+
flash_attn_3_hub_func = None
82+
flash_attn_3_varlen_hub_func = None
83+
else:
84+
flash_attn_3_hub_func = None
85+
flash_attn_3_varlen_hub_func = None
8086

8187

8288
if _CAN_USE_SAGE_ATTN:
@@ -162,6 +168,8 @@ class AttentionBackendName(str, Enum):
162168
FLASH_VARLEN = "flash_varlen"
163169
_FLASH_3 = "_flash_3"
164170
_FLASH_VARLEN_3 = "_flash_varlen_3"
171+
_FLASH_3_HUB = "_flash_3_hub"
172+
_FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
165173

166174
# PyTorch native
167175
FLEX = "flex"
@@ -355,11 +363,20 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
355363
)
356364

357365
elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]:
358-
if not _CAN_USE_FLASH_ATTN_3 and (flash_attn_3_func is None and flash_attn_3_varlen_func is None):
366+
if not _CAN_USE_FLASH_ATTN_3:
359367
raise RuntimeError(
360368
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
361369
)
362370

371+
# TODO: add support Hub variant of FA3 varlen later
372+
elif backend in [AttentionBackendName._FLASH_3_HUB]:
373+
if not is_kernels_available():
374+
raise RuntimeError(
375+
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
376+
)
377+
elif backend in [AttentionBackendName._FLASH_VARLEN_3_HUB]:
378+
raise NotImplementedError
379+
363380
elif backend in [
364381
AttentionBackendName.SAGE,
365382
AttentionBackendName.SAGE_VARLEN,
@@ -523,6 +540,22 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc
523540
return torch.empty_like(query), query.new_empty(lse_shape)
524541

525542

543+
@_custom_op("flash_attn_3_hub_func", mutates_args=(), device_types="cuda")
544+
def _wrapped_flash_attn_3_hub(
545+
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
546+
) -> Tuple[torch.Tensor, torch.Tensor]:
547+
out, lse = flash_attn_3_hub_func(query, key, value)
548+
lse = lse.permute(0, 2, 1)
549+
return out, lse
550+
551+
552+
@_register_fake("flash_attn_3_hub_func")
553+
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
554+
batch_size, seq_len, num_heads, head_dim = query.shape
555+
lse_shape = (batch_size, seq_len, num_heads)
556+
return torch.empty_like(query), query.new_empty(lse_shape)
557+
558+
526559
# ===== Attention backends =====
527560

528561

@@ -645,34 +678,59 @@ def _flash_attention_3(
645678
deterministic: bool = False,
646679
return_attn_probs: bool = False,
647680
) -> torch.Tensor:
648-
sig = inspect.signature(flash_attn_3_func)
649-
accepted = set(sig.parameters)
650-
params = {
651-
"q": query,
652-
"k": key,
653-
"v": value,
654-
"softmax_scale": scale,
655-
"causal": is_causal,
656-
"qv": None,
657-
"q_descale": None,
658-
"k_descale": None,
659-
"v_descale": None,
660-
"window_size": window_size,
661-
"attention_chunk": 0,
662-
"softcap": softcap,
663-
"num_splits": 1,
664-
"pack_gqa": None,
665-
"deterministic": deterministic,
666-
"sm_margin": 0,
667-
}
668-
kwargs = {}
669-
for name, value in params.items():
670-
if name not in accepted:
671-
logger.debug(f"{name} is not accepted by the `flash_attn_3_func` method, so it will be discarded.")
672-
else:
673-
kwargs[name] = value
681+
out, lse, *_ = flash_attn_3_func(
682+
q=query,
683+
k=key,
684+
v=value,
685+
softmax_scale=scale,
686+
causal=is_causal,
687+
qv=None,
688+
q_descale=None,
689+
k_descale=None,
690+
v_descale=None,
691+
window_size=window_size,
692+
attention_chunk=0,
693+
softcap=softcap,
694+
num_splits=1,
695+
pack_gqa=None,
696+
deterministic=deterministic,
697+
sm_margin=0,
698+
)
699+
return (out, lse) if return_attn_probs else out
700+
674701

675-
out, lse, *_ = flash_attn_3_func(**kwargs)
702+
@_AttentionBackendRegistry.register(
703+
AttentionBackendName._FLASH_3_HUB,
704+
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
705+
)
706+
def _flash_attention_3_hub(
707+
query: torch.Tensor,
708+
key: torch.Tensor,
709+
value: torch.Tensor,
710+
scale: Optional[float] = None,
711+
is_causal: bool = False,
712+
window_size: Tuple[int, int] = (-1, -1),
713+
softcap: float = 0.0,
714+
deterministic: bool = False,
715+
return_attn_probs: bool = False,
716+
) -> torch.Tensor:
717+
out, lse, *_ = flash_attn_3_hub_func(
718+
q=query,
719+
k=key,
720+
v=value,
721+
softmax_scale=scale,
722+
causal=is_causal,
723+
qv=None,
724+
q_descale=None,
725+
k_descale=None,
726+
v_descale=None,
727+
window_size=window_size,
728+
softcap=softcap,
729+
num_splits=1,
730+
pack_gqa=None,
731+
deterministic=deterministic,
732+
sm_margin=0,
733+
)
676734
return (out, lse) if return_attn_probs else out
677735

678736

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from .import_utils import is_kernels_available
2+
3+
4+
_DEFAULT_HUB_ID_FA3 = "kernels-community/vllm-flash-attn3"
5+
6+
7+
def _get_fa3_from_hub():
8+
if not is_kernels_available():
9+
return None
10+
else:
11+
from kernels import get_kernel
12+
13+
try:
14+
vllm_flash_attn3 = get_kernel(_DEFAULT_HUB_ID_FA3)
15+
return vllm_flash_attn3
16+
except Exception as e:
17+
raise e

0 commit comments

Comments
 (0)