@@ -115,7 +115,7 @@ class MagCacheConfig:
115115 calibrate : bool = False
116116
117117 def __post_init__ (self ):
118- # Strict validation: User MUST provide ratios OR enable calibration.
118+ # User MUST provide ratios OR enable calibration.
119119 if self .mag_ratios is None and not self .calibrate :
120120 raise ValueError (
121121 " `mag_ratios` must be provided for MagCache inference because these ratios are model-dependent.\n "
@@ -151,7 +151,7 @@ def __init__(self) -> None:
151151
152152 # Current step counter (timestep index)
153153 self .step_index : int = 0
154-
154+
155155 # Calibration storage
156156 self .calibration_ratios : List [float ] = []
157157
@@ -179,6 +179,9 @@ def initialize_hook(self, module):
179179 return module
180180
181181 def new_forward (self , module : torch .nn .Module , * args , ** kwargs ):
182+ if self .state_manager ._current_context is None :
183+ self .state_manager .set_context ("inference" )
184+
182185 # Capture input hidden_states
183186 hidden_states = self ._metadata ._get_parameter_from_args_kwargs ("hidden_states" , args , kwargs )
184187
@@ -225,6 +228,9 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
225228 output = hidden_states
226229 res = state .previous_residual
227230
231+ if res .device != output .device :
232+ res = res .to (output .device )
233+
228234 # Attempt to apply residual handling shape mismatches (e.g., text+image vs image only)
229235 if res .shape == output .shape :
230236 output = output + res
@@ -320,7 +326,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
320326 out_hidden = output
321327
322328 in_hidden = state .head_block_input
323-
329+
324330 # Determine residual
325331 if out_hidden .shape == in_hidden .shape :
326332 residual = out_hidden - in_hidden
@@ -345,28 +351,28 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
345351 def _perform_calibration_step (self , state : MagCacheState , current_residual : torch .Tensor ):
346352 if state .previous_residual is None :
347353 # First step has no previous residual to compare against.
348- # We log 1.0 as a neutral starting point.
354+ # log 1.0 as a neutral starting point.
349355 ratio = 1.0
350356 else :
351357 # MagCache Calibration Formula: mean(norm(curr) / norm(prev))
352358 # norm(dim=-1) gives magnitude of each token vector
353359 curr_norm = torch .linalg .norm (current_residual .float (), dim = - 1 )
354360 prev_norm = torch .linalg .norm (state .previous_residual .float (), dim = - 1 )
355-
361+
356362 # Avoid division by zero
357363 ratio = (curr_norm / (prev_norm + 1e-8 )).mean ().item ()
358-
364+
359365 state .calibration_ratios .append (ratio )
360-
366+
361367 def _advance_step (self , state : MagCacheState ):
362368 state .step_index += 1
363369 if state .step_index >= self .config .num_inference_steps :
364370 # End of inference loop
365371 if self .config .calibrate :
366- print (f "\n [MagCache] Calibration Complete. Copy these values to MagCacheConfig(mag_ratios=...):" )
372+ print ("\n [MagCache] Calibration Complete. Copy these values to MagCacheConfig(mag_ratios=...):" )
367373 print (f"{ state .calibration_ratios } \n " )
368374 logger .info (f"MagCache Calibration Results: { state .calibration_ratios } " )
369-
375+
370376 # Reset state
371377 state .step_index = 0
372378 state .accumulated_ratio = 1.0
@@ -386,6 +392,9 @@ def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None:
386392 config (`MagCacheConfig`):
387393 The configuration for MagCache.
388394 """
395+ # Initialize registry on the root module so the Pipeline can set context.
396+ HookRegistry .check_if_exists_or_initialize (module )
397+
389398 state_manager = StateManager (MagCacheState , (), {})
390399 remaining_blocks = []
391400
@@ -399,13 +408,11 @@ def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None:
399408 logger .warning ("MagCache: No transformer blocks found to apply hooks." )
400409 return
401410
411+ # Handle single-block models
402412 if len (remaining_blocks ) == 1 :
403- # Single block case: It acts as both Head (Decision) and Tail (Residual Calc)
404413 name , block = remaining_blocks [0 ]
405414 logger .info (f"MagCache: Applying Head+Tail Hooks to single block '{ name } '" )
406- # Apply BlockHook (Tail) FIRST so it is the INNER wrapper
407415 _apply_mag_cache_block_hook (block , state_manager , config , is_tail = True )
408- # Apply HeadHook SECOND so it is the OUTER wrapper (controls flow)
409416 _apply_mag_cache_head_hook (block , state_manager , config )
410417 return
411418
@@ -426,6 +433,11 @@ def _apply_mag_cache_head_hook(
426433 block : torch .nn .Module , state_manager : StateManager , config : MagCacheConfig
427434) -> None :
428435 registry = HookRegistry .check_if_exists_or_initialize (block )
436+
437+ # Automatically remove existing hook to allow re-application (e.g. switching modes)
438+ if registry .get_hook (_MAG_CACHE_LEADER_BLOCK_HOOK ) is not None :
439+ registry .remove_hook (_MAG_CACHE_LEADER_BLOCK_HOOK )
440+
429441 hook = MagCacheHeadHook (state_manager , config )
430442 registry .register_hook (hook , _MAG_CACHE_LEADER_BLOCK_HOOK )
431443
@@ -437,5 +449,10 @@ def _apply_mag_cache_block_hook(
437449 is_tail : bool = False ,
438450) -> None :
439451 registry = HookRegistry .check_if_exists_or_initialize (block )
452+
453+ # Automatically remove existing hook to allow re-application
454+ if registry .get_hook (_MAG_CACHE_BLOCK_HOOK ) is not None :
455+ registry .remove_hook (_MAG_CACHE_BLOCK_HOOK )
456+
440457 hook = MagCacheBlockHook (state_manager , is_tail , config )
441- registry .register_hook (hook , _MAG_CACHE_BLOCK_HOOK )
458+ registry .register_hook (hook , _MAG_CACHE_BLOCK_HOOK )
0 commit comments