| 
30 | 30 |     is_flash_attn_3_available,  | 
31 | 31 |     is_flash_attn_available,  | 
32 | 32 |     is_flash_attn_version,  | 
 | 33 | +    is_aiter_available,  | 
 | 34 | +    is_aiter_version,  | 
33 | 35 |     is_kernels_available,  | 
34 | 36 |     is_sageattention_available,  | 
35 | 37 |     is_sageattention_version,  | 
 | 
47 | 49 |     from ._modeling_parallel import ParallelConfig  | 
48 | 50 | 
 
  | 
49 | 51 | _REQUIRED_FLASH_VERSION = "2.6.3"  | 
 | 52 | +_REQUIRED_AITER_VERSION = "0.1.5"  | 
50 | 53 | _REQUIRED_SAGE_VERSION = "2.1.1"  | 
51 | 54 | _REQUIRED_FLEX_VERSION = "2.5.0"  | 
52 | 55 | _REQUIRED_XLA_VERSION = "2.2"  | 
53 | 56 | _REQUIRED_XFORMERS_VERSION = "0.0.29"  | 
54 | 57 | 
 
  | 
55 | 58 | _CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)  | 
56 | 59 | _CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()  | 
 | 60 | +_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)  | 
57 | 61 | _CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)  | 
58 | 62 | _CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)  | 
59 | 63 | _CAN_USE_NPU_ATTN = is_torch_npu_available()  | 
 | 
78 | 82 |     flash_attn_3_func = None  | 
79 | 83 |     flash_attn_3_varlen_func = None  | 
80 | 84 | 
 
  | 
 | 85 | + | 
 | 86 | +if _CAN_USE_AITER_ATTN:  | 
 | 87 | +    from aiter import flash_attn_func as aiter_flash_attn_func  | 
 | 88 | +else:  | 
 | 89 | +    aiter_flash_attn_func = None  | 
 | 90 | + | 
81 | 91 | if DIFFUSERS_ENABLE_HUB_KERNELS:  | 
82 | 92 |     if not is_kernels_available():  | 
83 | 93 |         raise ImportError(  | 
@@ -178,6 +188,9 @@ class AttentionBackendName(str, Enum):  | 
178 | 188 |     _FLASH_3_HUB = "_flash_3_hub"  | 
179 | 189 |     # _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub"  # not supported yet.  | 
180 | 190 | 
 
  | 
 | 191 | +    # `aiter`  | 
 | 192 | +    AITER = "aiter"  | 
 | 193 | + | 
181 | 194 |     # PyTorch native  | 
182 | 195 |     FLEX = "flex"  | 
183 | 196 |     NATIVE = "native"  | 
@@ -414,6 +427,12 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None  | 
414 | 427 |                 f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."  | 
415 | 428 |             )  | 
416 | 429 | 
 
  | 
 | 430 | +    elif backend == AttentionBackendName.AITER:  | 
 | 431 | +        if not _CAN_USE_AITER_ATTN:  | 
 | 432 | +            raise RuntimeError(  | 
 | 433 | +                f"Aiter Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `aiter>={_REQUIRED_AITER_VERSION}`."  | 
 | 434 | +            )  | 
 | 435 | + | 
417 | 436 |     elif backend in [  | 
418 | 437 |         AttentionBackendName.SAGE,  | 
419 | 438 |         AttentionBackendName.SAGE_VARLEN,  | 
@@ -1397,6 +1416,47 @@ def _flash_varlen_attention_3(  | 
1397 | 1416 |     return (out, lse) if return_lse else out  | 
1398 | 1417 | 
 
  | 
1399 | 1418 | 
 
  | 
 | 1419 | +@_AttentionBackendRegistry.register(  | 
 | 1420 | +    AttentionBackendName.AITER,  | 
 | 1421 | +    constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],  | 
 | 1422 | +)  | 
 | 1423 | +def _aiter_flash_attention(  | 
 | 1424 | +    query: torch.Tensor,  | 
 | 1425 | +    key: torch.Tensor,  | 
 | 1426 | +    value: torch.Tensor,  | 
 | 1427 | +    dropout_p: float = 0.0,  | 
 | 1428 | +    is_causal: bool = False,  | 
 | 1429 | +    scale: Optional[float] = None,  | 
 | 1430 | +    return_lse: bool = False,  | 
 | 1431 | +    _parallel_config: Optional["ParallelConfig"] = None,  | 
 | 1432 | +) -> torch.Tensor:  | 
 | 1433 | +    if not return_lse and torch.is_grad_enabled():  | 
 | 1434 | +        # aiter requires return_lse=True by assertion when gradients are enabled.  | 
 | 1435 | +        out, lse, *_ = aiter_flash_attn_func(  | 
 | 1436 | +            q=query,  | 
 | 1437 | +            k=key,  | 
 | 1438 | +            v=value,  | 
 | 1439 | +            dropout_p=dropout_p,  | 
 | 1440 | +            softmax_scale=scale,  | 
 | 1441 | +            causal=is_causal,  | 
 | 1442 | +            return_lse=True,  | 
 | 1443 | +        )  | 
 | 1444 | +    else:  | 
 | 1445 | +        out = aiter_flash_attn_func(  | 
 | 1446 | +            q=query,  | 
 | 1447 | +            k=key,  | 
 | 1448 | +            v=value,  | 
 | 1449 | +            dropout_p=dropout_p,  | 
 | 1450 | +            softmax_scale=scale,  | 
 | 1451 | +            causal=is_causal,  | 
 | 1452 | +            return_lse=return_lse,  | 
 | 1453 | +        )  | 
 | 1454 | +        if return_lse:  | 
 | 1455 | +            out, lse, *_ = out  | 
 | 1456 | + | 
 | 1457 | +    return (out, lse) if return_lse else out  | 
 | 1458 | + | 
 | 1459 | + | 
1400 | 1460 | @_AttentionBackendRegistry.register(  | 
1401 | 1461 |     AttentionBackendName.FLEX,  | 
1402 | 1462 |     constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],  | 
 | 
0 commit comments