Skip to content

Commit 23e7548

Browse files
committed
update
1 parent e9fd0ca commit 23e7548

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -187,13 +187,16 @@ def list_backends(cls):
187187

188188

189189
@contextlib.contextmanager
190-
def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIVE):
190+
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
191191
"""
192192
Context manager to set the active attention backend.
193193
"""
194194
if backend not in _AttentionBackendRegistry._backends:
195195
raise ValueError(f"Backend {backend} is not registered.")
196196

197+
backend = AttentionBackendName(backend)
198+
_check_attention_backend_requirements(backend)
199+
197200
old_backend = _AttentionBackendRegistry._active_backend
198201
_AttentionBackendRegistry._active_backend = backend
199202

@@ -226,8 +229,6 @@ def dispatch_attention_fn(
226229
backend_name = AttentionBackendName(backend)
227230
backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
228231

229-
_check_backend_requirements(backend_name)
230-
231232
kwargs = {
232233
"query": query,
233234
"key": key,
@@ -316,10 +317,7 @@ def _check_shape(
316317
# ===== Helper functions =====
317318

318319

319-
# LRU cache is hack to avoid checking the backend requirements multiple times. Maybe not needed
320-
# because CPU is running much farther ahead of the accelerator and this will not be blocking anyway.
321-
@functools.lru_cache(maxsize=16)
322-
def _check_backend_requirements(backend: AttentionBackendName) -> None:
320+
def _check_attention_backend_requirements(backend: AttentionBackendName) -> None:
323321
if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]:
324322
if not _CAN_USE_FLASH_ATTN:
325323
raise RuntimeError(

src/diffusers/models/modeling_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,7 @@ def set_attention_backend(self, backend: str) -> None:
622622
attention as backend.
623623
"""
624624
from .attention import AttentionModuleMixin
625-
from .attention_dispatch import AttentionBackendName
625+
from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements
626626

627627
# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
628628
from .attention_processor import Attention, MochiAttention
@@ -633,10 +633,10 @@ def set_attention_backend(self, backend: str) -> None:
633633
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
634634
if backend not in available_backends:
635635
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
636-
637636
backend = AttentionBackendName(backend)
638-
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
637+
_check_attention_backend_requirements(backend)
639638

639+
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
640640
for module in self.modules():
641641
if not isinstance(module, attention_classes):
642642
continue

0 commit comments

Comments
 (0)