@@ -141,20 +141,20 @@ class FasterCacheConfig:
141141 temporal_attention_block_skip_range : Optional [int ] = None
142142
143143 # TODO(aryan): write heuristics for what the best way to obtain these values are
144- spatial_attention_timestep_skip_range : Tuple [float , float ] = (- 1 , 681 )
145- temporal_attention_timestep_skip_range : Tuple [float , float ] = (- 1 , 681 )
144+ spatial_attention_timestep_skip_range : Tuple [int , int ] = (- 1 , 681 )
145+ temporal_attention_timestep_skip_range : Tuple [int , int ] = (- 1 , 681 )
146146
147147 # Indicator functions for low/high frequency as mentioned in Equation 11 of the paper
148148 low_frequency_weight_update_timestep_range : Tuple [int , int ] = (99 , 901 )
149149 high_frequency_weight_update_timestep_range : Tuple [int , int ] = (- 1 , 301 )
150150
151151 # ⍺1 and ⍺2 as mentioned in Equation 11 of the paper
152- alpha_low_frequency = 1.1
153- alpha_high_frequency = 1.1
152+ alpha_low_frequency : float = 1.1
153+ alpha_high_frequency : float = 1.1
154154
155155 # n as described in CFG-Cache explanation in the paper - dependant on the model
156156 unconditional_batch_skip_range : int = 5
157- unconditional_batch_timestep_skip_range : Tuple [float , float ] = (- 1 , 641 )
157+ unconditional_batch_timestep_skip_range : Tuple [int , int ] = (- 1 , 641 )
158158
159159 spatial_attention_block_identifiers : Tuple [str , ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
160160 temporal_attention_block_identifiers : Tuple [str , ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
@@ -184,10 +184,10 @@ def __init__(
184184 self .high_frequency_weight_callback = high_frequency_weight_callback
185185 self .uncond_skip_callback = uncond_skip_callback
186186
187- self .iteration = 0
188- self .low_frequency_delta = None
189- self .high_frequency_delta = None
190- self .is_guidance_distilled = None
187+ self .iteration : int = 0
188+ self .low_frequency_delta : torch . Tensor = None
189+ self .high_frequency_delta : torch . Tensor = None
190+ self .is_guidance_distilled : bool = None
191191
192192 def reset (self ):
193193 self .iteration = 0
@@ -213,10 +213,10 @@ def __init__(
213213 self .skip_callback = skip_callback
214214 self .weight_callback = weight_callback
215215
216- self .iteration = 0
217- self .batch_size = None
218- self .cache = None
219- self .is_guidance_distilled = None
216+ self .iteration : int = 0
217+ self .batch_size : int = None
218+ self .cache : Tuple [ torch . Tensor , torch . Tensor ] = None
219+ self .is_guidance_distilled : bool = None
220220
221221 def reset (self ):
222222 self .iteration = 0
@@ -232,9 +232,6 @@ def apply_faster_cache(
232232 r"""
233233 Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
234234
235- Note: FasterCache should only be applied when using classifer-free guidance. It will not work as expected even if
236- the inference runs successfully.
237-
238235 Args:
239236 pipeline (`DiffusionPipeline`):
240237 The diffusion pipeline to apply FasterCache to.
0 commit comments