2929logger  =  logging .get_logger (__name__ )  # pylint: disable=invalid-name 
3030
3131
32+ # TODO(aryan): handle mochi attention 
3233_ATTENTION_CLASSES  =  (Attention ,)
3334
3435_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS  =  (
@@ -63,73 +64,75 @@ class FasterCacheConfig:
6364            states again. 
6465        spatial_attention_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 681)`): 
6566            The timestep range within which the spatial attention computation can be skipped without a significant loss 
66-             in quality. This is to be determined by the user based on the underlying model. The first value in the tuple 
67-             is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for denoising are 
68-             in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at timestep 0). For the 
69-             default values, this would mean that the spatial attention computation skipping will be applicable only 
70-             after denoising timestep 681 is reached, and continue until the end of the denoising process. 
67+             in quality. This is to be determined by the user based on the underlying model. The first value in the 
68+             tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for 
69+             denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at 
70+             timestep 0). For the default values, this would mean that the spatial attention computation skipping will 
71+             be applicable only after denoising timestep 681 is reached, and continue until the end of the denoising 
72+             process. 
7173        temporal_attention_timestep_skip_range (`Tuple[float, float]`, *optional*, defaults to `None`): 
72-             The timestep range within which the temporal attention computation can be skipped without a significant loss 
73-             in quality. This is to be determined by the user based on the underlying model. The first value in the tuple 
74-             is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for denoising are 
75-             in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at timestep 0). 
74+             The timestep range within which the temporal attention computation can be skipped without a significant 
75+             loss in quality. This is to be determined by the user based on the underlying model. The first value in the 
76+             tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for 
77+             denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at 
78+             timestep 0). 
7679        low_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(99, 901)`): 
77-             The timestep range within which the low frequency weight scaling update is applied. The first value in the tuple  
78-             is the lower bound and the second value is the upper bound of the timestep range. The callback function for  
79-             the update is called only within this range. 
80+             The timestep range within which the low frequency weight scaling update is applied. The first value in the 
81+             tuple  is the lower bound and the second value is the upper bound of the timestep range. The callback 
82+             function for  the update is called only within this range. 
8083        high_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(-1, 301)`): 
81-             The timestep range within which the high frequency weight scaling update is applied. The first value in the tuple  
82-             is the lower bound and the second value is the upper bound of the timestep range. The callback function for  
83-             the update is called only within this range. 
84+             The timestep range within which the high frequency weight scaling update is applied. The first value in the 
85+             tuple  is the lower bound and the second value is the upper bound of the timestep range. The callback 
86+             function for  the update is called only within this range. 
8487        alpha_low_frequency (`float`, defaults to `1.1`): 
8588            The weight to scale the low frequency updates by. This is used to approximate the unconditional branch from 
8689            the conditional branch outputs. 
8790        alpha_high_frequency (`float`, defaults to `1.1`): 
88-             The weight to scale the high frequency updates by. This is used to approximate the unconditional branch from  
89-             the conditional branch outputs. 
91+             The weight to scale the high frequency updates by. This is used to approximate the unconditional branch 
92+             from  the conditional branch outputs. 
9093        unconditional_batch_skip_range (`int`, defaults to `5`): 
9194            Process the unconditional branch every `N` iterations. If this is set to `N`, the unconditional branch 
9295            computation will be skipped `N - 1` times (i.e., cached unconditional branch states will be re-used) before 
9396            computing the new unconditional branch states again. 
9497        unconditional_batch_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 641)`): 
95-             The timestep range within which the unconditional branch computation can be skipped without a significant loss  
96-             in quality. This is to be determined by the user based on the underlying model. The first value in the tuple  
97-             is the lower bound and the second value is the upper bound. 
98+             The timestep range within which the unconditional branch computation can be skipped without a significant 
99+             loss  in quality. This is to be determined by the user based on the underlying model. The first value in the 
100+             tuple  is the lower bound and the second value is the upper bound. 
98101        spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")`): 
99-             The identifiers to match the spatial attention blocks in the model. If the name of the block contains any of  
100-             these identifiers, FasterCache will be applied to that block. This can either be the full layer names, partial  
101-             layer names, or regex patterns. Matching will always be done using a regex match. 
102+             The identifiers to match the spatial attention blocks in the model. If the name of the block contains any 
103+             of  these identifiers, FasterCache will be applied to that block. This can either be the full layer names, 
104+             partial  layer names, or regex patterns. Matching will always be done using a regex match. 
102105        temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks.*attn1",)`): 
103-             The identifiers to match the temporal attention blocks in the model. If the name of the block contains any of  
104-             these identifiers, FasterCache will be applied to that block. This can either be the full layer names, partial  
105-             layer names, or regex patterns. Matching will always be done using a regex match. 
106+             The identifiers to match the temporal attention blocks in the model. If the name of the block contains any 
107+             of  these identifiers, FasterCache will be applied to that block. This can either be the full layer names, 
108+             partial  layer names, or regex patterns. Matching will always be done using a regex match. 
106109        attention_weight_callback (`Callable[[nn.Module], float]`, defaults to `None`): 
107-             The callback function to determine the weight to scale the attention outputs by. This function should take the  
108-             attention module as input and return a float value. This is used to approximate the unconditional branch from  
109-             the conditional branch outputs. If not provided, the default weight is 0.5 for all timesteps. Typically, as  
110-             described in the paper, this weight should gradually increase from 0 to 1 as the inference progresses. Users  
111-             are encouraged to experiment and provide custom weight schedules that take into account the number of inference  
112-             steps and underlying model behaviour as denoising progresses. 
110+             The callback function to determine the weight to scale the attention outputs by. This function should take 
111+             the  attention module as input and return a float value. This is used to approximate the unconditional 
112+             branch from  the conditional branch outputs. If not provided, the default weight is 0.5 for all timesteps. 
113+             Typically, as  described in the paper, this weight should gradually increase from 0 to 1 as the inference 
114+             progresses. Users  are encouraged to experiment and provide custom weight schedules that take into account 
115+             the number of inference  steps and underlying model behaviour as denoising progresses. 
113116        low_frequency_weight_callback (`Callable[[nn.Module], float]`, defaults to `None`): 
114-             The callback function to determine the weight to scale the low frequency updates by. If not provided, the default  
115-             weight is 1.1 for timesteps within the range specified (as described in the paper). 
117+             The callback function to determine the weight to scale the low frequency updates by. If not provided, the 
118+             default  weight is 1.1 for timesteps within the range specified (as described in the paper). 
116119        high_frequency_weight_callback (`Callable[[nn.Module], float]`, defaults to `None`): 
117-             The callback function to determine the weight to scale the high frequency updates by. If not provided, the default  
118-             weight is 1.1 for timesteps within the range specified (as described in the paper). 
120+             The callback function to determine the weight to scale the high frequency updates by. If not provided, the 
121+             default  weight is 1.1 for timesteps within the range specified (as described in the paper). 
119122        tensor_format (`str`, defaults to `"BCFHW"`): 
120-             The format of the input tensors. This should be one of `"BCFHW"`, `"BFCHW"`, or `"BCHW"`. The format is used to  
121-             split individual latent frames in order for low and high frequency components to be computed. 
123+             The format of the input tensors. This should be one of `"BCFHW"`, `"BFCHW"`, or `"BCHW"`. The format is 
124+             used to  split individual latent frames in order for low and high frequency components to be computed. 
122125        _unconditional_conditional_input_kwargs_identifiers (`List[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`): 
123-             The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and conditional  
124-             inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will split the inputs  
125-             into unconditional and conditional branches. This must be a list of exact input kwargs names that contain the  
126-             batchwise-concatenated unconditional and conditional inputs. 
126+             The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and 
127+             conditional  inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will 
128+             split the inputs  into unconditional and conditional branches. This must be a list of exact input kwargs 
129+             names that contain the  batchwise-concatenated unconditional and conditional inputs. 
127130        _guidance_distillation_kwargs_identifiers (`List[str]`, defaults to `("guidance",)`): 
128-             The identifiers to match the input kwargs that contain the guidance distillation inputs. If the name of the input  
129-             kwargs contains any of these identifiers, FasterCache will not split the inputs into unconditional and conditional  
130-             branches (unconditional branches are only computed sometimes based on certain checks). This allows usage of  
131-             FasterCache in models like Flux-Dev and HunyuanVideo which are guidance-distilled (only attention skipping  
132-             related parts are applied, and not unconditional branch approximation). 
131+             The identifiers to match the input kwargs that contain the guidance distillation inputs. If the name of the 
132+             input  kwargs contains any of these identifiers, FasterCache will not split the inputs into unconditional 
133+             and conditional  branches (unconditional branches are only computed sometimes based on certain checks). This 
134+             allows usage of  FasterCache in models like Flux-Dev and HunyuanVideo which are guidance-distilled (only 
135+             attention skipping  related parts are applied, and not unconditional branch approximation). 
133136    """ 
134137
135138    # In the paper and codebase, they hardcode these values to 2. However, it can be made configurable 
@@ -225,7 +228,6 @@ def reset(self):
225228def  apply_faster_cache (
226229    pipeline : DiffusionPipeline ,
227230    config : Optional [FasterCacheConfig ] =  None ,
228-     denoiser : Optional [nn .Module ] =  None ,
229231) ->  None :
230232    r""" 
231233    Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline. 
@@ -238,9 +240,6 @@ def apply_faster_cache(
238240            The diffusion pipeline to apply FasterCache to. 
239241        config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`): 
240242            The configuration to use for FasterCache. 
241-         denoiser (`Optional[nn.Module]`, `optional`, defaults to `None`): 
242-             The denoiser module to apply FasterCache to. If `None`, the pipeline's transformer or unet module will be 
243-             used. 
244243
245244    Example: 
246245    ```python 
@@ -310,8 +309,7 @@ def high_frequency_weight_callback(module: nn.Module) -> float:
310309    if  config .tensor_format  not  in   supported_tensor_formats :
311310        raise  ValueError (f"`tensor_format` must be one of { supported_tensor_formats }  , but got { config .tensor_format }  ." )
312311
313-     if  denoiser  is  None :
314-         denoiser  =  pipeline .transformer  if  hasattr (pipeline , "transformer" ) else  pipeline .unet 
312+     denoiser  =  pipeline .transformer  if  hasattr (pipeline , "transformer" ) else  pipeline .unet 
315313    _apply_fastercache_on_denoiser (pipeline , denoiser , config )
316314
317315    for  name , module  in  denoiser .named_modules ():
@@ -344,7 +342,11 @@ def uncond_skip_callback(module: nn.Module) -> bool:
344342    denoiser ._fastercache_state  =  FasterCacheDenoiserState (
345343        config .low_frequency_weight_callback , config .high_frequency_weight_callback , uncond_skip_callback 
346344    )
347-     hook  =  FasterCacheDenoiserHook (config ._unconditional_conditional_input_kwargs_identifiers , config ._guidance_distillation_kwargs_identifiers , config .tensor_format )
345+     hook  =  FasterCacheDenoiserHook (
346+         config ._unconditional_conditional_input_kwargs_identifiers ,
347+         config ._guidance_distillation_kwargs_identifiers ,
348+         config .tensor_format ,
349+     )
348350    add_hook_to_module (denoiser , hook , append = True )
349351
350352
@@ -408,11 +410,12 @@ def skip_callback(module: nn.Module) -> bool:
408410class  FasterCacheDenoiserHook (ModelHook ):
409411    _is_stateful  =  True 
410412
411-     def  __init__ (self ,
412-             uncond_cond_input_kwargs_identifiers : List [str ],
413-             guidance_distillation_kwargs_identifiers : List [str ],
414-             tensor_format : str 
415-         ) ->  None :
413+     def  __init__ (
414+         self ,
415+         uncond_cond_input_kwargs_identifiers : List [str ],
416+         guidance_distillation_kwargs_identifiers : List [str ],
417+         tensor_format : str ,
418+     ) ->  None :
416419        super ().__init__ ()
417420
418421        # We can't easily detect what args are to be split in unconditional and conditional branches. We 
@@ -451,7 +454,9 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any:
451454            # Make all children FasterCacheBlockHooks aware of whether the model is guidance distilled or not 
452455            # because we cannot determine this within the block hooks 
453456            for  name , child_module  in  module .named_modules ():
454-                 if  hasattr (child_module , "_fastercache_state" ) and  isinstance (child_module ._fastercache_state , FasterCacheBlockState ):
457+                 if  hasattr (child_module , "_fastercache_state" ) and  isinstance (
458+                     child_module ._fastercache_state , FasterCacheBlockState 
459+                 ):
455460                    # TODO(aryan): remove later 
456461                    logger .debug (f"Setting guidance distillation flag for layer: { name }  " )
457462                    child_module ._fastercache_state .is_guidance_distilled  =  state .is_guidance_distilled 
@@ -570,7 +575,9 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any:
570575        # the cache (which only caches conditional branch outputs). So, if state.batch_size (which is the true 
571576        # unconditional-conditional batch size) is same as the current batch size, we don't perform the layer 
572577        # skip. Otherwise, we conditionally skip the layer based on what state.skip_callback returns. 
573-         should_skip_attention  =  state .skip_callback (module ) and  (state .is_guidance_distilled  or  state .batch_size  !=  batch_size )
578+         should_skip_attention  =  state .skip_callback (module ) and  (
579+             state .is_guidance_distilled  or  state .batch_size  !=  batch_size 
580+         )
574581
575582        if  should_skip_attention :
576583            # TODO(aryan): remove later 
0 commit comments