|
80 | 80 | raise ImportError( |
81 | 81 | "To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`." |
82 | 82 | ) |
83 | | - from ..utils.kernels_utils import _get_fa3_from_hub |
| 83 | + from ..utils.kernels_utils import _get_fa3_from_hub, get_fa_from_hub |
84 | 84 |
|
85 | | - flash_attn_interface_hub = _get_fa3_from_hub() |
86 | | - flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func |
| 85 | + fa3_interface_hub = _get_fa3_from_hub() |
| 86 | + flash_attn_3_func_hub = fa3_interface_hub.flash_attn_func |
| 87 | + fa_interface_hub = get_fa_from_hub() |
| 88 | + flash_attn_func_hub = fa_interface_hub.flash_attn_func |
87 | 89 | else: |
88 | 90 | flash_attn_3_func_hub = None |
| 91 | + flash_attn_func_hub = None |
89 | 92 |
|
90 | 93 | if _CAN_USE_SAGE_ATTN: |
91 | 94 | from sageattention import ( |
@@ -170,6 +173,8 @@ class AttentionBackendName(str, Enum): |
170 | 173 | # `flash-attn` |
171 | 174 | FLASH = "flash" |
172 | 175 | FLASH_VARLEN = "flash_varlen" |
| 176 | + FLASH_HUB = "flash_hub" |
| 177 | + # FLASH_VARLEN_HUB = "flash_varlen_hub" # not supported yet. |
173 | 178 | _FLASH_3 = "_flash_3" |
174 | 179 | _FLASH_VARLEN_3 = "_flash_varlen_3" |
175 | 180 | _FLASH_3_HUB = "_flash_3_hub" |
@@ -400,15 +405,15 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None |
400 | 405 | f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source." |
401 | 406 | ) |
402 | 407 |
|
403 | | - # TODO: add support Hub variant of FA3 varlen later |
404 | | - elif backend in [AttentionBackendName._FLASH_3_HUB]: |
| 408 | + # TODO: add support Hub variant of FA and FA3 varlen later |
| 409 | + elif backend in [AttentionBackendName.FLASH_HUB, AttentionBackendName._FLASH_3_HUB]: |
405 | 410 | if not DIFFUSERS_ENABLE_HUB_KERNELS: |
406 | 411 | raise RuntimeError( |
407 | | - f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`." |
| 412 | + f"Flash Attention Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`." |
408 | 413 | ) |
409 | 414 | if not is_kernels_available(): |
410 | 415 | raise RuntimeError( |
411 | | - f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." |
| 416 | + f"Flash Attention Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." |
412 | 417 | ) |
413 | 418 |
|
414 | 419 | elif backend in [ |
@@ -1225,6 +1230,36 @@ def _flash_attention( |
1225 | 1230 | return (out, lse) if return_lse else out |
1226 | 1231 |
|
1227 | 1232 |
|
| 1233 | +@_AttentionBackendRegistry.register( |
| 1234 | + AttentionBackendName.FLASH_HUB, |
| 1235 | + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
| 1236 | +) |
| 1237 | +def _flash_attention_hub( |
| 1238 | + query: torch.Tensor, |
| 1239 | + key: torch.Tensor, |
| 1240 | + value: torch.Tensor, |
| 1241 | + dropout_p: float = 0.0, |
| 1242 | + is_causal: bool = False, |
| 1243 | + scale: Optional[float] = None, |
| 1244 | + return_lse: bool = False, |
| 1245 | + _parallel_config: Optional["ParallelConfig"] = None, |
| 1246 | +) -> torch.Tensor: |
| 1247 | + lse = None |
| 1248 | + out = flash_attn_func( |
| 1249 | + q=query, |
| 1250 | + k=key, |
| 1251 | + v=value, |
| 1252 | + dropout_p=dropout_p, |
| 1253 | + softmax_scale=scale, |
| 1254 | + causal=is_causal, |
| 1255 | + return_attn_probs=return_lse, |
| 1256 | + ) |
| 1257 | + if return_lse: |
| 1258 | + out, lse, *_ = out |
| 1259 | + |
| 1260 | + return (out, lse) if return_lse else out |
| 1261 | + |
| 1262 | + |
1228 | 1263 | @_AttentionBackendRegistry.register( |
1229 | 1264 | AttentionBackendName.FLASH_VARLEN, |
1230 | 1265 | constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
|
0 commit comments