|  | 
| 38 | 38 | from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS | 
| 39 | 39 | 
 | 
| 40 | 40 | 
 | 
| 41 |  | -logger = get_logger(__name__)  # pylint: disable=invalid-name | 
| 42 |  | - | 
| 43 |  | - | 
| 44 |  | -if is_flash_attn_available() and is_flash_attn_version(">=", "2.6.3"): | 
|  | 41 | +_REQUIRED_FLASH_VERSION = "2.6.3" | 
|  | 42 | +_REQUIRED_SAGE_VERSION = "2.1.1" | 
|  | 43 | +_REQUIRED_FLEX_VERSION = "2.5.0" | 
|  | 44 | +_REQUIRED_XLA_VERSION = "2.2" | 
|  | 45 | +_REQUIRED_XFORMERS_VERSION = "0.0.29" | 
|  | 46 | + | 
|  | 47 | +_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION) | 
|  | 48 | +_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available() | 
|  | 49 | +_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION) | 
|  | 50 | +_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION) | 
|  | 51 | +_CAN_USE_NPU_ATTN = is_torch_npu_available() | 
|  | 52 | +_CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION) | 
|  | 53 | +_CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION) | 
|  | 54 | + | 
|  | 55 | + | 
|  | 56 | +if _CAN_USE_FLASH_ATTN: | 
| 45 | 57 |     from flash_attn import flash_attn_func, flash_attn_varlen_func | 
| 46 | 58 | else: | 
| 47 |  | -    logger.warning("`flash-attn` is not available or the version is too old. Please install `flash-attn>=2.6.3`.") | 
| 48 | 59 |     flash_attn_func = None | 
| 49 | 60 |     flash_attn_varlen_func = None | 
| 50 | 61 | 
 | 
| 51 | 62 | 
 | 
| 52 |  | -if is_flash_attn_3_available(): | 
|  | 63 | +if _CAN_USE_FLASH_ATTN_3: | 
| 53 | 64 |     from flash_attn_interface import flash_attn_func as flash_attn_3_func | 
| 54 | 65 |     from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func | 
| 55 | 66 | else: | 
| 56 | 67 |     flash_attn_3_func = None | 
| 57 | 68 |     flash_attn_3_varlen_func = None | 
| 58 | 69 | 
 | 
| 59 | 70 | 
 | 
| 60 |  | -if is_sageattention_available() and is_sageattention_version(">=", "2.1.1"): | 
|  | 71 | +if _CAN_USE_SAGE_ATTN: | 
| 61 | 72 |     from sageattention import ( | 
| 62 | 73 |         sageattn, | 
| 63 | 74 |         sageattn_qk_int8_pv_fp8_cuda, | 
|  | 
| 67 | 78 |         sageattn_varlen, | 
| 68 | 79 |     ) | 
| 69 | 80 | else: | 
| 70 |  | -    logger.warning( | 
| 71 |  | -        "`sageattention` is not available or the version is too old. Please install `sageattention>=2.1.1`." | 
| 72 |  | -    ) | 
| 73 | 81 |     sageattn = None | 
| 74 | 82 |     sageattn_qk_int8_pv_fp16_cuda = None | 
| 75 | 83 |     sageattn_qk_int8_pv_fp16_triton = None | 
|  | 
| 78 | 86 |     sageattn_varlen = None | 
| 79 | 87 | 
 | 
| 80 | 88 | 
 | 
| 81 |  | -if is_torch_version(">=", "2.5.0"): | 
|  | 89 | +if _CAN_USE_FLEX_ATTN: | 
| 82 | 90 |     # We cannot import the flex_attention function from the package directly because it is expected (from the | 
| 83 | 91 |     # pytorch documentation) that the user may compile it. If we import directly, we will not have access to the | 
| 84 | 92 |     # compiled function. | 
| 85 | 93 |     import torch.nn.attention.flex_attention as flex_attention | 
| 86 | 94 | 
 | 
| 87 | 95 | 
 | 
| 88 |  | -if is_torch_npu_available(): | 
|  | 96 | +if _CAN_USE_NPU_ATTN: | 
| 89 | 97 |     from torch_npu import npu_fusion_attention | 
| 90 | 98 | else: | 
| 91 | 99 |     npu_fusion_attention = None | 
| 92 | 100 | 
 | 
| 93 | 101 | 
 | 
| 94 |  | -if is_torch_xla_available() and is_torch_xla_version(">", "2.2"): | 
|  | 102 | +if _CAN_USE_XLA_ATTN: | 
| 95 | 103 |     from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention | 
| 96 | 104 | else: | 
| 97 | 105 |     xla_flash_attention = None | 
| 98 | 106 | 
 | 
| 99 | 107 | 
 | 
| 100 |  | -if is_xformers_available() and is_xformers_version(">=", "0.0.29"): | 
|  | 108 | +if _CAN_USE_XFORMERS_ATTN: | 
| 101 | 109 |     import xformers.ops as xops | 
| 102 | 110 | else: | 
| 103 |  | -    logger.warning("`xformers` is not available or the version is too old. Please install `xformers>=0.0.29`.") | 
| 104 | 111 |     xops = None | 
| 105 | 112 | 
 | 
| 106 | 113 | 
 | 
|  | 114 | +logger = get_logger(__name__)  # pylint: disable=invalid-name | 
|  | 115 | + | 
| 107 | 116 | # TODO(aryan): Add support for the following: | 
| 108 | 117 | # - Sage Attention++ | 
| 109 | 118 | # - block sparse, radial and other attention methods | 
| 110 | 119 | # - CP with sage attention, flex, xformers, other missing backends | 
| 111 | 120 | # - Add support for normal and CP training with backends that don't support it yet | 
| 112 | 121 | 
 | 
| 113 |  | - | 
| 114 | 122 | _SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"] | 
| 115 | 123 | _SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] | 
| 116 | 124 | _SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] | 
| @@ -179,13 +187,16 @@ def list_backends(cls): | 
| 179 | 187 | 
 | 
| 180 | 188 | 
 | 
| 181 | 189 | @contextlib.contextmanager | 
| 182 |  | -def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIVE): | 
|  | 190 | +def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE): | 
| 183 | 191 |     """ | 
| 184 | 192 |     Context manager to set the active attention backend. | 
| 185 | 193 |     """ | 
| 186 | 194 |     if backend not in _AttentionBackendRegistry._backends: | 
| 187 | 195 |         raise ValueError(f"Backend {backend} is not registered.") | 
| 188 | 196 | 
 | 
|  | 197 | +    backend = AttentionBackendName(backend) | 
|  | 198 | +    _check_attention_backend_requirements(backend) | 
|  | 199 | + | 
| 189 | 200 |     old_backend = _AttentionBackendRegistry._active_backend | 
| 190 | 201 |     _AttentionBackendRegistry._active_backend = backend | 
| 191 | 202 | 
 | 
| @@ -226,9 +237,10 @@ def dispatch_attention_fn( | 
| 226 | 237 |         "dropout_p": dropout_p, | 
| 227 | 238 |         "is_causal": is_causal, | 
| 228 | 239 |         "scale": scale, | 
| 229 |  | -        "enable_gqa": enable_gqa, | 
| 230 | 240 |         **attention_kwargs, | 
| 231 | 241 |     } | 
|  | 242 | +    if is_torch_version(">=", "2.5.0"): | 
|  | 243 | +        kwargs["enable_gqa"] = enable_gqa | 
| 232 | 244 | 
 | 
| 233 | 245 |     if _AttentionBackendRegistry._checks_enabled: | 
| 234 | 246 |         removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name]) | 
| @@ -305,6 +317,57 @@ def _check_shape( | 
| 305 | 317 | # ===== Helper functions ===== | 
| 306 | 318 | 
 | 
| 307 | 319 | 
 | 
|  | 320 | +def _check_attention_backend_requirements(backend: AttentionBackendName) -> None: | 
|  | 321 | +    if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]: | 
|  | 322 | +        if not _CAN_USE_FLASH_ATTN: | 
|  | 323 | +            raise RuntimeError( | 
|  | 324 | +                f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`." | 
|  | 325 | +            ) | 
|  | 326 | + | 
|  | 327 | +    elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]: | 
|  | 328 | +        if not _CAN_USE_FLASH_ATTN_3: | 
|  | 329 | +            raise RuntimeError( | 
|  | 330 | +                f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source." | 
|  | 331 | +            ) | 
|  | 332 | + | 
|  | 333 | +    elif backend in [ | 
|  | 334 | +        AttentionBackendName.SAGE, | 
|  | 335 | +        AttentionBackendName.SAGE_VARLEN, | 
|  | 336 | +        AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA, | 
|  | 337 | +        AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90, | 
|  | 338 | +        AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA, | 
|  | 339 | +        AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON, | 
|  | 340 | +    ]: | 
|  | 341 | +        if not _CAN_USE_SAGE_ATTN: | 
|  | 342 | +            raise RuntimeError( | 
|  | 343 | +                f"Sage Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `sageattention>={_REQUIRED_SAGE_VERSION}`." | 
|  | 344 | +            ) | 
|  | 345 | + | 
|  | 346 | +    elif backend == AttentionBackendName.FLEX: | 
|  | 347 | +        if not _CAN_USE_FLEX_ATTN: | 
|  | 348 | +            raise RuntimeError( | 
|  | 349 | +                f"Flex Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch>=2.5.0`." | 
|  | 350 | +            ) | 
|  | 351 | + | 
|  | 352 | +    elif backend == AttentionBackendName._NATIVE_NPU: | 
|  | 353 | +        if not _CAN_USE_NPU_ATTN: | 
|  | 354 | +            raise RuntimeError( | 
|  | 355 | +                f"NPU Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_npu`." | 
|  | 356 | +            ) | 
|  | 357 | + | 
|  | 358 | +    elif backend == AttentionBackendName._NATIVE_XLA: | 
|  | 359 | +        if not _CAN_USE_XLA_ATTN: | 
|  | 360 | +            raise RuntimeError( | 
|  | 361 | +                f"XLA Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_xla>={_REQUIRED_XLA_VERSION}`." | 
|  | 362 | +            ) | 
|  | 363 | + | 
|  | 364 | +    elif backend == AttentionBackendName.XFORMERS: | 
|  | 365 | +        if not _CAN_USE_XFORMERS_ATTN: | 
|  | 366 | +            raise RuntimeError( | 
|  | 367 | +                f"Xformers Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `xformers>={_REQUIRED_XFORMERS_VERSION}`." | 
|  | 368 | +            ) | 
|  | 369 | + | 
|  | 370 | + | 
| 308 | 371 | @functools.lru_cache(maxsize=128) | 
| 309 | 372 | def _prepare_for_flash_attn_or_sage_varlen_without_mask( | 
| 310 | 373 |     batch_size: int, | 
|  | 
0 commit comments