diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 141a7fee858b..c00ec7dd6e41 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -38,18 +38,29 @@ from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS -logger = get_logger(__name__) # pylint: disable=invalid-name - - -if is_flash_attn_available() and is_flash_attn_version(">=", "2.6.3"): +_REQUIRED_FLASH_VERSION = "2.6.3" +_REQUIRED_SAGE_VERSION = "2.1.1" +_REQUIRED_FLEX_VERSION = "2.5.0" +_REQUIRED_XLA_VERSION = "2.2" +_REQUIRED_XFORMERS_VERSION = "0.0.29" + +_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION) +_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available() +_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION) +_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION) +_CAN_USE_NPU_ATTN = is_torch_npu_available() +_CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION) +_CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION) + + +if _CAN_USE_FLASH_ATTN: from flash_attn import flash_attn_func, flash_attn_varlen_func else: - logger.warning("`flash-attn` is not available or the version is too old. Please install `flash-attn>=2.6.3`.") flash_attn_func = None flash_attn_varlen_func = None -if is_flash_attn_3_available(): +if _CAN_USE_FLASH_ATTN_3: from flash_attn_interface import flash_attn_func as flash_attn_3_func from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func else: @@ -57,7 +68,7 @@ flash_attn_3_varlen_func = None -if is_sageattention_available() and is_sageattention_version(">=", "2.1.1"): +if _CAN_USE_SAGE_ATTN: from sageattention import ( sageattn, sageattn_qk_int8_pv_fp8_cuda, @@ -67,9 +78,6 @@ sageattn_varlen, ) else: - logger.warning( - "`sageattention` is not available or the version is too old. Please install `sageattention>=2.1.1`." - ) sageattn = None sageattn_qk_int8_pv_fp16_cuda = None sageattn_qk_int8_pv_fp16_triton = None @@ -78,39 +86,39 @@ sageattn_varlen = None -if is_torch_version(">=", "2.5.0"): +if _CAN_USE_FLEX_ATTN: # We cannot import the flex_attention function from the package directly because it is expected (from the # pytorch documentation) that the user may compile it. If we import directly, we will not have access to the # compiled function. import torch.nn.attention.flex_attention as flex_attention -if is_torch_npu_available(): +if _CAN_USE_NPU_ATTN: from torch_npu import npu_fusion_attention else: npu_fusion_attention = None -if is_torch_xla_available() and is_torch_xla_version(">", "2.2"): +if _CAN_USE_XLA_ATTN: from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention else: xla_flash_attention = None -if is_xformers_available() and is_xformers_version(">=", "0.0.29"): +if _CAN_USE_XFORMERS_ATTN: import xformers.ops as xops else: - logger.warning("`xformers` is not available or the version is too old. Please install `xformers>=0.0.29`.") xops = None +logger = get_logger(__name__) # pylint: disable=invalid-name + # TODO(aryan): Add support for the following: # - Sage Attention++ # - block sparse, radial and other attention methods # - CP with sage attention, flex, xformers, other missing backends # - Add support for normal and CP training with backends that don't support it yet - _SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"] _SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] _SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] @@ -179,13 +187,16 @@ def list_backends(cls): @contextlib.contextmanager -def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIVE): +def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE): """ Context manager to set the active attention backend. """ if backend not in _AttentionBackendRegistry._backends: raise ValueError(f"Backend {backend} is not registered.") + backend = AttentionBackendName(backend) + _check_attention_backend_requirements(backend) + old_backend = _AttentionBackendRegistry._active_backend _AttentionBackendRegistry._active_backend = backend @@ -226,9 +237,10 @@ def dispatch_attention_fn( "dropout_p": dropout_p, "is_causal": is_causal, "scale": scale, - "enable_gqa": enable_gqa, **attention_kwargs, } + if is_torch_version(">=", "2.5.0"): + kwargs["enable_gqa"] = enable_gqa if _AttentionBackendRegistry._checks_enabled: removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name]) @@ -305,6 +317,57 @@ def _check_shape( # ===== Helper functions ===== +def _check_attention_backend_requirements(backend: AttentionBackendName) -> None: + if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]: + if not _CAN_USE_FLASH_ATTN: + raise RuntimeError( + f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`." + ) + + elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]: + if not _CAN_USE_FLASH_ATTN_3: + raise RuntimeError( + f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source." + ) + + elif backend in [ + AttentionBackendName.SAGE, + AttentionBackendName.SAGE_VARLEN, + AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA, + AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90, + AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA, + AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON, + ]: + if not _CAN_USE_SAGE_ATTN: + raise RuntimeError( + f"Sage Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `sageattention>={_REQUIRED_SAGE_VERSION}`." + ) + + elif backend == AttentionBackendName.FLEX: + if not _CAN_USE_FLEX_ATTN: + raise RuntimeError( + f"Flex Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch>=2.5.0`." + ) + + elif backend == AttentionBackendName._NATIVE_NPU: + if not _CAN_USE_NPU_ATTN: + raise RuntimeError( + f"NPU Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_npu`." + ) + + elif backend == AttentionBackendName._NATIVE_XLA: + if not _CAN_USE_XLA_ATTN: + raise RuntimeError( + f"XLA Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_xla>={_REQUIRED_XLA_VERSION}`." + ) + + elif backend == AttentionBackendName.XFORMERS: + if not _CAN_USE_XFORMERS_ATTN: + raise RuntimeError( + f"Xformers Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `xformers>={_REQUIRED_XFORMERS_VERSION}`." + ) + + @functools.lru_cache(maxsize=128) def _prepare_for_flash_attn_or_sage_varlen_without_mask( batch_size: int, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index fb01e7e01a1e..4941b6d2a7b5 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -622,19 +622,21 @@ def set_attention_backend(self, backend: str) -> None: attention as backend. """ from .attention import AttentionModuleMixin - from .attention_dispatch import AttentionBackendName + from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements # TODO: the following will not be required when everything is refactored to AttentionModuleMixin from .attention_processor import Attention, MochiAttention + logger.warning("Attention backends are an experimental feature and the API may be subject to change.") + backend = backend.lower() available_backends = {x.value for x in AttentionBackendName.__members__.values()} if backend not in available_backends: raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends)) - backend = AttentionBackendName(backend) - attention_classes = (Attention, MochiAttention, AttentionModuleMixin) + _check_attention_backend_requirements(backend) + attention_classes = (Attention, MochiAttention, AttentionModuleMixin) for module in self.modules(): if not isinstance(module, attention_classes): continue @@ -651,6 +653,8 @@ def reset_attention_backend(self) -> None: from .attention import AttentionModuleMixin from .attention_processor import Attention, MochiAttention + logger.warning("Attention backends are an experimental feature and the API may be subject to change.") + attention_classes = (Attention, MochiAttention, AttentionModuleMixin) for module in self.modules(): if not isinstance(module, attention_classes):