2525
2626logger = logging .get_logger (__name__ )
2727
28+
2829class DummyBlock (torch .nn .Module ):
2930 def __init__ (self ):
3031 super ().__init__ ()
@@ -34,6 +35,7 @@ def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
3435 # This ensures Residual = 2*Input - Input = Input
3536 return hidden_states * 2.0
3637
38+
3739class DummyTransformer (ModelMixin ):
3840 def __init__ (self ):
3941 super ().__init__ ()
@@ -44,6 +46,7 @@ def forward(self, hidden_states, encoder_hidden_states=None):
4446 hidden_states = block (hidden_states , encoder_hidden_states = encoder_hidden_states )
4547 return hidden_states
4648
49+
4750class TupleOutputBlock (torch .nn .Module ):
4851 def __init__ (self ):
4952 super ().__init__ ()
@@ -52,6 +55,7 @@ def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
5255 # Returns a tuple
5356 return hidden_states * 2.0 , encoder_hidden_states
5457
58+
5559class TupleTransformer (ModelMixin ):
5660 def __init__ (self ):
5761 super ().__init__ ()
@@ -65,23 +69,18 @@ def forward(self, hidden_states, encoder_hidden_states=None):
6569 encoder_hidden_states = output [1 ]
6670 return hidden_states , encoder_hidden_states
6771
72+
6873class MagCacheTests (unittest .TestCase ):
6974 def setUp (self ):
7075 # Register standard dummy block
7176 TransformerBlockRegistry .register (
7277 DummyBlock ,
73- TransformerBlockMetadata (
74- return_hidden_states_index = None ,
75- return_encoder_hidden_states_index = None
76- )
78+ TransformerBlockMetadata (return_hidden_states_index = None , return_encoder_hidden_states_index = None ),
7779 )
7880 # Register tuple block (Flux style)
7981 TransformerBlockRegistry .register (
8082 TupleOutputBlock ,
81- TransformerBlockMetadata (
82- return_hidden_states_index = 0 ,
83- return_encoder_hidden_states_index = 1
84- )
83+ TransformerBlockMetadata (return_hidden_states_index = 0 , return_encoder_hidden_states_index = 1 ),
8584 )
8685
8786 def _set_context (self , model , context_name ):
@@ -115,9 +114,9 @@ def test_mag_cache_skipping_logic(self):
115114 config = MagCacheConfig (
116115 threshold = 100.0 ,
117116 num_inference_steps = 2 ,
118- retention_ratio = 0.0 , # Enable immediate skipping
117+ retention_ratio = 0.0 , # Enable immediate skipping
119118 max_skip_steps = 5 ,
120- mag_ratios = ratios
119+ mag_ratios = ratios ,
121120 )
122121
123122 apply_mag_cache (model , config )
@@ -136,8 +135,7 @@ def test_mag_cache_skipping_logic(self):
136135 output_t1 = model (input_t1 )
137136
138137 self .assertTrue (
139- torch .allclose (output_t1 , torch .tensor ([[[41.0 ]]])),
140- f"Expected Skip (41.0), got { output_t1 .item ()} "
138+ torch .allclose (output_t1 , torch .tensor ([[[41.0 ]]])), f"Expected Skip (41.0), got { output_t1 .item ()} "
141139 )
142140
143141 def test_mag_cache_retention (self ):
@@ -149,8 +147,8 @@ def test_mag_cache_retention(self):
149147 config = MagCacheConfig (
150148 threshold = 100.0 ,
151149 num_inference_steps = 2 ,
152- retention_ratio = 1.0 , # Force retention for ALL steps
153- mag_ratios = ratios
150+ retention_ratio = 1.0 , # Force retention for ALL steps
151+ mag_ratios = ratios ,
154152 )
155153
156154 apply_mag_cache (model , config )
@@ -165,20 +163,15 @@ def test_mag_cache_retention(self):
165163
166164 self .assertTrue (
167165 torch .allclose (output_t1 , torch .tensor ([[[44.0 ]]])),
168- f"Expected Compute (44.0) due to retention, got { output_t1 .item ()} "
166+ f"Expected Compute (44.0) due to retention, got { output_t1 .item ()} " ,
169167 )
170168
171169 def test_mag_cache_tuple_outputs (self ):
172170 """Test compatibility with models returning (hidden, encoder_hidden) like Flux."""
173171 model = TupleTransformer ()
174172 ratios = np .array ([1.0 , 1.0 ])
175173
176- config = MagCacheConfig (
177- threshold = 100.0 ,
178- num_inference_steps = 2 ,
179- retention_ratio = 0.0 ,
180- mag_ratios = ratios
181- )
174+ config = MagCacheConfig (threshold = 100.0 , num_inference_steps = 2 , retention_ratio = 0.0 , mag_ratios = ratios )
182175
183176 apply_mag_cache (model , config )
184177 self ._set_context (model , "test_context" )
@@ -196,36 +189,29 @@ def test_mag_cache_tuple_outputs(self):
196189 out_1 , _ = model (input_t1 , encoder_hidden_states = enc_t0 )
197190
198191 self .assertTrue (
199- torch .allclose (out_1 , torch .tensor ([[[21.0 ]]])),
200- f"Tuple skip failed. Expected 21.0, got { out_1 .item ()} "
192+ torch .allclose (out_1 , torch .tensor ([[[21.0 ]]])), f"Tuple skip failed. Expected 21.0, got { out_1 .item ()} "
201193 )
202194
203195 def test_mag_cache_reset (self ):
204196 """Test that state resets correctly after num_inference_steps."""
205197 model = DummyTransformer ()
206198 config = MagCacheConfig (
207- threshold = 100.0 ,
208- num_inference_steps = 2 ,
209- retention_ratio = 0.0 ,
210- mag_ratios = np .array ([1.0 , 1.0 ])
199+ threshold = 100.0 , num_inference_steps = 2 , retention_ratio = 0.0 , mag_ratios = np .array ([1.0 , 1.0 ])
211200 )
212201 apply_mag_cache (model , config )
213202 self ._set_context (model , "test_context" )
214203
215204 input_t = torch .ones (1 , 1 , 1 )
216205
217- model (input_t ) # Step 0
218- model (input_t ) # Step 1 (Skipped)
206+ model (input_t ) # Step 0
207+ model (input_t ) # Step 1 (Skipped)
219208
220209 # Step 2 (Reset -> Step 0) -> Should Compute
221210 # Input 2.0 -> Output 8.0
222211 input_t2 = torch .tensor ([[[2.0 ]]])
223212 output_t2 = model (input_t2 )
224213
225- self .assertTrue (
226- torch .allclose (output_t2 , torch .tensor ([[[8.0 ]]])),
227- "State did not reset correctly"
228- )
214+ self .assertTrue (torch .allclose (output_t2 , torch .tensor ([[[8.0 ]]])), "State did not reset correctly" )
229215
230216 def test_mag_cache_calibration (self ):
231217 """Test that calibration mode records ratios."""
0 commit comments