Skip to content

Commit c386f22

Browse files
committed
up
1 parent 310fdaf commit c386f22

File tree

2 files changed

+65
-20
lines changed

2 files changed

+65
-20
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,15 @@
8080
raise ImportError(
8181
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
8282
)
83-
from ..utils.kernels_utils import _get_fa3_from_hub
83+
from ..utils.kernels_utils import _get_fa3_from_hub, get_fa_from_hub
8484

85-
flash_attn_interface_hub = _get_fa3_from_hub()
86-
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
85+
fa3_interface_hub = _get_fa3_from_hub()
86+
flash_attn_3_func_hub = fa3_interface_hub.flash_attn_func
87+
fa_interface_hub = get_fa_from_hub()
88+
flash_attn_func_hub = fa_interface_hub.flash_attn_func
8789
else:
8890
flash_attn_3_func_hub = None
91+
flash_attn_func_hub = None
8992

9093
if _CAN_USE_SAGE_ATTN:
9194
from sageattention import (
@@ -170,6 +173,8 @@ class AttentionBackendName(str, Enum):
170173
# `flash-attn`
171174
FLASH = "flash"
172175
FLASH_VARLEN = "flash_varlen"
176+
FLASH_HUB = "flash_hub"
177+
# FLASH_VARLEN_HUB = "flash_varlen_hub" # not supported yet.
173178
_FLASH_3 = "_flash_3"
174179
_FLASH_VARLEN_3 = "_flash_varlen_3"
175180
_FLASH_3_HUB = "_flash_3_hub"
@@ -400,15 +405,15 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
400405
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
401406
)
402407

403-
# TODO: add support Hub variant of FA3 varlen later
404-
elif backend in [AttentionBackendName._FLASH_3_HUB]:
408+
# TODO: add support Hub variant of FA and FA3 varlen later
409+
elif backend in [AttentionBackendName.FLASH_HUB, AttentionBackendName._FLASH_3_HUB]:
405410
if not DIFFUSERS_ENABLE_HUB_KERNELS:
406411
raise RuntimeError(
407-
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
412+
f"Flash Attention Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
408413
)
409414
if not is_kernels_available():
410415
raise RuntimeError(
411-
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
416+
f"Flash Attention Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
412417
)
413418

414419
elif backend in [
@@ -1225,6 +1230,36 @@ def _flash_attention(
12251230
return (out, lse) if return_lse else out
12261231

12271232

1233+
@_AttentionBackendRegistry.register(
1234+
AttentionBackendName.FLASH_HUB,
1235+
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
1236+
)
1237+
def _flash_attention_hub(
1238+
query: torch.Tensor,
1239+
key: torch.Tensor,
1240+
value: torch.Tensor,
1241+
dropout_p: float = 0.0,
1242+
is_causal: bool = False,
1243+
scale: Optional[float] = None,
1244+
return_lse: bool = False,
1245+
_parallel_config: Optional["ParallelConfig"] = None,
1246+
) -> torch.Tensor:
1247+
lse = None
1248+
out = flash_attn_func(
1249+
q=query,
1250+
k=key,
1251+
v=value,
1252+
dropout_p=dropout_p,
1253+
softmax_scale=scale,
1254+
causal=is_causal,
1255+
return_attn_probs=return_lse,
1256+
)
1257+
if return_lse:
1258+
out, lse, *_ = out
1259+
1260+
return (out, lse) if return_lse else out
1261+
1262+
12281263
@_AttentionBackendRegistry.register(
12291264
AttentionBackendName.FLASH_VARLEN,
12301265
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],

src/diffusers/utils/kernels_utils.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,32 @@
22
from .import_utils import is_kernels_available
33

44

5-
logger = get_logger(__name__)
5+
if is_kernels_available():
6+
from kernels import get_kernel
67

8+
logger = get_logger(__name__)
79

8-
_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"
10+
_DEFAULT_HUB_IDS = {
11+
"fa3": ("kernels-community/flash-attn3", {"revision": "fake-ops-return-probs"}),
12+
"fa": ("kernels-community/flash-attn", {}),
13+
}
914

1015

11-
def _get_fa3_from_hub():
16+
def _get_from_hub(key: str):
1217
if not is_kernels_available():
1318
return None
14-
else:
15-
from kernels import get_kernel
16-
17-
try:
18-
# TODO: temporary revision for now. Remove when merged upstream into `main`.
19-
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs")
20-
return flash_attn_3_hub
21-
except Exception as e:
22-
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
23-
raise
19+
20+
hub_id, kwargs = _DEFAULT_HUB_IDS[key]
21+
try:
22+
return get_kernel(hub_id, **kwargs)
23+
except Exception as e:
24+
logger.error(f"An error occurred while fetching kernel '{hub_id}' from the Hub: {e}")
25+
raise
26+
27+
28+
def get_fa3_from_hub():
29+
return _get_from_hub("fa3")
30+
31+
32+
def get_fa_from_hub():
33+
return _get_from_hub("fa")

0 commit comments

Comments
 (0)