Skip to content

Commit 6c15a10

Browse files
committed
update
1 parent 9db9be6 commit 6c15a10

File tree

2 files changed

+85
-17
lines changed

2 files changed

+85
-17
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 81 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,26 +38,37 @@
3838
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
3939

4040

41-
logger = get_logger(__name__) # pylint: disable=invalid-name
42-
43-
44-
if is_flash_attn_available() and is_flash_attn_version(">=", "2.6.3"):
41+
_REQUIRED_FLASH_VERSION = "2.6.3"
42+
_REQUIRED_SAGE_VERSION = "2.1.1"
43+
_REQUIRED_FLEX_VERSION = "2.5.0"
44+
_REQUIRED_XLA_VERSION = "2.2"
45+
_REQUIRED_XFORMERS_VERSION = "0.0.29"
46+
47+
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
48+
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
49+
_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)
50+
_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)
51+
_CAN_USE_NPU_ATTN = is_torch_npu_available()
52+
_CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION)
53+
_CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION)
54+
55+
56+
if _CAN_USE_FLASH_ATTN:
4557
from flash_attn import flash_attn_func, flash_attn_varlen_func
4658
else:
47-
logger.warning("`flash-attn` is not available or the version is too old. Please install `flash-attn>=2.6.3`.")
4859
flash_attn_func = None
4960
flash_attn_varlen_func = None
5061

5162

52-
if is_flash_attn_3_available():
63+
if _CAN_USE_FLASH_ATTN_3:
5364
from flash_attn_interface import flash_attn_func as flash_attn_3_func
5465
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
5566
else:
5667
flash_attn_3_func = None
5768
flash_attn_3_varlen_func = None
5869

5970

60-
if is_sageattention_available() and is_sageattention_version(">=", "2.1.1"):
71+
if _CAN_USE_SAGE_ATTN:
6172
from sageattention import (
6273
sageattn,
6374
sageattn_qk_int8_pv_fp8_cuda,
@@ -67,9 +78,6 @@
6778
sageattn_varlen,
6879
)
6980
else:
70-
logger.warning(
71-
"`sageattention` is not available or the version is too old. Please install `sageattention>=2.1.1`."
72-
)
7381
sageattn = None
7482
sageattn_qk_int8_pv_fp16_cuda = None
7583
sageattn_qk_int8_pv_fp16_triton = None
@@ -78,39 +86,39 @@
7886
sageattn_varlen = None
7987

8088

81-
if is_torch_version(">=", "2.5.0"):
89+
if _CAN_USE_FLEX_ATTN:
8290
# We cannot import the flex_attention function from the package directly because it is expected (from the
8391
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
8492
# compiled function.
8593
import torch.nn.attention.flex_attention as flex_attention
8694

8795

88-
if is_torch_npu_available():
96+
if _CAN_USE_NPU_ATTN:
8997
from torch_npu import npu_fusion_attention
9098
else:
9199
npu_fusion_attention = None
92100

93101

94-
if is_torch_xla_available() and is_torch_xla_version(">", "2.2"):
102+
if _CAN_USE_XLA_ATTN:
95103
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
96104
else:
97105
xla_flash_attention = None
98106

99107

100-
if is_xformers_available() and is_xformers_version(">=", "0.0.29"):
108+
if _CAN_USE_XFORMERS_ATTN:
101109
import xformers.ops as xops
102110
else:
103-
logger.warning("`xformers` is not available or the version is too old. Please install `xformers>=0.0.29`.")
104111
xops = None
105112

106113

114+
logger = get_logger(__name__) # pylint: disable=invalid-name
115+
107116
# TODO(aryan): Add support for the following:
108117
# - Sage Attention++
109118
# - block sparse, radial and other attention methods
110119
# - CP with sage attention, flex, xformers, other missing backends
111120
# - Add support for normal and CP training with backends that don't support it yet
112121

113-
114122
_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
115123
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
116124
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
@@ -171,6 +179,7 @@ def decorator(func):
171179

172180
@classmethod
173181
def get_active_backend(cls):
182+
_check_backend_requirements(cls._active_backend)
174183
return cls._active_backend, cls._backends[cls._active_backend]
175184

176185
@classmethod
@@ -226,9 +235,10 @@ def dispatch_attention_fn(
226235
"dropout_p": dropout_p,
227236
"is_causal": is_causal,
228237
"scale": scale,
229-
"enable_gqa": enable_gqa,
230238
**attention_kwargs,
231239
}
240+
if is_torch_version(">=", "2.5.0"):
241+
kwargs["enable_gqa"] = enable_gqa
232242

233243
if _AttentionBackendRegistry._checks_enabled:
234244
removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name])
@@ -305,6 +315,60 @@ def _check_shape(
305315
# ===== Helper functions =====
306316

307317

318+
# LRU cache is hack to avoid checking the backend requirements multiple times. Maybe not needed
319+
# because CPU is running much farther ahead of the accelerator and this will not be blocking anyway.
320+
@functools.lru_cache(maxsize=16)
321+
def _check_backend_requirements(backend: AttentionBackendName) -> None:
322+
if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]:
323+
if not _CAN_USE_FLASH_ATTN:
324+
raise RuntimeError(
325+
f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`."
326+
)
327+
328+
elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]:
329+
if not _CAN_USE_FLASH_ATTN_3:
330+
raise RuntimeError(
331+
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
332+
)
333+
334+
elif backend in [
335+
AttentionBackendName.SAGE,
336+
AttentionBackendName.SAGE_VARLEN,
337+
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
338+
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
339+
AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
340+
AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
341+
]:
342+
if not _CAN_USE_SAGE_ATTN:
343+
raise RuntimeError(
344+
f"Sage Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `sageattention>={_REQUIRED_SAGE_VERSION}`."
345+
)
346+
347+
elif backend == AttentionBackendName.FLEX:
348+
if not _CAN_USE_FLEX_ATTN:
349+
raise RuntimeError(
350+
f"Flex Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch>=2.5.0`."
351+
)
352+
353+
elif backend == AttentionBackendName._NATIVE_NPU:
354+
if not _CAN_USE_NPU_ATTN:
355+
raise RuntimeError(
356+
f"NPU Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_npu`."
357+
)
358+
359+
elif backend == AttentionBackendName._NATIVE_XLA:
360+
if not _CAN_USE_XLA_ATTN:
361+
raise RuntimeError(
362+
f"XLA Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_xla>={_REQUIRED_XLA_VERSION}`."
363+
)
364+
365+
elif backend == AttentionBackendName.XFORMERS:
366+
if not _CAN_USE_XFORMERS_ATTN:
367+
raise RuntimeError(
368+
f"Xformers Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `xformers>={_REQUIRED_XFORMERS_VERSION}`."
369+
)
370+
371+
308372
@functools.lru_cache(maxsize=128)
309373
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
310374
batch_size: int,

src/diffusers/models/modeling_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,8 @@ def set_attention_backend(self, backend: str) -> None:
627627
# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
628628
from .attention_processor import Attention, MochiAttention
629629

630+
logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
631+
630632
backend = backend.lower()
631633
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
632634
if backend not in available_backends:
@@ -651,6 +653,8 @@ def reset_attention_backend(self) -> None:
651653
from .attention import AttentionModuleMixin
652654
from .attention_processor import Attention, MochiAttention
653655

656+
logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
657+
654658
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
655659
for module in self.modules():
656660
if not isinstance(module, attention_classes):

0 commit comments

Comments
 (0)