Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/optimization/attention_backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Refer to the table below for an overview of the available attention families and
| attention family | main feature |
|---|---|
| FlashAttention | minimizes memory reads/writes through tiling and recomputation |
| AI Tensor Engine for ROCm | FlashAttention implementation optimized for AMD ROCm accelerators |
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this PR.

Do you think it might be possible to package the aiter kernels with kernels? If so, we could also support through the kernel hub then like we do for FA3 and others (FA2 and SAGE).

Cc: @danieldk

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a great project and would also make for a good follow-up, though perhaps best handled via separate issue/PR? If I understand it correctly, the kernel would first need to make it to kernels before integration to diffusers.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100% not related.

| SageAttention | quantizes attention to int8 |
| PyTorch native | built-in PyTorch implementation using [scaled_dot_product_attention](./fp16#scaled-dot-product-attention) |
| xFormers | memory-efficient attention with support for various attention kernels |
Expand Down Expand Up @@ -139,6 +140,7 @@ Refer to the table below for a complete list of available attention backends and
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
Expand Down
60 changes: 60 additions & 0 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

from ..utils import (
get_logger,
is_aiter_available,
is_aiter_version,
is_flash_attn_3_available,
is_flash_attn_available,
is_flash_attn_version,
Expand All @@ -47,13 +49,15 @@
from ._modeling_parallel import ParallelConfig

_REQUIRED_FLASH_VERSION = "2.6.3"
_REQUIRED_AITER_VERSION = "0.1.5"
_REQUIRED_SAGE_VERSION = "2.1.1"
_REQUIRED_FLEX_VERSION = "2.5.0"
_REQUIRED_XLA_VERSION = "2.2"
_REQUIRED_XFORMERS_VERSION = "0.0.29"

_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)
_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)
_CAN_USE_NPU_ATTN = is_torch_npu_available()
Expand All @@ -78,6 +82,12 @@
flash_attn_3_func = None
flash_attn_3_varlen_func = None


if _CAN_USE_AITER_ATTN:
from aiter import flash_attn_func as aiter_flash_attn_func
else:
aiter_flash_attn_func = None

if DIFFUSERS_ENABLE_HUB_KERNELS:
if not is_kernels_available():
raise ImportError(
Expand Down Expand Up @@ -178,6 +188,9 @@ class AttentionBackendName(str, Enum):
_FLASH_3_HUB = "_flash_3_hub"
# _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.

# `aiter`
AITER = "aiter"

# PyTorch native
FLEX = "flex"
NATIVE = "native"
Expand Down Expand Up @@ -414,6 +427,12 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)

elif backend == AttentionBackendName.AITER:
if not _CAN_USE_AITER_ATTN:
raise RuntimeError(
f"Aiter Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `aiter>={_REQUIRED_AITER_VERSION}`."
)

elif backend in [
AttentionBackendName.SAGE,
AttentionBackendName.SAGE_VARLEN,
Expand Down Expand Up @@ -1397,6 +1416,47 @@ def _flash_varlen_attention_3(
return (out, lse) if return_lse else out


@_AttentionBackendRegistry.register(
AttentionBackendName.AITER,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _aiter_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if not return_lse and torch.is_grad_enabled():
# aiter requires return_lse=True by assertion when gradients are enabled.
out, lse, *_ = aiter_flash_attn_func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_lse=True,
)
else:
out = aiter_flash_attn_func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_lse=return_lse,
)
if return_lse:
out, lse, *_ = out

return (out, lse) if return_lse else out


@_AttentionBackendRegistry.register(
AttentionBackendName.FLEX,
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@
get_objects_from_module,
is_accelerate_available,
is_accelerate_version,
is_aiter_available,
is_aiter_version,
is_better_profanity_available,
is_bitsandbytes_available,
is_bitsandbytes_version,
Expand Down
21 changes: 21 additions & 0 deletions src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
_aiter_available, _aiter_version = _is_package_available("aiter")
_kornia_available, _kornia_version = _is_package_available("kornia")
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)

Expand Down Expand Up @@ -406,6 +407,10 @@ def is_flash_attn_3_available():
return _flash_attn_3_available


def is_aiter_available():
return _aiter_available


def is_kornia_available():
return _kornia_available

Expand Down Expand Up @@ -911,6 +916,22 @@ def is_flash_attn_version(operation: str, version: str):
return compare_versions(parse(_flash_attn_version), operation, version)


@cache
def is_aiter_version(operation: str, version: str):
"""
Compares the current aiter version to a given reference with an operation.

Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _aiter_available:
return False
return compare_versions(parse(_aiter_version), operation, version)


def get_objects_from_module(module):
"""
Returns a dict of object names and values in a module, while skipping private/internal objects
Expand Down
13 changes: 13 additions & 0 deletions tests/others/test_attention_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in
"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).

Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X
with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and
aiter 0.1.5.post4.dev20+ga25e55e79.
"""

import os
Expand Down Expand Up @@ -44,6 +48,10 @@
"_native_cudnn",
torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16),
),
(
"aiter",
torch.tensor([0.0781, 0.0820, 0.0879, 0.0957, 0.0898, 0.0938, 0.0957, 0.0957, 0.2285, 0.2363, 0.2461, 0.2637, 0.2695, 0.2617, 0.2617, 0.2891], dtype=torch.bfloat16),
)
]

COMPILE_CASES = [
Expand All @@ -63,6 +71,11 @@
torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16),
True,
),
(
"aiter",
torch.tensor([0.0391, 0.0391, 0.0430, 0.0488, 0.0469, 0.0566, 0.0586, 0.0566, 0.2402, 0.2539, 0.2637, 0.2812, 0.2930, 0.2910, 0.2891, 0.3164], dtype=torch.bfloat16),
True,
)
]
# fmt: on

Expand Down