3131_FASTER_CACHE_BLOCK_HOOK  =  "faster_cache_block" 
3232_ATTENTION_CLASSES  =  (Attention , MochiAttention )
3333_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS  =  (
34-     "blocks.*attn" ,
35-     "transformer_blocks.*attn" ,
36-     "single_transformer_blocks.*attn" , 
34+     "^ blocks.*attn" ,
35+     "^ transformer_blocks.*attn" ,
36+     "^ single_transformer_blocks.*attn" 
3737)
38- _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS  =  ("temporal_transformer_blocks.*attn" ,)
38+ _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS  =  ("^temporal_transformer_blocks.*attn" ,)
39+ _TRANSFORMER_BLOCK_IDENTIFIERS  =  _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS  +  _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS 
3940_UNCOND_COND_INPUT_KWARGS_IDENTIFIERS  =  (
4041    "hidden_states" ,
4142    "encoder_hidden_states" ,
@@ -276,9 +277,10 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
276277            self .state .iteration  >  0 
277278            and  is_within_timestep_range 
278279            and  self .state .iteration  %  self .unconditional_batch_skip_range  !=  0 
280+             and  not  self .is_guidance_distilled 
279281        )
280282
281-         if  should_skip_uncond   and   not   self . is_guidance_distilled :
283+         if  should_skip_uncond :
282284            is_any_kwarg_uncond  =  any (k  in  self .uncond_cond_input_kwargs_identifiers  for  k  in  kwargs .keys ())
283285            if  is_any_kwarg_uncond :
284286                logger .debug ("FasterCache - Skipping unconditional branch computation" )
@@ -483,7 +485,7 @@ def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
483485
484486def  apply_faster_cache (
485487    module : torch .nn .Module ,
486-     config : Optional [ FasterCacheConfig ]  =   None , 
488+     config : FasterCacheConfig 
487489) ->  None :
488490    r""" 
489491    Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline. 
@@ -515,10 +517,6 @@ def apply_faster_cache(
515517    ``` 
516518    """ 
517519
518-     if  config  is  None :
519-         logger .warning ("No FasterCacheConfig provided. Using default configuration." )
520-         config  =  FasterCacheConfig ()
521- 
522520    if  config .attention_weight_callback  is  None :
523521        # If the user has not provided a weight callback, we default to 0.5 for all timesteps. 
524522        # In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but 
@@ -568,7 +566,8 @@ def high_frequency_weight_callback(module: torch.nn.Module) -> float:
568566    for  name , submodule  in  module .named_modules ():
569567        if  not  isinstance (submodule , _ATTENTION_CLASSES ):
570568            continue 
571-         _apply_faster_cache_on_attention_class (name , submodule , config )
569+         if  any (re .search (identifier , name ) is  not None  for  identifier  in  _TRANSFORMER_BLOCK_IDENTIFIERS ):
570+             _apply_faster_cache_on_attention_class (name , submodule , config )
572571
573572
574573def  _apply_faster_cache_on_denoiser (module : torch .nn .Module , config : FasterCacheConfig ) ->  None :
@@ -590,13 +589,10 @@ def _apply_faster_cache_on_attention_class(name: str, module: Attention, config:
590589    is_spatial_self_attention  =  (
591590        any (re .search (identifier , name ) is  not None  for  identifier  in  config .spatial_attention_block_identifiers )
592591        and  config .spatial_attention_block_skip_range  is  not None 
593-         and  not  module . is_cross_attention 
592+         and  not  getattr ( module ,  " is_cross_attention" ,  False ) 
594593    )
595594    is_temporal_self_attention  =  (
596-         any (
597-             f"{ identifier }   in  name  or  identifier  ==  name 
598-             for  identifier  in  config .temporal_attention_block_identifiers 
599-         )
595+         any (re .search (identifier , name ) is  not None  for  identifier  in  config .temporal_attention_block_identifiers )
600596        and  config .temporal_attention_block_skip_range  is  not None 
601597        and  not  module .is_cross_attention 
602598    )
@@ -633,7 +629,7 @@ def _apply_faster_cache_on_attention_class(name: str, module: Attention, config:
633629    registry .register_hook (hook , _FASTER_CACHE_BLOCK_HOOK )
634630
635631
636- # Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/fastercache_sample_latte .py#L127C1-L143C39 
632+ # Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/faster_cache_sample_latte .py#L127C1-L143C39 
637633@torch .no_grad () 
638634def  _split_low_high_freq (x ):
639635    fft  =  torch .fft .fft2 (x )
0 commit comments