diff --git a/docs/source/en/optimization/attention_backends.md b/docs/source/en/optimization/attention_backends.md index 8be2c0603009..edfdcc38b50b 100644 --- a/docs/source/en/optimization/attention_backends.md +++ b/docs/source/en/optimization/attention_backends.md @@ -21,6 +21,7 @@ Refer to the table below for an overview of the available attention families and | attention family | main feature | |---|---| | FlashAttention | minimizes memory reads/writes through tiling and recomputation | +| AI Tensor Engine for ROCm | FlashAttention implementation optimized for AMD ROCm accelerators | | SageAttention | quantizes attention to int8 | | PyTorch native | built-in PyTorch implementation using [scaled_dot_product_attention](./fp16#scaled-dot-product-attention) | | xFormers | memory-efficient attention with support for various attention kernels | @@ -139,6 +140,7 @@ Refer to the table below for a complete list of available attention backends and | `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention | | `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 | | `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention | +| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm | | `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 | | `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 | | `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels | diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index e1694910997a..ab0d7102ee83 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -27,6 +27,8 @@ from ..utils import ( get_logger, + is_aiter_available, + is_aiter_version, is_flash_attn_3_available, is_flash_attn_available, is_flash_attn_version, @@ -47,6 +49,7 @@ from ._modeling_parallel import ParallelConfig _REQUIRED_FLASH_VERSION = "2.6.3" +_REQUIRED_AITER_VERSION = "0.1.5" _REQUIRED_SAGE_VERSION = "2.1.1" _REQUIRED_FLEX_VERSION = "2.5.0" _REQUIRED_XLA_VERSION = "2.2" @@ -54,6 +57,7 @@ _CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION) _CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available() +_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION) _CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION) _CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION) _CAN_USE_NPU_ATTN = is_torch_npu_available() @@ -78,6 +82,12 @@ flash_attn_3_func = None flash_attn_3_varlen_func = None + +if _CAN_USE_AITER_ATTN: + from aiter import flash_attn_func as aiter_flash_attn_func +else: + aiter_flash_attn_func = None + if DIFFUSERS_ENABLE_HUB_KERNELS: if not is_kernels_available(): raise ImportError( @@ -178,6 +188,9 @@ class AttentionBackendName(str, Enum): _FLASH_3_HUB = "_flash_3_hub" # _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet. + # `aiter` + AITER = "aiter" + # PyTorch native FLEX = "flex" NATIVE = "native" @@ -414,6 +427,12 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." ) + elif backend == AttentionBackendName.AITER: + if not _CAN_USE_AITER_ATTN: + raise RuntimeError( + f"Aiter Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `aiter>={_REQUIRED_AITER_VERSION}`." + ) + elif backend in [ AttentionBackendName.SAGE, AttentionBackendName.SAGE_VARLEN, @@ -1397,6 +1416,47 @@ def _flash_varlen_attention_3( return (out, lse) if return_lse else out +@_AttentionBackendRegistry.register( + AttentionBackendName.AITER, + constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _aiter_flash_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, +) -> torch.Tensor: + if not return_lse and torch.is_grad_enabled(): + # aiter requires return_lse=True by assertion when gradients are enabled. + out, lse, *_ = aiter_flash_attn_func( + q=query, + k=key, + v=value, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + return_lse=True, + ) + else: + out = aiter_flash_attn_func( + q=query, + k=key, + v=value, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + return_lse=return_lse, + ) + if return_lse: + out, lse, *_ = out + + return (out, lse) if return_lse else out + + @_AttentionBackendRegistry.register( AttentionBackendName.FLEX, constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index d8e1a5540100..cf77aaee8205 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -64,6 +64,8 @@ get_objects_from_module, is_accelerate_available, is_accelerate_version, + is_aiter_available, + is_aiter_version, is_better_profanity_available, is_bitsandbytes_available, is_bitsandbytes_version, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 97065267b004..adf8ed8b0694 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -226,6 +226,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b _sageattention_available, _sageattention_version = _is_package_available("sageattention") _flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") _flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3") +_aiter_available, _aiter_version = _is_package_available("aiter") _kornia_available, _kornia_version = _is_package_available("kornia") _nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True) @@ -406,6 +407,10 @@ def is_flash_attn_3_available(): return _flash_attn_3_available +def is_aiter_available(): + return _aiter_available + + def is_kornia_available(): return _kornia_available @@ -911,6 +916,22 @@ def is_flash_attn_version(operation: str, version: str): return compare_versions(parse(_flash_attn_version), operation, version) +@cache +def is_aiter_version(operation: str, version: str): + """ + Compares the current aiter version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _aiter_available: + return False + return compare_versions(parse(_aiter_version), operation, version) + + def get_objects_from_module(module): """ Returns a dict of object names and values in a module, while skipping private/internal objects diff --git a/tests/others/test_attention_backends.py b/tests/others/test_attention_backends.py index 42cdcd56f74a..2e5a2fc82bb6 100644 --- a/tests/others/test_attention_backends.py +++ b/tests/others/test_attention_backends.py @@ -14,6 +14,10 @@ Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in "native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128). + +Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X +with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and +aiter 0.1.5.post4.dev20+ga25e55e79. """ import os @@ -44,6 +48,10 @@ "_native_cudnn", torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16), ), + ( + "aiter", + torch.tensor([0.0781, 0.0820, 0.0879, 0.0957, 0.0898, 0.0938, 0.0957, 0.0957, 0.2285, 0.2363, 0.2461, 0.2637, 0.2695, 0.2617, 0.2617, 0.2891], dtype=torch.bfloat16), + ) ] COMPILE_CASES = [ @@ -63,6 +71,11 @@ torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16), True, ), + ( + "aiter", + torch.tensor([0.0391, 0.0391, 0.0430, 0.0488, 0.0469, 0.0566, 0.0586, 0.0566, 0.2402, 0.2539, 0.2637, 0.2812, 0.2930, 0.2910, 0.2891, 0.3164], dtype=torch.bfloat16), + True, + ) ] # fmt: on