Skip to content

Commit 6e9f81f

Browse files
committed
up
1 parent 548f56e commit 6e9f81f

File tree

3 files changed

+35
-25
lines changed

3 files changed

+35
-25
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import inspect
1818
import math
1919
from enum import Enum
20-
from functools import lru_cache
2120
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
2221

2322
import torch
@@ -145,6 +144,9 @@ def wrap(func):
145144
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
146145
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
147146

147+
flash_attn_3_hub_func = None
148+
__fa3_hub_loaded = False
149+
148150

149151
class AttentionBackendName(str, Enum):
150152
# EAGER = "eager"
@@ -210,20 +212,20 @@ def list_backends(cls):
210212
return list(cls._backends.keys())
211213

212214

213-
@lru_cache(maxsize=None)
214-
def _load_fa3_hub():
215+
def _ensure_fa3_hub_loaded():
216+
global __fa3_hub_loaded
217+
if __fa3_hub_loaded:
218+
return
215219
from ..utils.kernels_utils import _get_fa3_from_hub
216220

217-
fa3_hub = _get_fa3_from_hub() # won't re-download if already present
218-
if fa3_hub is None:
221+
fa3_hub_module = _get_fa3_from_hub() # doesn't retrigger download if already available.
222+
if fa3_hub_module is None:
219223
raise RuntimeError(
220224
"Failed to load FlashAttention-3 kernels from the Hub. Please ensure the wheel is available for your platform."
221225
)
222-
return fa3_hub
223-
224-
225-
def flash_attn_3_hub_func(*args, **kwargs):
226-
return _load_fa3_hub().flash_attn_func(*args, **kwargs)
226+
global flash_attn_3_hub_func
227+
flash_attn_3_hub_func = fa3_hub_module.flash_attn_func
228+
__fa3_hub_loaded = True
227229

228230

229231
@contextlib.contextmanager
@@ -540,20 +542,20 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc
540542
return torch.empty_like(query), query.new_empty(lse_shape)
541543

542544

543-
@_custom_op("vllm_flash_attn3::flash_attn", mutates_args=(), device_types="cuda")
544-
def _wrapped_flash_attn_3_hub(
545-
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
546-
) -> Tuple[torch.Tensor, torch.Tensor]:
547-
out, lse = flash_attn_3_hub_func(query, key, value)
548-
lse = lse.permute(0, 2, 1)
549-
return out, lse
545+
# @_custom_op("vllm_flash_attn3::flash_attn", mutates_args=(), device_types="cuda")
546+
# def _wrapped_flash_attn_3_hub(
547+
# query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
548+
# ) -> Tuple[torch.Tensor, torch.Tensor]:
549+
# out, lse = flash_attn_3_hub_func(query, key, value)
550+
# lse = lse.permute(0, 2, 1)
551+
# return out, lse
550552

551553

552-
@_register_fake("vllm_flash_attn3::flash_attn")
553-
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
554-
batch_size, seq_len, num_heads, head_dim = query.shape
555-
lse_shape = (batch_size, seq_len, num_heads)
556-
return torch.empty_like(query), query.new_empty(lse_shape)
554+
# @_register_fake("vllm_flash_attn3::flash_attn")
555+
# def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
556+
# batch_size, seq_len, num_heads, head_dim = query.shape
557+
# lse_shape = (batch_size, seq_len, num_heads)
558+
# return torch.empty_like(query), query.new_empty(lse_shape)
557559

558560

559561
# ===== Attention backends =====

src/diffusers/models/modeling_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,11 @@ def set_attention_backend(self, backend: str) -> None:
595595
attention as backend.
596596
"""
597597
from .attention import AttentionModuleMixin
598-
from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements
598+
from .attention_dispatch import (
599+
AttentionBackendName,
600+
_check_attention_backend_requirements,
601+
_ensure_fa3_hub_loaded,
602+
)
599603

600604
# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
601605
from .attention_processor import Attention, MochiAttention
@@ -608,6 +612,10 @@ def set_attention_backend(self, backend: str) -> None:
608612
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
609613
backend = AttentionBackendName(backend)
610614
_check_attention_backend_requirements(backend)
615+
# TODO: clean this once it gets exhausted.
616+
if "_flash_3_hub" in backend:
617+
# We ensure it's preloaded to reduce overhead and also to avoid compilation errors.
618+
_ensure_fa3_hub_loaded()
611619

612620
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
613621
for module in self.modules():

src/diffusers/utils/kernels_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
logger = get_logger(__name__)
66

77

8-
_DEFAULT_HUB_ID_FA3 = "kernels-community/vllm-flash-attn3"
8+
_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"
99

1010

1111
def _get_fa3_from_hub():
@@ -15,7 +15,7 @@ def _get_fa3_from_hub():
1515
from kernels import get_kernel
1616

1717
try:
18-
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3)
18+
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops")
1919
return flash_attn_3_hub
2020
except Exception as e:
2121
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")

0 commit comments

Comments
 (0)