Skip to content

Commit 309ce72

Browse files
author
toilaluan
committed
quality & style
1 parent 83b6253 commit 309ce72

File tree

1 file changed

+22
-26
lines changed

1 file changed

+22
-26
lines changed

src/diffusers/hooks/taylorseer_cache.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -30,34 +30,35 @@ class TaylorSeerCacheConfig:
3030
3131
Attributes:
3232
cache_interval (`int`, defaults to `5`):
33-
The interval between full computation steps. After a full computation, the cached (predicted) outputs are reused
34-
for this many subsequent denoising steps before refreshing with a new full forward pass.
33+
The interval between full computation steps. After a full computation, the cached (predicted) outputs are
34+
reused for this many subsequent denoising steps before refreshing with a new full forward pass.
3535
3636
disable_cache_before_step (`int`, defaults to `3`):
37-
The denoising step index before which caching is disabled, meaning full computation is performed for the initial
38-
steps (0 to disable_cache_before_step - 1) to gather data for Taylor series approximations. During these steps,
39-
Taylor factors are updated, but caching/predictions are not applied. Caching begins at this step.
37+
The denoising step index before which caching is disabled, meaning full computation is performed for the
38+
initial steps (0 to disable_cache_before_step - 1) to gather data for Taylor series approximations. During
39+
these steps, Taylor factors are updated, but caching/predictions are not applied. Caching begins at this
40+
step.
4041
4142
disable_cache_after_step (`int`, *optional*, defaults to `None`):
42-
The denoising step index after which caching is disabled. If set, for steps >= this value, all modules run full
43-
computations without predictions or state updates, ensuring accuracy in later stages if needed.
43+
The denoising step index after which caching is disabled. If set, for steps >= this value, all modules run
44+
full computations without predictions or state updates, ensuring accuracy in later stages if needed.
4445
4546
max_order (`int`, defaults to `1`):
46-
The highest order in the Taylor series expansion for approximating module outputs. Higher orders provide better
47-
approximations but increase computation and memory usage.
47+
The highest order in the Taylor series expansion for approximating module outputs. Higher orders provide
48+
better approximations but increase computation and memory usage.
4849
4950
taylor_factors_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
50-
Data type used for storing and computing Taylor series factors. Lower precision reduces memory but may affect
51-
stability; higher precision improves accuracy at the cost of more memory.
51+
Data type used for storing and computing Taylor series factors. Lower precision reduces memory but may
52+
affect stability; higher precision improves accuracy at the cost of more memory.
5253
5354
inactive_identifiers (`List[str]`, *optional*, defaults to `None`):
54-
Regex patterns (using `re.fullmatch`) for module names to place in "inactive" mode. In this mode, the module
55-
computes fully during initial or refresh steps but returns a zero tensor (matching recorded shape) during
56-
prediction steps to skip computation cheaply.
55+
Regex patterns (using `re.fullmatch`) for module names to place in "inactive" mode. In this mode, the
56+
module computes fully during initial or refresh steps but returns a zero tensor (matching recorded shape)
57+
during prediction steps to skip computation cheaply.
5758
5859
active_identifiers (`List[str]`, *optional*, defaults to `None`):
59-
Regex patterns (using `re.fullmatch`) for module names to place in Taylor-series caching mode, where outputs
60-
are approximated and cached for reuse.
60+
Regex patterns (using `re.fullmatch`) for module names to place in Taylor-series caching mode, where
61+
outputs are approximated and cached for reuse.
6162
6263
use_lite_mode (`bool`, *optional*, defaults to `False`):
6364
Enables a lightweight TaylorSeer variant that minimizes memory usage by applying predefined patterns for
@@ -118,7 +119,6 @@ def __init__(
118119
self.device: Optional[torch.device] = None
119120
self.current_step: int = -1
120121

121-
122122
def reset(self) -> None:
123123
self.current_step = -1
124124
self.last_update_step = None
@@ -223,13 +223,9 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
223223
state.current_step += 1
224224
current_step = state.current_step
225225
is_warmup_phase = current_step < self.disable_cache_before_step
226-
is_compute_interval = ((current_step - self.disable_cache_before_step - 1) % self.cache_interval == 0)
226+
is_compute_interval = (current_step - self.disable_cache_before_step - 1) % self.cache_interval == 0
227227
is_cooldown_phase = self.disable_cache_after_step is not None and current_step >= self.disable_cache_after_step
228-
should_compute = (
229-
is_warmup_phase
230-
or is_compute_interval
231-
or is_cooldown_phase
232-
)
228+
should_compute = is_warmup_phase or is_compute_interval or is_cooldown_phase
233229
if should_compute:
234230
outputs = self.fn_ref.original_forward(*args, **kwargs)
235231
wrapped_outputs = (outputs,) if isinstance(outputs, torch.Tensor) else outputs
@@ -255,8 +251,8 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi
255251
"""
256252
Applies the TaylorSeer cache to a given pipeline (typically the transformer / UNet).
257253
258-
This function hooks selected modules in the model to enable caching or skipping based on the provided configuration,
259-
reducing redundant computations in diffusion denoising loops.
254+
This function hooks selected modules in the model to enable caching or skipping based on the provided
255+
configuration, reducing redundant computations in diffusion denoising loops.
260256
261257
Args:
262258
module (torch.nn.Module): The model subtree to apply the hooks to.
@@ -338,4 +334,4 @@ def _apply_taylorseer_cache_hook(
338334
state_manager=state_manager,
339335
)
340336

341-
registry.register_hook(hook, _TAYLORSEER_CACHE_HOOK)
337+
registry.register_hook(hook, _TAYLORSEER_CACHE_HOOK)

0 commit comments

Comments
 (0)