File tree Expand file tree Collapse file tree 1 file changed +2
-1
lines changed Expand file tree Collapse file tree 1 file changed +2
-1
lines changed Original file line number Diff line number Diff line change 1818
1919import torch
2020
21+ from ..models .attention import AttentionModuleMixin
2122from ..models .attention_processor import Attention , MochiAttention
2223from ..models .modeling_outputs import Transformer2DModelOutput
2324from ..utils import logging
@@ -567,7 +568,7 @@ def high_frequency_weight_callback(module: torch.nn.Module) -> float:
567568 _apply_faster_cache_on_denoiser (module , config )
568569
569570 for name , submodule in module .named_modules ():
570- if not isinstance (submodule , _ATTENTION_CLASSES ):
571+ if not isinstance (submodule , ( * _ATTENTION_CLASSES , AttentionModuleMixin ) ):
571572 continue
572573 if any (re .search (identifier , name ) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS ):
573574 _apply_faster_cache_on_attention_class (name , submodule , config )
You can’t perform that action at this time.
0 commit comments