|
30 | 30 | is_flash_attn_3_available, |
31 | 31 | is_flash_attn_available, |
32 | 32 | is_flash_attn_version, |
| 33 | + is_aiter_available, |
| 34 | + is_aiter_version, |
33 | 35 | is_kernels_available, |
34 | 36 | is_sageattention_available, |
35 | 37 | is_sageattention_version, |
|
47 | 49 | from ._modeling_parallel import ParallelConfig |
48 | 50 |
|
49 | 51 | _REQUIRED_FLASH_VERSION = "2.6.3" |
| 52 | +_REQUIRED_AITER_VERSION = "0.1.5" |
50 | 53 | _REQUIRED_SAGE_VERSION = "2.1.1" |
51 | 54 | _REQUIRED_FLEX_VERSION = "2.5.0" |
52 | 55 | _REQUIRED_XLA_VERSION = "2.2" |
53 | 56 | _REQUIRED_XFORMERS_VERSION = "0.0.29" |
54 | 57 |
|
55 | 58 | _CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION) |
56 | 59 | _CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available() |
| 60 | +_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION) |
57 | 61 | _CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION) |
58 | 62 | _CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION) |
59 | 63 | _CAN_USE_NPU_ATTN = is_torch_npu_available() |
|
78 | 82 | flash_attn_3_func = None |
79 | 83 | flash_attn_3_varlen_func = None |
80 | 84 |
|
| 85 | + |
| 86 | +if _CAN_USE_AITER_ATTN: |
| 87 | + from aiter import flash_attn_func as aiter_flash_attn_func |
| 88 | +else: |
| 89 | + aiter_flash_attn_func = None |
| 90 | + |
81 | 91 | if DIFFUSERS_ENABLE_HUB_KERNELS: |
82 | 92 | if not is_kernels_available(): |
83 | 93 | raise ImportError( |
@@ -178,6 +188,9 @@ class AttentionBackendName(str, Enum): |
178 | 188 | _FLASH_3_HUB = "_flash_3_hub" |
179 | 189 | # _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet. |
180 | 190 |
|
| 191 | + # `aiter` |
| 192 | + AITER = "aiter" |
| 193 | + |
181 | 194 | # PyTorch native |
182 | 195 | FLEX = "flex" |
183 | 196 | NATIVE = "native" |
@@ -414,6 +427,12 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None |
414 | 427 | f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." |
415 | 428 | ) |
416 | 429 |
|
| 430 | + elif backend == AttentionBackendName.AITER: |
| 431 | + if not _CAN_USE_AITER_ATTN: |
| 432 | + raise RuntimeError( |
| 433 | + f"Aiter Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `aiter>={_REQUIRED_AITER_VERSION}`." |
| 434 | + ) |
| 435 | + |
417 | 436 | elif backend in [ |
418 | 437 | AttentionBackendName.SAGE, |
419 | 438 | AttentionBackendName.SAGE_VARLEN, |
@@ -1397,6 +1416,47 @@ def _flash_varlen_attention_3( |
1397 | 1416 | return (out, lse) if return_lse else out |
1398 | 1417 |
|
1399 | 1418 |
|
| 1419 | +@_AttentionBackendRegistry.register( |
| 1420 | + AttentionBackendName.AITER, |
| 1421 | + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
| 1422 | +) |
| 1423 | +def _aiter_flash_attention( |
| 1424 | + query: torch.Tensor, |
| 1425 | + key: torch.Tensor, |
| 1426 | + value: torch.Tensor, |
| 1427 | + dropout_p: float = 0.0, |
| 1428 | + is_causal: bool = False, |
| 1429 | + scale: Optional[float] = None, |
| 1430 | + return_lse: bool = False, |
| 1431 | + _parallel_config: Optional["ParallelConfig"] = None, |
| 1432 | +) -> torch.Tensor: |
| 1433 | + if not return_lse and torch.is_grad_enabled(): |
| 1434 | + # aiter requires return_lse=True by assertion when gradients are enabled. |
| 1435 | + out, lse, *_ = aiter_flash_attn_func( |
| 1436 | + q=query, |
| 1437 | + k=key, |
| 1438 | + v=value, |
| 1439 | + dropout_p=dropout_p, |
| 1440 | + softmax_scale=scale, |
| 1441 | + causal=is_causal, |
| 1442 | + return_lse=True, |
| 1443 | + ) |
| 1444 | + else: |
| 1445 | + out = aiter_flash_attn_func( |
| 1446 | + q=query, |
| 1447 | + k=key, |
| 1448 | + v=value, |
| 1449 | + dropout_p=dropout_p, |
| 1450 | + softmax_scale=scale, |
| 1451 | + causal=is_causal, |
| 1452 | + return_lse=return_lse, |
| 1453 | + ) |
| 1454 | + if return_lse: |
| 1455 | + out, lse, *_ = out |
| 1456 | + |
| 1457 | + return (out, lse) if return_lse else out |
| 1458 | + |
| 1459 | + |
1400 | 1460 | @_AttentionBackendRegistry.register( |
1401 | 1461 | AttentionBackendName.FLEX, |
1402 | 1462 | constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], |
|
0 commit comments