Skip to content

Commit 00e9b96

Browse files
committed
fix kandinsky errors
1 parent ebbebbe commit 00e9b96

File tree

3 files changed

+16
-3
lines changed

3 files changed

+16
-3
lines changed

src/diffusers/hooks/_helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class AttentionProcessorMetadata:
2626
class TransformerBlockMetadata:
2727
return_hidden_states_index: int = None
2828
return_encoder_hidden_states_index: int = None
29+
hidden_states_argument_name: str = "hidden_states"
2930

3031
_cls: Type = None
3132
_cached_parameter_indices: Dict[str, int] = None
@@ -346,6 +347,7 @@ def _register_transformer_blocks_metadata():
346347
metadata=TransformerBlockMetadata(
347348
return_hidden_states_index=0,
348349
return_encoder_hidden_states_index=None,
350+
hidden_states_argument_name="visual_embed",
349351
),
350352
)
351353

src/diffusers/hooks/mag_cache.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,9 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
182182
if self.state_manager._current_context is None:
183183
self.state_manager.set_context("inference")
184184

185-
# Capture input hidden_states
186-
hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
185+
arg_name = self._metadata.hidden_states_argument_name
186+
hidden_states = self._metadata._get_parameter_from_args_kwargs(arg_name, args, kwargs)
187+
187188

188189
state: MagCacheState = self.state_manager.get_state()
189190
state.head_block_input = hidden_states
@@ -297,7 +298,9 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
297298
state: MagCacheState = self.state_manager.get_state()
298299

299300
if not state.should_compute:
300-
hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
301+
arg_name = self._metadata.hidden_states_argument_name
302+
hidden_states = self._metadata._get_parameter_from_args_kwargs(arg_name, args, kwargs)
303+
301304
if self.is_tail:
302305
# Still need to advance step index even if we skip
303306
self._advance_step(state)

src/diffusers/models/cache_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,12 @@ def enable_cache(self, config) -> None:
6666
from ..hooks import (
6767
FasterCacheConfig,
6868
FirstBlockCacheConfig,
69+
MagCacheConfig,
6970
PyramidAttentionBroadcastConfig,
7071
TaylorSeerCacheConfig,
7172
apply_faster_cache,
7273
apply_first_block_cache,
74+
apply_mag_cache,
7375
apply_pyramid_attention_broadcast,
7476
apply_taylorseer_cache,
7577
)
@@ -83,6 +85,8 @@ def enable_cache(self, config) -> None:
8385
apply_faster_cache(self, config)
8486
elif isinstance(config, FirstBlockCacheConfig):
8587
apply_first_block_cache(self, config)
88+
elif isinstance(config, MagCacheConfig):
89+
apply_mag_cache(self, config)
8690
elif isinstance(config, PyramidAttentionBroadcastConfig):
8791
apply_pyramid_attention_broadcast(self, config)
8892
elif isinstance(config, TaylorSeerCacheConfig):
@@ -97,11 +101,13 @@ def disable_cache(self) -> None:
97101
FasterCacheConfig,
98102
FirstBlockCacheConfig,
99103
HookRegistry,
104+
MagCacheConfig,
100105
PyramidAttentionBroadcastConfig,
101106
TaylorSeerCacheConfig,
102107
)
103108
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
104109
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
110+
from ..hooks.mag_cache import _MAG_CACHE_LEADER_BLOCK_HOOK
105111
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
106112
from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK
107113

@@ -116,6 +122,8 @@ def disable_cache(self) -> None:
116122
elif isinstance(self._cache_config, FirstBlockCacheConfig):
117123
registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True)
118124
registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
125+
elif isinstance(self._cache_config, MagCacheConfig):
126+
registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK, recurse=True)
119127
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
120128
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
121129
elif isinstance(self._cache_config, TaylorSeerCacheConfig):

0 commit comments

Comments
 (0)