|
17 | 17 | import inspect |
18 | 18 | import math |
19 | 19 | from enum import Enum |
20 | | -from functools import lru_cache |
21 | 20 | from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union |
22 | 21 |
|
23 | 22 | import torch |
@@ -145,6 +144,9 @@ def wrap(func): |
145 | 144 | _SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] |
146 | 145 | _SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] |
147 | 146 |
|
| 147 | +flash_attn_3_hub_func = None |
| 148 | +__fa3_hub_loaded = False |
| 149 | + |
148 | 150 |
|
149 | 151 | class AttentionBackendName(str, Enum): |
150 | 152 | # EAGER = "eager" |
@@ -210,20 +212,20 @@ def list_backends(cls): |
210 | 212 | return list(cls._backends.keys()) |
211 | 213 |
|
212 | 214 |
|
213 | | -@lru_cache(maxsize=None) |
214 | | -def _load_fa3_hub(): |
| 215 | +def _ensure_fa3_hub_loaded(): |
| 216 | + global __fa3_hub_loaded |
| 217 | + if __fa3_hub_loaded: |
| 218 | + return |
215 | 219 | from ..utils.kernels_utils import _get_fa3_from_hub |
216 | 220 |
|
217 | | - fa3_hub = _get_fa3_from_hub() # won't re-download if already present |
218 | | - if fa3_hub is None: |
| 221 | + fa3_hub_module = _get_fa3_from_hub() # doesn't retrigger download if already available. |
| 222 | + if fa3_hub_module is None: |
219 | 223 | raise RuntimeError( |
220 | 224 | "Failed to load FlashAttention-3 kernels from the Hub. Please ensure the wheel is available for your platform." |
221 | 225 | ) |
222 | | - return fa3_hub |
223 | | - |
224 | | - |
225 | | -def flash_attn_3_hub_func(*args, **kwargs): |
226 | | - return _load_fa3_hub().flash_attn_func(*args, **kwargs) |
| 226 | + global flash_attn_3_hub_func |
| 227 | + flash_attn_3_hub_func = fa3_hub_module.flash_attn_func |
| 228 | + __fa3_hub_loaded = True |
227 | 229 |
|
228 | 230 |
|
229 | 231 | @contextlib.contextmanager |
@@ -540,20 +542,20 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc |
540 | 542 | return torch.empty_like(query), query.new_empty(lse_shape) |
541 | 543 |
|
542 | 544 |
|
543 | | -@_custom_op("vllm_flash_attn3::flash_attn", mutates_args=(), device_types="cuda") |
544 | | -def _wrapped_flash_attn_3_hub( |
545 | | - query: torch.Tensor, key: torch.Tensor, value: torch.Tensor |
546 | | -) -> Tuple[torch.Tensor, torch.Tensor]: |
547 | | - out, lse = flash_attn_3_hub_func(query, key, value) |
548 | | - lse = lse.permute(0, 2, 1) |
549 | | - return out, lse |
| 545 | +# @_custom_op("vllm_flash_attn3::flash_attn", mutates_args=(), device_types="cuda") |
| 546 | +# def _wrapped_flash_attn_3_hub( |
| 547 | +# query: torch.Tensor, key: torch.Tensor, value: torch.Tensor |
| 548 | +# ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 549 | +# out, lse = flash_attn_3_hub_func(query, key, value) |
| 550 | +# lse = lse.permute(0, 2, 1) |
| 551 | +# return out, lse |
550 | 552 |
|
551 | 553 |
|
552 | | -@_register_fake("vllm_flash_attn3::flash_attn") |
553 | | -def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
554 | | - batch_size, seq_len, num_heads, head_dim = query.shape |
555 | | - lse_shape = (batch_size, seq_len, num_heads) |
556 | | - return torch.empty_like(query), query.new_empty(lse_shape) |
| 554 | +# @_register_fake("vllm_flash_attn3::flash_attn") |
| 555 | +# def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| 556 | +# batch_size, seq_len, num_heads, head_dim = query.shape |
| 557 | +# lse_shape = (batch_size, seq_len, num_heads) |
| 558 | +# return torch.empty_like(query), query.new_empty(lse_shape) |
557 | 559 |
|
558 | 560 |
|
559 | 561 | # ===== Attention backends ===== |
|
0 commit comments