Skip to content

Commit f3cb80c

Browse files
committed
update
1 parent d68977d commit f3cb80c

File tree

1 file changed

+13
-16
lines changed

1 file changed

+13
-16
lines changed

src/diffusers/pipelines/faster_cache_utils.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -141,20 +141,20 @@ class FasterCacheConfig:
141141
temporal_attention_block_skip_range: Optional[int] = None
142142

143143
# TODO(aryan): write heuristics for what the best way to obtain these values are
144-
spatial_attention_timestep_skip_range: Tuple[float, float] = (-1, 681)
145-
temporal_attention_timestep_skip_range: Tuple[float, float] = (-1, 681)
144+
spatial_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
145+
temporal_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
146146

147147
# Indicator functions for low/high frequency as mentioned in Equation 11 of the paper
148148
low_frequency_weight_update_timestep_range: Tuple[int, int] = (99, 901)
149149
high_frequency_weight_update_timestep_range: Tuple[int, int] = (-1, 301)
150150

151151
# ⍺1 and ⍺2 as mentioned in Equation 11 of the paper
152-
alpha_low_frequency = 1.1
153-
alpha_high_frequency = 1.1
152+
alpha_low_frequency: float = 1.1
153+
alpha_high_frequency: float = 1.1
154154

155155
# n as described in CFG-Cache explanation in the paper - dependant on the model
156156
unconditional_batch_skip_range: int = 5
157-
unconditional_batch_timestep_skip_range: Tuple[float, float] = (-1, 641)
157+
unconditional_batch_timestep_skip_range: Tuple[int, int] = (-1, 641)
158158

159159
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
160160
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
@@ -184,10 +184,10 @@ def __init__(
184184
self.high_frequency_weight_callback = high_frequency_weight_callback
185185
self.uncond_skip_callback = uncond_skip_callback
186186

187-
self.iteration = 0
188-
self.low_frequency_delta = None
189-
self.high_frequency_delta = None
190-
self.is_guidance_distilled = None
187+
self.iteration: int = 0
188+
self.low_frequency_delta: torch.Tensor = None
189+
self.high_frequency_delta: torch.Tensor = None
190+
self.is_guidance_distilled: bool = None
191191

192192
def reset(self):
193193
self.iteration = 0
@@ -213,10 +213,10 @@ def __init__(
213213
self.skip_callback = skip_callback
214214
self.weight_callback = weight_callback
215215

216-
self.iteration = 0
217-
self.batch_size = None
218-
self.cache = None
219-
self.is_guidance_distilled = None
216+
self.iteration: int = 0
217+
self.batch_size: int = None
218+
self.cache: Tuple[torch.Tensor, torch.Tensor] = None
219+
self.is_guidance_distilled: bool = None
220220

221221
def reset(self):
222222
self.iteration = 0
@@ -232,9 +232,6 @@ def apply_faster_cache(
232232
r"""
233233
Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
234234
235-
Note: FasterCache should only be applied when using classifer-free guidance. It will not work as expected even if
236-
the inference runs successfully.
237-
238235
Args:
239236
pipeline (`DiffusionPipeline`):
240237
The diffusion pipeline to apply FasterCache to.

0 commit comments

Comments
 (0)