@@ -30,34 +30,35 @@ class TaylorSeerCacheConfig:
3030
3131 Attributes:
3232 cache_interval (`int`, defaults to `5`):
33- The interval between full computation steps. After a full computation, the cached (predicted) outputs are reused
34- for this many subsequent denoising steps before refreshing with a new full forward pass.
33+ The interval between full computation steps. After a full computation, the cached (predicted) outputs are
34+ reused for this many subsequent denoising steps before refreshing with a new full forward pass.
3535
3636 disable_cache_before_step (`int`, defaults to `3`):
37- The denoising step index before which caching is disabled, meaning full computation is performed for the initial
38- steps (0 to disable_cache_before_step - 1) to gather data for Taylor series approximations. During these steps,
39- Taylor factors are updated, but caching/predictions are not applied. Caching begins at this step.
37+ The denoising step index before which caching is disabled, meaning full computation is performed for the
38+ initial steps (0 to disable_cache_before_step - 1) to gather data for Taylor series approximations. During
39+ these steps, Taylor factors are updated, but caching/predictions are not applied. Caching begins at this
40+ step.
4041
4142 disable_cache_after_step (`int`, *optional*, defaults to `None`):
42- The denoising step index after which caching is disabled. If set, for steps >= this value, all modules run full
43- computations without predictions or state updates, ensuring accuracy in later stages if needed.
43+ The denoising step index after which caching is disabled. If set, for steps >= this value, all modules run
44+ full computations without predictions or state updates, ensuring accuracy in later stages if needed.
4445
4546 max_order (`int`, defaults to `1`):
46- The highest order in the Taylor series expansion for approximating module outputs. Higher orders provide better
47- approximations but increase computation and memory usage.
47+ The highest order in the Taylor series expansion for approximating module outputs. Higher orders provide
48+ better approximations but increase computation and memory usage.
4849
4950 taylor_factors_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
50- Data type used for storing and computing Taylor series factors. Lower precision reduces memory but may affect
51- stability; higher precision improves accuracy at the cost of more memory.
51+ Data type used for storing and computing Taylor series factors. Lower precision reduces memory but may
52+ affect stability; higher precision improves accuracy at the cost of more memory.
5253
5354 inactive_identifiers (`List[str]`, *optional*, defaults to `None`):
54- Regex patterns (using `re.fullmatch`) for module names to place in "inactive" mode. In this mode, the module
55- computes fully during initial or refresh steps but returns a zero tensor (matching recorded shape) during
56- prediction steps to skip computation cheaply.
55+ Regex patterns (using `re.fullmatch`) for module names to place in "inactive" mode. In this mode, the
56+ module computes fully during initial or refresh steps but returns a zero tensor (matching recorded shape)
57+ during prediction steps to skip computation cheaply.
5758
5859 active_identifiers (`List[str]`, *optional*, defaults to `None`):
59- Regex patterns (using `re.fullmatch`) for module names to place in Taylor-series caching mode, where outputs
60- are approximated and cached for reuse.
60+ Regex patterns (using `re.fullmatch`) for module names to place in Taylor-series caching mode, where
61+ outputs are approximated and cached for reuse.
6162
6263 use_lite_mode (`bool`, *optional*, defaults to `False`):
6364 Enables a lightweight TaylorSeer variant that minimizes memory usage by applying predefined patterns for
@@ -118,7 +119,6 @@ def __init__(
118119 self .device : Optional [torch .device ] = None
119120 self .current_step : int = - 1
120121
121-
122122 def reset (self ) -> None :
123123 self .current_step = - 1
124124 self .last_update_step = None
@@ -223,13 +223,9 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
223223 state .current_step += 1
224224 current_step = state .current_step
225225 is_warmup_phase = current_step < self .disable_cache_before_step
226- is_compute_interval = (( current_step - self .disable_cache_before_step - 1 ) % self .cache_interval == 0 )
226+ is_compute_interval = (current_step - self .disable_cache_before_step - 1 ) % self .cache_interval == 0
227227 is_cooldown_phase = self .disable_cache_after_step is not None and current_step >= self .disable_cache_after_step
228- should_compute = (
229- is_warmup_phase
230- or is_compute_interval
231- or is_cooldown_phase
232- )
228+ should_compute = is_warmup_phase or is_compute_interval or is_cooldown_phase
233229 if should_compute :
234230 outputs = self .fn_ref .original_forward (* args , ** kwargs )
235231 wrapped_outputs = (outputs ,) if isinstance (outputs , torch .Tensor ) else outputs
@@ -255,8 +251,8 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi
255251 """
256252 Applies the TaylorSeer cache to a given pipeline (typically the transformer / UNet).
257253
258- This function hooks selected modules in the model to enable caching or skipping based on the provided configuration,
259- reducing redundant computations in diffusion denoising loops.
254+ This function hooks selected modules in the model to enable caching or skipping based on the provided
255+ configuration, reducing redundant computations in diffusion denoising loops.
260256
261257 Args:
262258 module (torch.nn.Module): The model subtree to apply the hooks to.
@@ -338,4 +334,4 @@ def _apply_taylorseer_cache_hook(
338334 state_manager = state_manager ,
339335 )
340336
341- registry .register_hook (hook , _TAYLORSEER_CACHE_HOOK )
337+ registry .register_hook (hook , _TAYLORSEER_CACHE_HOOK )
0 commit comments