Skip to content

Commit 633a4dd

Browse files
committed
add aiter attention backend
1 parent 7536f64 commit 633a4dd

File tree

5 files changed

+95
-0
lines changed

5 files changed

+95
-0
lines changed

docs/source/en/optimization/attention_backends.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Refer to the table below for an overview of the available attention families and
2121
| attention family | main feature |
2222
|---|---|
2323
| FlashAttention | minimizes memory reads/writes through tiling and recomputation |
24+
| AI Tensor Engine for ROCm | FlashAttention implementation optimized for AMD ROCm accelerators |
2425
| SageAttention | quantizes attention to int8 |
2526
| PyTorch native | built-in PyTorch implementation using [scaled_dot_product_attention](./fp16#scaled-dot-product-attention) |
2627
| xFormers | memory-efficient attention with support for various attention kernels |
@@ -139,6 +140,7 @@ Refer to the table below for a complete list of available attention backends and
139140
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
140141
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
141142
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
143+
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
142144
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
143145
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
144146
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |

src/diffusers/models/attention_dispatch.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
is_flash_attn_3_available,
3131
is_flash_attn_available,
3232
is_flash_attn_version,
33+
is_aiter_available,
34+
is_aiter_version,
3335
is_kernels_available,
3436
is_sageattention_available,
3537
is_sageattention_version,
@@ -47,13 +49,15 @@
4749
from ._modeling_parallel import ParallelConfig
4850

4951
_REQUIRED_FLASH_VERSION = "2.6.3"
52+
_REQUIRED_AITER_VERSION = "0.1.5"
5053
_REQUIRED_SAGE_VERSION = "2.1.1"
5154
_REQUIRED_FLEX_VERSION = "2.5.0"
5255
_REQUIRED_XLA_VERSION = "2.2"
5356
_REQUIRED_XFORMERS_VERSION = "0.0.29"
5457

5558
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
5659
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
60+
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
5761
_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)
5862
_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)
5963
_CAN_USE_NPU_ATTN = is_torch_npu_available()
@@ -78,6 +82,12 @@
7882
flash_attn_3_func = None
7983
flash_attn_3_varlen_func = None
8084

85+
86+
if _CAN_USE_AITER_ATTN:
87+
from aiter import flash_attn_func as aiter_flash_attn_func
88+
else:
89+
aiter_flash_attn_func = None
90+
8191
if DIFFUSERS_ENABLE_HUB_KERNELS:
8292
if not is_kernels_available():
8393
raise ImportError(
@@ -178,6 +188,9 @@ class AttentionBackendName(str, Enum):
178188
_FLASH_3_HUB = "_flash_3_hub"
179189
# _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
180190

191+
# `aiter`
192+
AITER = "aiter"
193+
181194
# PyTorch native
182195
FLEX = "flex"
183196
NATIVE = "native"
@@ -414,6 +427,12 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
414427
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
415428
)
416429

430+
elif backend == AttentionBackendName.AITER:
431+
if not _CAN_USE_AITER_ATTN:
432+
raise RuntimeError(
433+
f"Aiter Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `aiter>={_REQUIRED_AITER_VERSION}`."
434+
)
435+
417436
elif backend in [
418437
AttentionBackendName.SAGE,
419438
AttentionBackendName.SAGE_VARLEN,
@@ -1397,6 +1416,47 @@ def _flash_varlen_attention_3(
13971416
return (out, lse) if return_lse else out
13981417

13991418

1419+
@_AttentionBackendRegistry.register(
1420+
AttentionBackendName.AITER,
1421+
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
1422+
)
1423+
def _aiter_flash_attention(
1424+
query: torch.Tensor,
1425+
key: torch.Tensor,
1426+
value: torch.Tensor,
1427+
dropout_p: float = 0.0,
1428+
is_causal: bool = False,
1429+
scale: Optional[float] = None,
1430+
return_lse: bool = False,
1431+
_parallel_config: Optional["ParallelConfig"] = None,
1432+
) -> torch.Tensor:
1433+
if not return_lse and torch.is_grad_enabled():
1434+
# aiter requires return_lse=True by assertion when gradients are enabled.
1435+
out, lse, *_ = aiter_flash_attn_func(
1436+
q=query,
1437+
k=key,
1438+
v=value,
1439+
dropout_p=dropout_p,
1440+
softmax_scale=scale,
1441+
causal=is_causal,
1442+
return_lse=True,
1443+
)
1444+
else:
1445+
out = aiter_flash_attn_func(
1446+
q=query,
1447+
k=key,
1448+
v=value,
1449+
dropout_p=dropout_p,
1450+
softmax_scale=scale,
1451+
causal=is_causal,
1452+
return_lse=return_lse,
1453+
)
1454+
if return_lse:
1455+
out, lse, *_ = out
1456+
1457+
return (out, lse) if return_lse else out
1458+
1459+
14001460
@_AttentionBackendRegistry.register(
14011461
AttentionBackendName.FLEX,
14021462
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],

src/diffusers/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@
7272
is_flash_attn_3_available,
7373
is_flash_attn_available,
7474
is_flash_attn_version,
75+
is_aiter_available,
76+
is_aiter_version,
7577
is_flax_available,
7678
is_ftfy_available,
7779
is_gguf_available,

src/diffusers/utils/import_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
226226
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
227227
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
228228
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
229+
_aiter_available, _aiter_version = _is_package_available("aiter")
229230
_kornia_available, _kornia_version = _is_package_available("kornia")
230231
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
231232

@@ -405,6 +406,8 @@ def is_flash_attn_available():
405406
def is_flash_attn_3_available():
406407
return _flash_attn_3_available
407408

409+
def is_aiter_available():
410+
return _aiter_available
408411

409412
def is_kornia_available():
410413
return _kornia_available
@@ -910,6 +913,21 @@ def is_flash_attn_version(operation: str, version: str):
910913
return False
911914
return compare_versions(parse(_flash_attn_version), operation, version)
912915

916+
@cache
917+
def is_aiter_version(operation: str, version: str):
918+
"""
919+
Compares the current aiter version to a given reference with an operation.
920+
921+
Args:
922+
operation (`str`):
923+
A string representation of an operator, such as `">"` or `"<="`
924+
version (`str`):
925+
A version string
926+
"""
927+
if not _aiter_available:
928+
return False
929+
return compare_versions(parse(_aiter_version), operation, version)
930+
913931

914932
def get_objects_from_module(module):
915933
"""

tests/others/test_attention_backends.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
1515
Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in
1616
"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).
17+
18+
Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X
19+
with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and
20+
aiter 0.1.5.post4.dev20+ga25e55e79.
1721
"""
1822

1923
import os
@@ -44,6 +48,10 @@
4448
"_native_cudnn",
4549
torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16),
4650
),
51+
(
52+
"aiter",
53+
torch.tensor([0.0781, 0.0820, 0.0879, 0.0957, 0.0898, 0.0938, 0.0957, 0.0957, 0.2285, 0.2363, 0.2461, 0.2637, 0.2695, 0.2617, 0.2617, 0.2891], dtype=torch.bfloat16),
54+
)
4755
]
4856

4957
COMPILE_CASES = [
@@ -63,6 +71,11 @@
6371
torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16),
6472
True,
6573
),
74+
(
75+
"aiter",
76+
torch.tensor([0.0391, 0.0391, 0.0430, 0.0488, 0.0469, 0.0566, 0.0586, 0.0566, 0.2402, 0.2539, 0.2637, 0.2812, 0.2930, 0.2910, 0.2891, 0.3164], dtype=torch.bfloat16),
77+
True,
78+
)
6679
]
6780
# fmt: on
6881

0 commit comments

Comments
 (0)