Skip to content

Commit bc64f12

Browse files
committed
fix fastercache implementation
1 parent 0cda91d commit bc64f12

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/diffusers/hooks/faster_cache.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import torch
2020

21+
from ..models.attention import AttentionModuleMixin
2122
from ..models.attention_processor import Attention, MochiAttention
2223
from ..models.modeling_outputs import Transformer2DModelOutput
2324
from ..utils import logging
@@ -567,7 +568,7 @@ def high_frequency_weight_callback(module: torch.nn.Module) -> float:
567568
_apply_faster_cache_on_denoiser(module, config)
568569

569570
for name, submodule in module.named_modules():
570-
if not isinstance(submodule, _ATTENTION_CLASSES):
571+
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
571572
continue
572573
if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
573574
_apply_faster_cache_on_attention_class(name, submodule, config)

0 commit comments

Comments
 (0)