Skip to content

Commit ea18eb6

Browse files
committed
add fastercache to CacheMixin
1 parent 93de5f3 commit ea18eb6

File tree

3 files changed

+24
-6
lines changed

3 files changed

+24
-6
lines changed

src/diffusers/hooks/faster_cache.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
_FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser"
3131
_FASTER_CACHE_BLOCK_HOOK = "faster_cache_block"
3232
_ATTENTION_CLASSES = (Attention, MochiAttention)
33-
3433
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
3534
"blocks.*attn",
3635
"transformer_blocks.*attn",

src/diffusers/hooks/pyramid_attention_broadcast.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2727

2828

29+
_PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast"
2930
_ATTENTION_CLASSES = (Attention, MochiAttention)
30-
3131
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
3232
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
3333
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
@@ -311,4 +311,4 @@ def _apply_pyramid_attention_broadcast_hook(
311311
"""
312312
registry = HookRegistry.check_if_exists_or_initialize(module)
313313
hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback)
314-
registry.register_hook(hook, "pyramid_attention_broadcast")
314+
registry.register_hook(hook, _PYRAMID_ATTENTION_BROADCAST_HOOK)

src/diffusers/models/cache_utils.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class CacheMixin:
2424
2525
Supported caching techniques:
2626
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
27+
- [FasterCache](https://huggingface.co/papers/2410.19355)
2728
"""
2829

2930
_cache_config = None
@@ -59,25 +60,43 @@ def enable_cache(self, config) -> None:
5960
```
6061
"""
6162

62-
from ..hooks import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
63+
from ..hooks import (
64+
FasterCacheConfig,
65+
PyramidAttentionBroadcastConfig,
66+
apply_faster_cache,
67+
apply_pyramid_attention_broadcast,
68+
)
69+
70+
if self.is_cache_enabled:
71+
raise ValueError(
72+
f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
73+
)
6374

6475
if isinstance(config, PyramidAttentionBroadcastConfig):
6576
apply_pyramid_attention_broadcast(self, config)
77+
elif isinstance(config, FasterCacheConfig):
78+
apply_faster_cache(self, config)
6679
else:
6780
raise ValueError(f"Cache config {type(config)} is not supported.")
6881

6982
self._cache_config = config
7083

7184
def disable_cache(self) -> None:
72-
from ..hooks import HookRegistry, PyramidAttentionBroadcastConfig
85+
from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
86+
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
87+
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
7388

7489
if self._cache_config is None:
7590
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
7691
return
7792

7893
if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
7994
registry = HookRegistry.check_if_exists_or_initialize(self)
80-
registry.remove_hook("pyramid_attention_broadcast", recurse=True)
95+
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
96+
elif isinstance(self._cache_config, FasterCacheConfig):
97+
registry = HookRegistry.check_if_exists_or_initialize(self)
98+
registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
99+
registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
81100
else:
82101
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
83102

0 commit comments

Comments
 (0)