Skip to content

Commit 4e69d42

Browse files
committed
up
1 parent 2bb3796 commit 4e69d42

File tree

2 files changed

+28
-27
lines changed

2 files changed

+28
-27
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import inspect
1818
import math
1919
from enum import Enum
20+
from functools import lru_cache
2021
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
2122

2223
import torch
@@ -39,8 +40,6 @@
3940
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
4041

4142

42-
logger = get_logger(__name__) # pylint: disable=invalid-name
43-
4443
_REQUIRED_FLASH_VERSION = "2.6.3"
4544
_REQUIRED_SAGE_VERSION = "2.1.1"
4645
_REQUIRED_FLEX_VERSION = "2.5.0"
@@ -70,20 +69,6 @@
7069
flash_attn_3_func = None
7170
flash_attn_3_varlen_func = None
7271

73-
if is_kernels_available():
74-
from ..utils.kernels_utils import _get_fa3_from_hub
75-
76-
flash_attn_interface_hub = _get_fa3_from_hub()
77-
if flash_attn_interface_hub is not None:
78-
flash_attn_3_hub_func = flash_attn_interface_hub.flash_attn_func
79-
flash_attn_3_varlen_hub_func = flash_attn_interface_hub.flash_attn_varlen_func
80-
else:
81-
flash_attn_3_hub_func = None
82-
flash_attn_3_varlen_hub_func = None
83-
else:
84-
flash_attn_3_hub_func = None
85-
flash_attn_3_varlen_hub_func = None
86-
8772

8873
if _CAN_USE_SAGE_ATTN:
8974
from sageattention import (
@@ -148,6 +133,7 @@ def wrap(func):
148133
_custom_op = custom_op_no_op
149134
_register_fake = register_fake_no_op
150135

136+
logger = get_logger(__name__) # pylint: disable=invalid-name
151137

152138
# TODO(aryan): Add support for the following:
153139
# - Sage Attention++
@@ -169,7 +155,7 @@ class AttentionBackendName(str, Enum):
169155
_FLASH_3 = "_flash_3"
170156
_FLASH_VARLEN_3 = "_flash_varlen_3"
171157
_FLASH_3_HUB = "_flash_3_hub"
172-
_FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
158+
# _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
173159

174160
# PyTorch native
175161
FLEX = "flex"
@@ -224,6 +210,22 @@ def list_backends(cls):
224210
return list(cls._backends.keys())
225211

226212

213+
@lru_cache(maxsize=None)
214+
def _load_fa3_hub():
215+
from ..utils.kernels_utils import _get_fa3_from_hub
216+
217+
fa3_hub = _get_fa3_from_hub() # won't re-download if already present
218+
if fa3_hub is None:
219+
raise RuntimeError(
220+
"Failed to load FlashAttention-3 kernels from the Hub. Please ensure the wheel is available for your platform."
221+
)
222+
return fa3_hub
223+
224+
225+
def flash_attn_3_hub_func(*args, **kwargs):
226+
return _load_fa3_hub().flash_attn_func(*args, **kwargs)
227+
228+
227229
@contextlib.contextmanager
228230
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
229231
"""
@@ -374,12 +376,6 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
374376
raise RuntimeError(
375377
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
376378
)
377-
if flash_attn_3_hub_func is None:
378-
raise RuntimeError(
379-
"`flash_attn_3_hub_func` wasn't available. Please double if `kernels` was able to successfully pull the FA3 kernel from kernels-community/vllm-flash-attn3."
380-
)
381-
elif backend in [AttentionBackendName._FLASH_VARLEN_3_HUB]:
382-
raise NotImplementedError
383379

384380
elif backend in [
385381
AttentionBackendName.SAGE,
@@ -544,7 +540,7 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc
544540
return torch.empty_like(query), query.new_empty(lse_shape)
545541

546542

547-
@_custom_op("vllm_flash_attn3::_flash_attn_forward", mutates_args=(), device_types="cuda")
543+
@_custom_op("vllm_flash_attn3::flash_attn", mutates_args=(), device_types="cuda")
548544
def _wrapped_flash_attn_3_hub(
549545
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
550546
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -553,7 +549,7 @@ def _wrapped_flash_attn_3_hub(
553549
return out, lse
554550

555551

556-
@_register_fake("vllm_flash_attn3::_flash_attn_forward")
552+
@_register_fake("vllm_flash_attn3::flash_attn")
557553
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
558554
batch_size, seq_len, num_heads, head_dim = query.shape
559555
lse_shape = (batch_size, seq_len, num_heads)

src/diffusers/utils/kernels_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
from ..utils import get_logger
12
from .import_utils import is_kernels_available
23

34

5+
logger = get_logger(__name__)
6+
7+
48
_DEFAULT_HUB_ID_FA3 = "kernels-community/vllm-flash-attn3"
59

610

@@ -13,5 +17,6 @@ def _get_fa3_from_hub():
1317
try:
1418
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3)
1519
return flash_attn_3_hub
16-
except Exception:
17-
return None
20+
except Exception as e:
21+
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
22+
raise

0 commit comments

Comments
 (0)