|
38 | 38 | from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS |
39 | 39 |
|
40 | 40 |
|
41 | | -logger = get_logger(__name__) # pylint: disable=invalid-name |
42 | | - |
43 | | - |
44 | | -if is_flash_attn_available() and is_flash_attn_version(">=", "2.6.3"): |
| 41 | +_REQUIRED_FLASH_VERSION = "2.6.3" |
| 42 | +_REQUIRED_SAGE_VERSION = "2.1.1" |
| 43 | +_REQUIRED_FLEX_VERSION = "2.5.0" |
| 44 | +_REQUIRED_XLA_VERSION = "2.2" |
| 45 | +_REQUIRED_XFORMERS_VERSION = "0.0.29" |
| 46 | + |
| 47 | +_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION) |
| 48 | +_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available() |
| 49 | +_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION) |
| 50 | +_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION) |
| 51 | +_CAN_USE_NPU_ATTN = is_torch_npu_available() |
| 52 | +_CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION) |
| 53 | +_CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION) |
| 54 | + |
| 55 | + |
| 56 | +if _CAN_USE_FLASH_ATTN: |
45 | 57 | from flash_attn import flash_attn_func, flash_attn_varlen_func |
46 | 58 | else: |
47 | | - logger.warning("`flash-attn` is not available or the version is too old. Please install `flash-attn>=2.6.3`.") |
48 | 59 | flash_attn_func = None |
49 | 60 | flash_attn_varlen_func = None |
50 | 61 |
|
51 | 62 |
|
52 | | -if is_flash_attn_3_available(): |
| 63 | +if _CAN_USE_FLASH_ATTN_3: |
53 | 64 | from flash_attn_interface import flash_attn_func as flash_attn_3_func |
54 | 65 | from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func |
55 | 66 | else: |
56 | 67 | flash_attn_3_func = None |
57 | 68 | flash_attn_3_varlen_func = None |
58 | 69 |
|
59 | 70 |
|
60 | | -if is_sageattention_available() and is_sageattention_version(">=", "2.1.1"): |
| 71 | +if _CAN_USE_SAGE_ATTN: |
61 | 72 | from sageattention import ( |
62 | 73 | sageattn, |
63 | 74 | sageattn_qk_int8_pv_fp8_cuda, |
|
67 | 78 | sageattn_varlen, |
68 | 79 | ) |
69 | 80 | else: |
70 | | - logger.warning( |
71 | | - "`sageattention` is not available or the version is too old. Please install `sageattention>=2.1.1`." |
72 | | - ) |
73 | 81 | sageattn = None |
74 | 82 | sageattn_qk_int8_pv_fp16_cuda = None |
75 | 83 | sageattn_qk_int8_pv_fp16_triton = None |
|
78 | 86 | sageattn_varlen = None |
79 | 87 |
|
80 | 88 |
|
81 | | -if is_torch_version(">=", "2.5.0"): |
| 89 | +if _CAN_USE_FLEX_ATTN: |
82 | 90 | # We cannot import the flex_attention function from the package directly because it is expected (from the |
83 | 91 | # pytorch documentation) that the user may compile it. If we import directly, we will not have access to the |
84 | 92 | # compiled function. |
85 | 93 | import torch.nn.attention.flex_attention as flex_attention |
86 | 94 |
|
87 | 95 |
|
88 | | -if is_torch_npu_available(): |
| 96 | +if _CAN_USE_NPU_ATTN: |
89 | 97 | from torch_npu import npu_fusion_attention |
90 | 98 | else: |
91 | 99 | npu_fusion_attention = None |
92 | 100 |
|
93 | 101 |
|
94 | | -if is_torch_xla_available() and is_torch_xla_version(">", "2.2"): |
| 102 | +if _CAN_USE_XLA_ATTN: |
95 | 103 | from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention |
96 | 104 | else: |
97 | 105 | xla_flash_attention = None |
98 | 106 |
|
99 | 107 |
|
100 | | -if is_xformers_available() and is_xformers_version(">=", "0.0.29"): |
| 108 | +if _CAN_USE_XFORMERS_ATTN: |
101 | 109 | import xformers.ops as xops |
102 | 110 | else: |
103 | | - logger.warning("`xformers` is not available or the version is too old. Please install `xformers>=0.0.29`.") |
104 | 111 | xops = None |
105 | 112 |
|
106 | 113 |
|
| 114 | +logger = get_logger(__name__) # pylint: disable=invalid-name |
| 115 | + |
107 | 116 | # TODO(aryan): Add support for the following: |
108 | 117 | # - Sage Attention++ |
109 | 118 | # - block sparse, radial and other attention methods |
110 | 119 | # - CP with sage attention, flex, xformers, other missing backends |
111 | 120 | # - Add support for normal and CP training with backends that don't support it yet |
112 | 121 |
|
113 | | - |
114 | 122 | _SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"] |
115 | 123 | _SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] |
116 | 124 | _SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] |
@@ -171,6 +179,7 @@ def decorator(func): |
171 | 179 |
|
172 | 180 | @classmethod |
173 | 181 | def get_active_backend(cls): |
| 182 | + _check_backend_requirements(cls._active_backend) |
174 | 183 | return cls._active_backend, cls._backends[cls._active_backend] |
175 | 184 |
|
176 | 185 | @classmethod |
@@ -226,9 +235,10 @@ def dispatch_attention_fn( |
226 | 235 | "dropout_p": dropout_p, |
227 | 236 | "is_causal": is_causal, |
228 | 237 | "scale": scale, |
229 | | - "enable_gqa": enable_gqa, |
230 | 238 | **attention_kwargs, |
231 | 239 | } |
| 240 | + if is_torch_version(">=", "2.5.0"): |
| 241 | + kwargs["enable_gqa"] = enable_gqa |
232 | 242 |
|
233 | 243 | if _AttentionBackendRegistry._checks_enabled: |
234 | 244 | removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name]) |
@@ -305,6 +315,60 @@ def _check_shape( |
305 | 315 | # ===== Helper functions ===== |
306 | 316 |
|
307 | 317 |
|
| 318 | +# LRU cache is hack to avoid checking the backend requirements multiple times. Maybe not needed |
| 319 | +# because CPU is running much farther ahead of the accelerator and this will not be blocking anyway. |
| 320 | +@functools.lru_cache(maxsize=16) |
| 321 | +def _check_backend_requirements(backend: AttentionBackendName) -> None: |
| 322 | + if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]: |
| 323 | + if not _CAN_USE_FLASH_ATTN: |
| 324 | + raise RuntimeError( |
| 325 | + f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`." |
| 326 | + ) |
| 327 | + |
| 328 | + elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]: |
| 329 | + if not _CAN_USE_FLASH_ATTN_3: |
| 330 | + raise RuntimeError( |
| 331 | + f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source." |
| 332 | + ) |
| 333 | + |
| 334 | + elif backend in [ |
| 335 | + AttentionBackendName.SAGE, |
| 336 | + AttentionBackendName.SAGE_VARLEN, |
| 337 | + AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA, |
| 338 | + AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90, |
| 339 | + AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA, |
| 340 | + AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON, |
| 341 | + ]: |
| 342 | + if not _CAN_USE_SAGE_ATTN: |
| 343 | + raise RuntimeError( |
| 344 | + f"Sage Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `sageattention>={_REQUIRED_SAGE_VERSION}`." |
| 345 | + ) |
| 346 | + |
| 347 | + elif backend == AttentionBackendName.FLEX: |
| 348 | + if not _CAN_USE_FLEX_ATTN: |
| 349 | + raise RuntimeError( |
| 350 | + f"Flex Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch>=2.5.0`." |
| 351 | + ) |
| 352 | + |
| 353 | + elif backend == AttentionBackendName._NATIVE_NPU: |
| 354 | + if not _CAN_USE_NPU_ATTN: |
| 355 | + raise RuntimeError( |
| 356 | + f"NPU Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_npu`." |
| 357 | + ) |
| 358 | + |
| 359 | + elif backend == AttentionBackendName._NATIVE_XLA: |
| 360 | + if not _CAN_USE_XLA_ATTN: |
| 361 | + raise RuntimeError( |
| 362 | + f"XLA Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_xla>={_REQUIRED_XLA_VERSION}`." |
| 363 | + ) |
| 364 | + |
| 365 | + elif backend == AttentionBackendName.XFORMERS: |
| 366 | + if not _CAN_USE_XFORMERS_ATTN: |
| 367 | + raise RuntimeError( |
| 368 | + f"Xformers Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `xformers>={_REQUIRED_XFORMERS_VERSION}`." |
| 369 | + ) |
| 370 | + |
| 371 | + |
308 | 372 | @functools.lru_cache(maxsize=128) |
309 | 373 | def _prepare_for_flash_attn_or_sage_varlen_without_mask( |
310 | 374 | batch_size: int, |
|
0 commit comments