4949class  FasterCacheConfig :
5050    r""" 
5151    Configuration for [FasterCache](https://huggingface.co/papers/2410.19355). 
52-     """ 
5352
54-     num_train_timesteps :  int   =   1000 
53+     Attributes:"""  
5554
5655    # In the paper and codebase, they hardcode these values to 2. However, it can be made configurable 
5756    # after some testing. We default to 2 if these parameters are not provided. 
58-     spatial_attention_block_skip_range : Optional [ int ]  =  None 
57+     spatial_attention_block_skip_range : int  =  2 
5958    temporal_attention_block_skip_range : Optional [int ] =  None 
6059
6160    # TODO(aryan): write heuristics for what the best way to obtain these values are 
@@ -145,6 +144,9 @@ def apply_faster_cache(
145144    r""" 
146145    Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline. 
147146
147+     Note: FasterCache should only be applied when using classifer-free guidance. It will not work as expected even if 
148+     the inference runs successfully. 
149+ 
148150    Args: 
149151        pipeline (`DiffusionPipeline`): 
150152            The diffusion pipeline to apply FasterCache to. 
@@ -163,15 +165,6 @@ def apply_faster_cache(
163165    if  config  is  None :
164166        config  =  FasterCacheConfig ()
165167
166-     if  config .spatial_attention_block_skip_range  is  None  and  config .temporal_attention_block_skip_range  is  None :
167-         logger .warning (
168-             "FasterCache requires one of `spatial_attention_block_skip_range` and/or `temporal_attention_block_skip_range` " 
169-             "to be set to an integer, not `None`. Defaulting to using `spatial_attention_block_skip_range=2` and " 
170-             "`temporal_attention_block_skip_range=2`. To avoid this warning, please set one of the above parameters." 
171-         )
172-         config .spatial_attention_block_skip_range  =  2 
173-         config .temporal_attention_block_skip_range  =  2 
174- 
175168    if  config .attention_weight_callback  is  None :
176169        # If the user has not provided a weight callback, we default to 0.5 for all timesteps. 
177170        # In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but 
@@ -231,12 +224,6 @@ def _apply_fastercache_on_denoiser(
231224    pipeline : DiffusionPipeline , denoiser : nn .Module , config : FasterCacheConfig 
232225) ->  None :
233226    def  uncond_skip_callback (module : nn .Module ) ->  bool :
234-         # If we are not using classifier-free guidance, we cannot skip the denoiser computation. We only compute the 
235-         # conditional branch in this case. 
236-         is_using_classifier_free_guidance  =  pipeline .do_classifier_free_guidance 
237-         if  not  is_using_classifier_free_guidance :
238-             return  False 
239- 
240227        # We skip the unconditional branch only if the following conditions are met: 
241228        #   1. We have completed at least one iteration of the denoiser 
242229        #   2. The current timestep is within the range specified by the user. This is the optimal timestep range 
@@ -298,20 +285,13 @@ def _apply_fastercache_on_attention_class(
298285        return 
299286
300287    def  skip_callback (module : nn .Module ) ->  bool :
301-         is_using_classifier_free_guidance  =  pipeline .do_classifier_free_guidance 
302-         if  not  is_using_classifier_free_guidance :
303-             return  False 
304- 
305288        fastercache_state : FasterCacheState  =  module ._fastercache_state 
306289        is_within_timestep_range  =  timestep_skip_range [0 ] <  pipeline ._current_timestep  <  timestep_skip_range [1 ]
307290
308291        if  not  is_within_timestep_range :
309292            # We are still not in the phase of inference where skipping attention is possible without minimal quality 
310293            # loss, as described in the paper. So, the attention computation cannot be skipped 
311294            return  False 
312-         if  fastercache_state .cache  is  None  or  fastercache_state .iteration  <  2 :
313-             # We need at least 2 iterations to start skipping attention computation 
314-             return  False 
315295
316296        should_compute_attention  =  (
317297            fastercache_state .iteration  >  0  and  fastercache_state .iteration  %  block_skip_range  ==  0 
@@ -358,8 +338,6 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any:
358338            # TODO(aryan): remove later 
359339            logger .debug ("Skipping unconditional branch computation" )
360340
361-         if  should_skip_uncond :
362-             breakpoint ()
363341        output  =  module ._old_forward (* args , ** kwargs )
364342        # TODO(aryan): handle Transformer2DModelOutput 
365343        hidden_states  =  output [0 ] if  isinstance (output , tuple ) else  output 
@@ -422,6 +400,22 @@ def reset_state(self, module: nn.Module) -> None:
422400class  FasterCacheBlockHook (ModelHook ):
423401    _is_stateful  =  True 
424402
403+     def  _compute_approximated_attention_output (
404+         self , t_2_output : torch .Tensor , t_output : torch .Tensor , weight : float , batch_size : int 
405+     ) ->  torch .Tensor :
406+         # TODO(aryan): these conditions may not be needed after latest refactor. they exist for safety. do test if they can be removed 
407+         if  t_2_output .size (0 ) !=  batch_size :
408+             # The cache t_2_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just 
409+             # take the conditional branch outputs. 
410+             assert  t_2_output .size (0 ) ==  2  *  batch_size 
411+             t_2_output  =  t_2_output [batch_size :]
412+         if  t_output .size (0 ) !=  batch_size :
413+             # The cache t_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just 
414+             # take the conditional branch outputs. 
415+             assert  t_output .size (0 ) ==  2  *  batch_size 
416+             t_output  =  t_output [batch_size :]
417+         return  t_output  +  (t_output  -  t_2_output ) *  weight 
418+ 
425419    def  new_forward (self , module : nn .Module , * args , ** kwargs ) ->  Any :
426420        args , kwargs  =  module ._diffusers_hook .pre_forward (module , * args , ** kwargs )
427421        state : FasterCacheState  =  module ._fastercache_state 
@@ -435,40 +429,59 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any:
435429            state .batch_size  =  batch_size 
436430
437431        # If we have to skip due to the skip conditions, then let's skip as expected. 
438-         # But, we can't skip if the denoiser wants to infer both unconditional and conditional branches. So, 
439-         # if state.batch_size (which is the true unconditional-conditional batch size) is same as the current 
440-         # batch size, we don't perform the layer skip. Otherwise, we conditionally skip the layer based on 
441-         # what state.skip_callback returns. 
442-         if  state .skip_callback (module ) and  state .batch_size  !=  batch_size :
432+         # But, we can't skip if the denoiser wants to infer both unconditional and conditional branches. This 
433+         # is because the expected output shapes of attention layer will not match if we only return values from 
434+         # the cache (which only caches conditional branch outputs). So, if state.batch_size (which is the true 
435+         # unconditional-conditional batch size) is same as the current batch size, we don't perform the layer 
436+         # skip. Otherwise, we conditionally skip the layer based on what state.skip_callback returns. 
437+         should_skip_attention  =  state .skip_callback (module ) and  state .batch_size  !=  batch_size 
438+ 
439+         if  should_skip_attention :
443440            # TODO(aryan): remove later 
444-             logger .debug ("Skipping layer computation" )
445-             t_2_output , t_output  =  state .cache 
446- 
447-             # TODO(aryan): these conditions may not be needed after latest refactor. they exist for safety. do test if they can be removed 
448-             if  t_2_output .size (0 ) !=  batch_size :
449-                 # The cache t_2_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just 
450-                 # take the conditional branch outputs. 
451-                 assert  t_2_output .size (0 ) ==  2  *  batch_size 
452-                 t_2_output  =  t_2_output [batch_size :]
453-             if  t_output .size (0 ) !=  batch_size :
454-                 # The cache t_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just 
455-                 # take the conditional branch outputs. 
456-                 assert  t_output .size (0 ) ==  2  *  batch_size 
457-                 t_output  =  t_output [batch_size :]
458- 
459-             output  =  t_output  +  (t_output  -  t_2_output ) *  state .weight_callback (module )
441+             logger .debug ("Skipping attention" )
442+ 
443+             if  torch .is_tensor (state .cache ):
444+                 t_2_output , t_output  =  state .cache 
445+                 weight  =  state .weight_callback (module )
446+                 output  =  self ._compute_approximated_attention_output (t_2_output , t_output , weight , batch_size )
447+             else :
448+                 # The cache contains multiple tensors from past N iterations (N=2 for FasterCache). We need to handle all of them. 
449+                 # Diffusers blocks can return multiple tensors - let's call them [A, B, C, ...] for simplicity. 
450+                 # In our cache, we would have [[A_1, B_1, C_1, ...], [A_2, B_2, C_2, ...], ...] where each list is the output from 
451+                 # a forward pass of the block. We need to compute the approximated output for each of these tensors. 
452+                 # The zip(*state.cache) operation will give us [(A_1, A_2, ...), (B_1, B_2, ...), (C_1, C_2, ...), ...] which 
453+                 # allows us to compute the approximated attention output for each tensor in the cache. 
454+                 output  =  ()
455+                 for  t_2_output , t_output  in  zip (* state .cache ):
456+                     result  =  self ._compute_approximated_attention_output (
457+                         t_2_output , t_output , state .weight_callback (module ), batch_size 
458+                     )
459+                     output  +=  (result ,)
460460        else :
461+             logger .debug ("Computing attention" )
461462            output  =  module ._old_forward (* args , ** kwargs )
462463
463-         # The output here can be both unconditional-conditional branch outputs or just conditional branch outputs. 
464-         # This is determined at the higher-level denoiser module. We only want to cache the conditional branch outputs. 
465-         cache_output  =  output 
466-         if  output .size (0 ) ==  state .batch_size :
467-             cache_output  =  cache_output .chunk (2 , dim = 0 )[1 ]
468- 
469-         # Just to be safe that the output is of the correct size for both unconditional-conditional branch inference 
470-         # and only-conditional branch inference. 
471-         assert  2  *  cache_output .size (0 ) ==  state .batch_size 
464+         # Note that the following condition for getting hidden_states should suffice since Diffusers blocks either return 
465+         # a single hidden_states tensor, or a tuple of (hidden_states, encoder_hidden_states) tensors. We need to handle 
466+         # both cases. 
467+         if  torch .is_tensor (output ):
468+             cache_output  =  output 
469+             if  cache_output .size (0 ) ==  state .batch_size :
470+                 # The output here can be both unconditional-conditional branch outputs or just conditional branch outputs. 
471+                 # This is determined at the higher-level denoiser module. We only want to cache the conditional branch outputs. 
472+                 cache_output  =  cache_output .chunk (2 , dim = 0 )[1 ]
473+ 
474+             # Just to be safe that the output is of the correct size for both unconditional-conditional branch inference 
475+             # and only-conditional branch inference. 
476+             assert  2  *  cache_output .size (0 ) ==  state .batch_size 
477+         else :
478+             # Cache all return values and perform the same operation as above 
479+             cache_output  =  ()
480+             for  out  in  output :
481+                 if  out .size (0 ) ==  state .batch_size :
482+                     out  =  out .chunk (2 , dim = 0 )[1 ]
483+                 assert  2  *  out .size (0 ) ==  state .batch_size 
484+                 cache_output  +=  (out ,)
472485
473486        if  state .cache  is  None :
474487            state .cache  =  [cache_output , cache_output ]
0 commit comments