Skip to content

Commit ebbebbe

Browse files
Apply style fixes
1 parent 535a14e commit ebbebbe

File tree

5 files changed

+41
-53
lines changed

5 files changed

+41
-53
lines changed

src/diffusers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,12 @@
163163
)
164164
_import_structure["hooks"].extend(
165165
[
166+
"FLUX_MAG_RATIOS",
166167
"FasterCacheConfig",
167168
"FirstBlockCacheConfig",
168169
"HookRegistry",
169170
"LayerSkipConfig",
170171
"MagCacheConfig",
171-
"FLUX_MAG_RATIOS",
172172
"PyramidAttentionBroadcastConfig",
173173
"SmoothedEnergyGuidanceConfig",
174174
"TaylorSeerCacheConfig",
@@ -898,9 +898,9 @@
898898
TangentialClassifierFreeGuidance,
899899
)
900900
from .hooks import (
901+
FLUX_MAG_RATIOS,
901902
FasterCacheConfig,
902903
FirstBlockCacheConfig,
903-
FLUX_MAG_RATIOS,
904904
HookRegistry,
905905
LayerSkipConfig,
906906
MagCacheConfig,

src/diffusers/hooks/_common.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,13 @@
2323
_ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin)
2424
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
2525

26-
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers", "visual_transformer_blocks",)
26+
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = (
27+
"blocks",
28+
"transformer_blocks",
29+
"single_transformer_blocks",
30+
"layers",
31+
"visual_transformer_blocks",
32+
)
2733
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
2834
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
2935

src/diffusers/hooks/_helpers.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,12 @@ def _register_transformer_blocks_metadata():
184184
HunyuanImageSingleTransformerBlock,
185185
HunyuanImageTransformerBlock,
186186
)
187+
from ..models.transformers.transformer_kandinsky import Kandinsky5TransformerDecoderBlock
187188
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
188189
from ..models.transformers.transformer_mochi import MochiTransformerBlock
189190
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
190191
from ..models.transformers.transformer_wan import WanTransformerBlock
191192
from ..models.transformers.transformer_z_image import ZImageTransformerBlock
192-
from ..models.transformers.transformer_kandinsky import Kandinsky5TransformerDecoderBlock
193193

194194
# BasicTransformerBlock
195195
TransformerBlockRegistry.register(
@@ -332,7 +332,6 @@ def _register_transformer_blocks_metadata():
332332
),
333333
)
334334

335-
336335
TransformerBlockRegistry.register(
337336
model_class=JointTransformerBlock,
338337
metadata=TransformerBlockMetadata(
@@ -341,7 +340,6 @@ def _register_transformer_blocks_metadata():
341340
),
342341
)
343342

344-
345343
# Kandinsky 5.0 (Kandinsky5TransformerDecoderBlock)
346344
TransformerBlockRegistry.register(
347345
model_class=Kandinsky5TransformerDecoderBlock,

src/diffusers/hooks/mag_cache.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -93,18 +93,18 @@ class MagCacheConfig:
9393
max_skip_steps (`int`, defaults to `5`):
9494
The maximum number of consecutive steps that can be skipped (K in the paper).
9595
retention_ratio (`float`, defaults to `0.1`):
96-
The fraction of initial steps during which skipping is disabled to ensure stability.
97-
For example, if `num_inference_steps` is 28 and `retention_ratio` is 0.1, the first 3 steps will never be skipped.
96+
The fraction of initial steps during which skipping is disabled to ensure stability. For example, if
97+
`num_inference_steps` is 28 and `retention_ratio` is 0.1, the first 3 steps will never be skipped.
9898
num_inference_steps (`int`, defaults to `28`):
9999
The number of inference steps used in the pipeline. This is required to interpolate `mag_ratios` correctly.
100100
mag_ratios (`np.ndarray`, *optional*):
101-
The pre-computed magnitude ratios for the model. These are checkpoint-dependent.
102-
If not provided, you must set `calibrate=True` to calculate them for your specific model.
103-
For Flux models, you can use `diffusers.hooks.mag_cache.FLUX_MAG_RATIOS`.
101+
The pre-computed magnitude ratios for the model. These are checkpoint-dependent. If not provided, you must
102+
set `calibrate=True` to calculate them for your specific model. For Flux models, you can use
103+
`diffusers.hooks.mag_cache.FLUX_MAG_RATIOS`.
104104
calibrate (`bool`, defaults to `False`):
105-
If True, enables calibration mode. In this mode, no blocks are skipped. Instead, the hook calculates
106-
the magnitude ratios for the current run and logs them at the end. Use this to obtain `mag_ratios`
107-
for new models or schedulers.
105+
If True, enables calibration mode. In this mode, no blocks are skipped. Instead, the hook calculates the
106+
magnitude ratios for the current run and logs them at the end. Use this to obtain `mag_ratios` for new
107+
models or schedulers.
108108
"""
109109

110110
threshold: float = 0.24
@@ -335,10 +335,10 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
335335
if diff == 0:
336336
residual = out_hidden - in_hidden
337337
else:
338-
residual = out_hidden - in_hidden # Fallback to matching tail
338+
residual = out_hidden - in_hidden # Fallback to matching tail
339339
else:
340-
# Fallback for completely mismatched shapes
341-
residual = out_hidden # Invalid but prevents crash
340+
# Fallback for completely mismatched shapes
341+
residual = out_hidden # Invalid but prevents crash
342342

343343
if self.config.calibrate:
344344
self._perform_calibration_step(state, residual)
@@ -429,9 +429,7 @@ def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None:
429429
_apply_mag_cache_block_hook(tail_block, state_manager, config, is_tail=True)
430430

431431

432-
def _apply_mag_cache_head_hook(
433-
block: torch.nn.Module, state_manager: StateManager, config: MagCacheConfig
434-
) -> None:
432+
def _apply_mag_cache_head_hook(block: torch.nn.Module, state_manager: StateManager, config: MagCacheConfig) -> None:
435433
registry = HookRegistry.check_if_exists_or_initialize(block)
436434

437435
# Automatically remove existing hook to allow re-application (e.g. switching modes)

tests/hooks/test_mag_cache.py

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
logger = logging.get_logger(__name__)
2727

28+
2829
class DummyBlock(torch.nn.Module):
2930
def __init__(self):
3031
super().__init__()
@@ -34,6 +35,7 @@ def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
3435
# This ensures Residual = 2*Input - Input = Input
3536
return hidden_states * 2.0
3637

38+
3739
class DummyTransformer(ModelMixin):
3840
def __init__(self):
3941
super().__init__()
@@ -44,6 +46,7 @@ def forward(self, hidden_states, encoder_hidden_states=None):
4446
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
4547
return hidden_states
4648

49+
4750
class TupleOutputBlock(torch.nn.Module):
4851
def __init__(self):
4952
super().__init__()
@@ -52,6 +55,7 @@ def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
5255
# Returns a tuple
5356
return hidden_states * 2.0, encoder_hidden_states
5457

58+
5559
class TupleTransformer(ModelMixin):
5660
def __init__(self):
5761
super().__init__()
@@ -65,23 +69,18 @@ def forward(self, hidden_states, encoder_hidden_states=None):
6569
encoder_hidden_states = output[1]
6670
return hidden_states, encoder_hidden_states
6771

72+
6873
class MagCacheTests(unittest.TestCase):
6974
def setUp(self):
7075
# Register standard dummy block
7176
TransformerBlockRegistry.register(
7277
DummyBlock,
73-
TransformerBlockMetadata(
74-
return_hidden_states_index=None,
75-
return_encoder_hidden_states_index=None
76-
)
78+
TransformerBlockMetadata(return_hidden_states_index=None, return_encoder_hidden_states_index=None),
7779
)
7880
# Register tuple block (Flux style)
7981
TransformerBlockRegistry.register(
8082
TupleOutputBlock,
81-
TransformerBlockMetadata(
82-
return_hidden_states_index=0,
83-
return_encoder_hidden_states_index=1
84-
)
83+
TransformerBlockMetadata(return_hidden_states_index=0, return_encoder_hidden_states_index=1),
8584
)
8685

8786
def _set_context(self, model, context_name):
@@ -115,9 +114,9 @@ def test_mag_cache_skipping_logic(self):
115114
config = MagCacheConfig(
116115
threshold=100.0,
117116
num_inference_steps=2,
118-
retention_ratio=0.0, # Enable immediate skipping
117+
retention_ratio=0.0, # Enable immediate skipping
119118
max_skip_steps=5,
120-
mag_ratios=ratios
119+
mag_ratios=ratios,
121120
)
122121

123122
apply_mag_cache(model, config)
@@ -136,8 +135,7 @@ def test_mag_cache_skipping_logic(self):
136135
output_t1 = model(input_t1)
137136

138137
self.assertTrue(
139-
torch.allclose(output_t1, torch.tensor([[[41.0]]])),
140-
f"Expected Skip (41.0), got {output_t1.item()}"
138+
torch.allclose(output_t1, torch.tensor([[[41.0]]])), f"Expected Skip (41.0), got {output_t1.item()}"
141139
)
142140

143141
def test_mag_cache_retention(self):
@@ -149,8 +147,8 @@ def test_mag_cache_retention(self):
149147
config = MagCacheConfig(
150148
threshold=100.0,
151149
num_inference_steps=2,
152-
retention_ratio=1.0, # Force retention for ALL steps
153-
mag_ratios=ratios
150+
retention_ratio=1.0, # Force retention for ALL steps
151+
mag_ratios=ratios,
154152
)
155153

156154
apply_mag_cache(model, config)
@@ -165,20 +163,15 @@ def test_mag_cache_retention(self):
165163

166164
self.assertTrue(
167165
torch.allclose(output_t1, torch.tensor([[[44.0]]])),
168-
f"Expected Compute (44.0) due to retention, got {output_t1.item()}"
166+
f"Expected Compute (44.0) due to retention, got {output_t1.item()}",
169167
)
170168

171169
def test_mag_cache_tuple_outputs(self):
172170
"""Test compatibility with models returning (hidden, encoder_hidden) like Flux."""
173171
model = TupleTransformer()
174172
ratios = np.array([1.0, 1.0])
175173

176-
config = MagCacheConfig(
177-
threshold=100.0,
178-
num_inference_steps=2,
179-
retention_ratio=0.0,
180-
mag_ratios=ratios
181-
)
174+
config = MagCacheConfig(threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=ratios)
182175

183176
apply_mag_cache(model, config)
184177
self._set_context(model, "test_context")
@@ -196,36 +189,29 @@ def test_mag_cache_tuple_outputs(self):
196189
out_1, _ = model(input_t1, encoder_hidden_states=enc_t0)
197190

198191
self.assertTrue(
199-
torch.allclose(out_1, torch.tensor([[[21.0]]])),
200-
f"Tuple skip failed. Expected 21.0, got {out_1.item()}"
192+
torch.allclose(out_1, torch.tensor([[[21.0]]])), f"Tuple skip failed. Expected 21.0, got {out_1.item()}"
201193
)
202194

203195
def test_mag_cache_reset(self):
204196
"""Test that state resets correctly after num_inference_steps."""
205197
model = DummyTransformer()
206198
config = MagCacheConfig(
207-
threshold=100.0,
208-
num_inference_steps=2,
209-
retention_ratio=0.0,
210-
mag_ratios=np.array([1.0, 1.0])
199+
threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=np.array([1.0, 1.0])
211200
)
212201
apply_mag_cache(model, config)
213202
self._set_context(model, "test_context")
214203

215204
input_t = torch.ones(1, 1, 1)
216205

217-
model(input_t) # Step 0
218-
model(input_t) # Step 1 (Skipped)
206+
model(input_t) # Step 0
207+
model(input_t) # Step 1 (Skipped)
219208

220209
# Step 2 (Reset -> Step 0) -> Should Compute
221210
# Input 2.0 -> Output 8.0
222211
input_t2 = torch.tensor([[[2.0]]])
223212
output_t2 = model(input_t2)
224213

225-
self.assertTrue(
226-
torch.allclose(output_t2, torch.tensor([[[8.0]]])),
227-
"State did not reset correctly"
228-
)
214+
self.assertTrue(torch.allclose(output_t2, torch.tensor([[[8.0]]])), "State did not reset correctly")
229215

230216
def test_mag_cache_calibration(self):
231217
"""Test that calibration mode records ratios."""

0 commit comments

Comments
 (0)