Skip to content

Commit 0a05bec

Browse files
committed
improvements
1 parent 37f8826 commit 0a05bec

File tree

3 files changed

+71
-33
lines changed

3 files changed

+71
-33
lines changed

src/diffusers/hooks/_helpers.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def _register_attention_processors_metadata():
169169

170170

171171
def _register_transformer_blocks_metadata():
172-
from ..models.attention import BasicTransformerBlock
172+
from ..models.attention import BasicTransformerBlock, JointTransformerBlock
173173
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
174174
from ..models.transformers.transformer_bria import BriaTransformerBlock
175175
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
@@ -189,6 +189,7 @@ def _register_transformer_blocks_metadata():
189189
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
190190
from ..models.transformers.transformer_wan import WanTransformerBlock
191191
from ..models.transformers.transformer_z_image import ZImageTransformerBlock
192+
from ..models.transformers.transformer_kandinsky import Kandinsky5TransformerDecoderBlock
192193

193194
# BasicTransformerBlock
194195
TransformerBlockRegistry.register(
@@ -332,6 +333,25 @@ def _register_transformer_blocks_metadata():
332333
)
333334

334335

336+
TransformerBlockRegistry.register(
337+
model_class=JointTransformerBlock,
338+
metadata=TransformerBlockMetadata(
339+
return_hidden_states_index=1,
340+
return_encoder_hidden_states_index=0,
341+
),
342+
)
343+
344+
345+
# Kandinsky 5.0 (Kandinsky5TransformerDecoderBlock)
346+
TransformerBlockRegistry.register(
347+
model_class=Kandinsky5TransformerDecoderBlock,
348+
metadata=TransformerBlockMetadata(
349+
return_hidden_states_index=0,
350+
return_encoder_hidden_states_index=None,
351+
),
352+
)
353+
354+
335355
# fmt: off
336356
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
337357
hidden_states = kwargs.get("hidden_states", None)

src/diffusers/hooks/mag_cache.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class MagCacheConfig:
115115
calibrate: bool = False
116116

117117
def __post_init__(self):
118-
# Strict validation: User MUST provide ratios OR enable calibration.
118+
# User MUST provide ratios OR enable calibration.
119119
if self.mag_ratios is None and not self.calibrate:
120120
raise ValueError(
121121
" `mag_ratios` must be provided for MagCache inference because these ratios are model-dependent.\n"
@@ -151,7 +151,7 @@ def __init__(self) -> None:
151151

152152
# Current step counter (timestep index)
153153
self.step_index: int = 0
154-
154+
155155
# Calibration storage
156156
self.calibration_ratios: List[float] = []
157157

@@ -179,6 +179,9 @@ def initialize_hook(self, module):
179179
return module
180180

181181
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
182+
if self.state_manager._current_context is None:
183+
self.state_manager.set_context("inference")
184+
182185
# Capture input hidden_states
183186
hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
184187

@@ -225,6 +228,9 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
225228
output = hidden_states
226229
res = state.previous_residual
227230

231+
if res.device != output.device:
232+
res = res.to(output.device)
233+
228234
# Attempt to apply residual handling shape mismatches (e.g., text+image vs image only)
229235
if res.shape == output.shape:
230236
output = output + res
@@ -320,7 +326,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
320326
out_hidden = output
321327

322328
in_hidden = state.head_block_input
323-
329+
324330
# Determine residual
325331
if out_hidden.shape == in_hidden.shape:
326332
residual = out_hidden - in_hidden
@@ -345,28 +351,28 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
345351
def _perform_calibration_step(self, state: MagCacheState, current_residual: torch.Tensor):
346352
if state.previous_residual is None:
347353
# First step has no previous residual to compare against.
348-
# We log 1.0 as a neutral starting point.
354+
# log 1.0 as a neutral starting point.
349355
ratio = 1.0
350356
else:
351357
# MagCache Calibration Formula: mean(norm(curr) / norm(prev))
352358
# norm(dim=-1) gives magnitude of each token vector
353359
curr_norm = torch.linalg.norm(current_residual.float(), dim=-1)
354360
prev_norm = torch.linalg.norm(state.previous_residual.float(), dim=-1)
355-
361+
356362
# Avoid division by zero
357363
ratio = (curr_norm / (prev_norm + 1e-8)).mean().item()
358-
364+
359365
state.calibration_ratios.append(ratio)
360-
366+
361367
def _advance_step(self, state: MagCacheState):
362368
state.step_index += 1
363369
if state.step_index >= self.config.num_inference_steps:
364370
# End of inference loop
365371
if self.config.calibrate:
366-
print(f"\n[MagCache] Calibration Complete. Copy these values to MagCacheConfig(mag_ratios=...):")
372+
print("\n[MagCache] Calibration Complete. Copy these values to MagCacheConfig(mag_ratios=...):")
367373
print(f"{state.calibration_ratios}\n")
368374
logger.info(f"MagCache Calibration Results: {state.calibration_ratios}")
369-
375+
370376
# Reset state
371377
state.step_index = 0
372378
state.accumulated_ratio = 1.0
@@ -386,6 +392,9 @@ def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None:
386392
config (`MagCacheConfig`):
387393
The configuration for MagCache.
388394
"""
395+
# Initialize registry on the root module so the Pipeline can set context.
396+
HookRegistry.check_if_exists_or_initialize(module)
397+
389398
state_manager = StateManager(MagCacheState, (), {})
390399
remaining_blocks = []
391400

@@ -399,13 +408,11 @@ def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None:
399408
logger.warning("MagCache: No transformer blocks found to apply hooks.")
400409
return
401410

411+
# Handle single-block models
402412
if len(remaining_blocks) == 1:
403-
# Single block case: It acts as both Head (Decision) and Tail (Residual Calc)
404413
name, block = remaining_blocks[0]
405414
logger.info(f"MagCache: Applying Head+Tail Hooks to single block '{name}'")
406-
# Apply BlockHook (Tail) FIRST so it is the INNER wrapper
407415
_apply_mag_cache_block_hook(block, state_manager, config, is_tail=True)
408-
# Apply HeadHook SECOND so it is the OUTER wrapper (controls flow)
409416
_apply_mag_cache_head_hook(block, state_manager, config)
410417
return
411418

@@ -426,6 +433,11 @@ def _apply_mag_cache_head_hook(
426433
block: torch.nn.Module, state_manager: StateManager, config: MagCacheConfig
427434
) -> None:
428435
registry = HookRegistry.check_if_exists_or_initialize(block)
436+
437+
# Automatically remove existing hook to allow re-application (e.g. switching modes)
438+
if registry.get_hook(_MAG_CACHE_LEADER_BLOCK_HOOK) is not None:
439+
registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK)
440+
429441
hook = MagCacheHeadHook(state_manager, config)
430442
registry.register_hook(hook, _MAG_CACHE_LEADER_BLOCK_HOOK)
431443

@@ -437,5 +449,10 @@ def _apply_mag_cache_block_hook(
437449
is_tail: bool = False,
438450
) -> None:
439451
registry = HookRegistry.check_if_exists_or_initialize(block)
452+
453+
# Automatically remove existing hook to allow re-application
454+
if registry.get_hook(_MAG_CACHE_BLOCK_HOOK) is not None:
455+
registry.remove_hook(_MAG_CACHE_BLOCK_HOOK)
456+
440457
hook = MagCacheBlockHook(state_manager, is_tail, config)
441-
registry.register_hook(hook, _MAG_CACHE_BLOCK_HOOK)
458+
registry.register_hook(hook, _MAG_CACHE_BLOCK_HOOK)

tests/hooks/test_mag_cache.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
# limitations under the License.
1414

1515
import unittest
16-
import torch
16+
1717
import numpy as np
18+
import torch
1819

1920
from diffusers import MagCacheConfig, apply_mag_cache
2021
from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry
@@ -46,7 +47,7 @@ def forward(self, hidden_states, encoder_hidden_states=None):
4647
class TupleOutputBlock(torch.nn.Module):
4748
def __init__(self):
4849
super().__init__()
49-
50+
5051
def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
5152
# Returns a tuple
5253
return hidden_states * 2.0, encoder_hidden_states
@@ -88,7 +89,7 @@ def _set_context(self, model, context_name):
8889
for module in model.modules():
8990
if hasattr(module, "_diffusers_hook"):
9091
module._diffusers_hook._set_context(context_name)
91-
92+
9293
def _get_calibration_data(self, model):
9394
for module in model.modules():
9495
if hasattr(module, "_diffusers_hook"):
@@ -143,25 +144,25 @@ def test_mag_cache_retention(self):
143144
"""Test that retention_ratio prevents skipping even if error is low."""
144145
model = DummyTransformer()
145146
# Ratios that imply 0 error, so it *would* skip if retention allowed it
146-
ratios = np.array([1.0, 1.0])
147-
147+
ratios = np.array([1.0, 1.0])
148+
148149
config = MagCacheConfig(
149150
threshold=100.0,
150151
num_inference_steps=2,
151152
retention_ratio=1.0, # Force retention for ALL steps
152153
mag_ratios=ratios
153154
)
154-
155+
155156
apply_mag_cache(model, config)
156157
self._set_context(model, "test_context")
157-
158+
158159
# Step 0
159160
model(torch.tensor([[[10.0]]]))
160-
161+
161162
# Step 1: Should COMPUTE (44.0) not SKIP (41.0) because of retention
162163
input_t1 = torch.tensor([[[11.0]]])
163164
output_t1 = model(input_t1)
164-
165+
165166
self.assertTrue(
166167
torch.allclose(output_t1, torch.tensor([[[44.0]]])),
167168
f"Expected Compute (44.0) due to retention, got {output_t1.item()}"
@@ -171,29 +172,29 @@ def test_mag_cache_tuple_outputs(self):
171172
"""Test compatibility with models returning (hidden, encoder_hidden) like Flux."""
172173
model = TupleTransformer()
173174
ratios = np.array([1.0, 1.0])
174-
175+
175176
config = MagCacheConfig(
176177
threshold=100.0,
177178
num_inference_steps=2,
178179
retention_ratio=0.0,
179180
mag_ratios=ratios
180181
)
181-
182+
182183
apply_mag_cache(model, config)
183184
self._set_context(model, "test_context")
184-
185+
185186
# Step 0: Compute. Input 10.0 -> Output 20.0 (1 block * 2x)
186187
# Residual = 10.0
187188
input_t0 = torch.tensor([[[10.0]]])
188189
enc_t0 = torch.tensor([[[1.0]]])
189190
out_0, _ = model(input_t0, encoder_hidden_states=enc_t0)
190191
self.assertTrue(torch.allclose(out_0, torch.tensor([[[20.0]]])))
191-
192+
192193
# Step 1: Skip. Input 11.0.
193194
# Skipped Output = 11 + 10 = 21.0
194195
input_t1 = torch.tensor([[[11.0]]])
195196
out_1, _ = model(input_t1, encoder_hidden_states=enc_t0)
196-
197+
197198
self.assertTrue(
198199
torch.allclose(out_1, torch.tensor([[[21.0]]])),
199200
f"Tuple skip failed. Expected 21.0, got {out_1.item()}"
@@ -203,8 +204,8 @@ def test_mag_cache_reset(self):
203204
"""Test that state resets correctly after num_inference_steps."""
204205
model = DummyTransformer()
205206
config = MagCacheConfig(
206-
threshold=100.0,
207-
num_inference_steps=2,
207+
threshold=100.0,
208+
num_inference_steps=2,
208209
retention_ratio=0.0,
209210
mag_ratios=np.array([1.0, 1.0])
210211
)
@@ -237,7 +238,7 @@ def test_mag_cache_calibration(self):
237238
# HeadInput = 10. Output = 40. Residual = 30.
238239
# Ratio 0 is placeholder 1.0
239240
model(torch.tensor([[[10.0]]]))
240-
241+
241242
# Check intermediate state
242243
ratios = self._get_calibration_data(model)
243244
self.assertEqual(len(ratios), 1)
@@ -248,10 +249,10 @@ def test_mag_cache_calibration(self):
248249
# PrevResidual = 30. CurrResidual = 30.
249250
# Ratio = 30/30 = 1.0
250251
model(torch.tensor([[[10.0]]]))
251-
252+
252253
# Verify it computes fully (no skip)
253254
# If it skipped, output would be 41.0. It should be 40.0
254255
# Actually in test setup, input is same (10.0) so output 40.0.
255256
# Let's ensure list is empty after reset (end of step 1)
256257
ratios_after = self._get_calibration_data(model)
257-
self.assertEqual(ratios_after, [])
258+
self.assertEqual(ratios_after, [])

0 commit comments

Comments
 (0)