|
1 | 1 | import math |
2 | 2 | import re |
3 | 3 | from dataclasses import dataclass |
4 | | -from typing import Optional, List, Dict, Tuple |
| 4 | +from typing import Dict, List, Optional, Tuple |
5 | 5 |
|
6 | 6 | import torch |
7 | 7 | import torch.nn as nn |
8 | 8 |
|
9 | | -from .hooks import ModelHook, StateManager, HookRegistry |
10 | 9 | from ..utils import logging |
| 10 | +from .hooks import HookRegistry, ModelHook, StateManager |
11 | 11 |
|
12 | 12 |
|
13 | 13 | logger = logging.get_logger(__name__) |
|
19 | 19 | ) |
20 | 20 | _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",) |
21 | 21 | _TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS |
22 | | -_BLOCK_IDENTIFIERS = ( |
23 | | - "^[^.]*block[^.]*\\.[^.]+$", |
24 | | -) |
| 22 | +_BLOCK_IDENTIFIERS = ("^[^.]*block[^.]*\\.[^.]+$",) |
25 | 23 | _PROJ_OUT_IDENTIFIERS = ("^proj_out$",) |
26 | 24 |
|
| 25 | + |
27 | 26 | @dataclass |
28 | 27 | class TaylorSeerCacheConfig: |
29 | 28 | """ |
30 | | - Configuration for TaylorSeer cache. |
31 | | - See: https://huggingface.co/papers/2503.06923 |
| 29 | + Configuration for TaylorSeer cache. See: https://huggingface.co/papers/2503.06923 |
32 | 30 |
|
33 | 31 | Attributes: |
34 | 32 | warmup_steps (`int`, defaults to `3`): |
35 | | - Number of denoising steps to run with full computation |
36 | | - before enabling caching. During warmup, the Taylor series factors |
37 | | - are still updated, but no predictions are used. |
| 33 | + Number of denoising steps to run with full computation before enabling caching. During warmup, the Taylor |
| 34 | + series factors are still updated, but no predictions are used. |
38 | 35 |
|
39 | 36 | predict_steps (`int`, defaults to `5`): |
40 | | - Number of prediction (cached) steps to take between two full |
41 | | - computations. That is, once a module state is refreshed, it will |
42 | | - be reused for `predict_steps` subsequent denoising steps, then a new |
43 | | - full forward will be computed on the next step. |
| 37 | + Number of prediction (cached) steps to take between two full computations. That is, once a module state is |
| 38 | + refreshed, it will be reused for `predict_steps` subsequent denoising steps, then a new full forward will |
| 39 | + be computed on the next step. |
44 | 40 |
|
45 | 41 | stop_predicts (`int`, *optional*, defaults to `None`): |
46 | | - Denoising step index at which caching is disabled. |
47 | | - If provided, for `self.current_step >= stop_predicts` all modules are |
48 | | - evaluated normally (no predictions, no state updates). |
| 42 | + Denoising step index at which caching is disabled. If provided, for `self.current_step >= stop_predicts` |
| 43 | + all modules are evaluated normally (no predictions, no state updates). |
49 | 44 |
|
50 | 45 | max_order (`int`, defaults to `1`): |
51 | | - Maximum order of Taylor series expansion to approximate the |
52 | | - features. Higher order gives closer approximation but more compute. |
| 46 | + Maximum order of Taylor series expansion to approximate the features. Higher order gives closer |
| 47 | + approximation but more compute. |
53 | 48 |
|
54 | 49 | taylor_factors_dtype (`torch.dtype`, defaults to `torch.bfloat16`): |
55 | | - Data type for computing Taylor series expansion factors. |
56 | | - Use lower precision to reduce memory usage. |
57 | | - Use higher precision to improve numerical stability. |
| 50 | + Data type for computing Taylor series expansion factors. Use lower precision to reduce memory usage. Use |
| 51 | + higher precision to improve numerical stability. |
58 | 52 |
|
59 | 53 | skip_identifiers (`List[str]`, *optional*, defaults to `None`): |
60 | | - Regex patterns (fullmatch) for module names to be placed in |
61 | | - "skip" mode, where the module is evaluated during warmup / |
62 | | - refresh, but then replaced by a cheap dummy tensor during |
63 | | - prediction steps. |
| 54 | + Regex patterns (fullmatch) for module names to be placed in "skip" mode, where the module is evaluated |
| 55 | + during warmup / refresh, but then replaced by a cheap dummy tensor during prediction steps. |
64 | 56 |
|
65 | 57 | cache_identifiers (`List[str]`, *optional*, defaults to `None`): |
66 | | - Regex patterns (fullmatch) for module names to be placed in |
67 | | - Taylor-series caching mode. |
| 58 | + Regex patterns (fullmatch) for module names to be placed in Taylor-series caching mode. |
68 | 59 |
|
69 | 60 | lite (`bool`, *optional*, defaults to `False`): |
70 | | - Whether to use a TaylorSeer Lite variant that reduces memory usage. This option overrides |
71 | | - any user-provided `skip_identifiers` or `cache_identifiers` patterns. |
| 61 | + Whether to use a TaylorSeer Lite variant that reduces memory usage. This option overrides any user-provided |
| 62 | + `skip_identifiers` or `cache_identifiers` patterns. |
72 | 63 | Notes: |
73 | 64 | - Patterns are applied with `re.fullmatch` on `module_name`. |
74 | | - - If either `skip_identifiers` or `cache_identifiers` is provided, only modules matching at least |
75 | | - one of those patterns will be hooked. |
| 65 | + - If either `skip_identifiers` or `cache_identifiers` is provided, only modules matching at least one of those |
| 66 | + patterns will be hooked. |
76 | 67 | - If neither is provided, all attention-like modules will be hooked. |
77 | 68 | """ |
78 | 69 |
|
@@ -255,13 +246,13 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi |
255 | 246 | ```python |
256 | 247 | >>> import torch |
257 | 248 | >>> from diffusers import FluxPipeline, TaylorSeerCacheConfig |
258 | | - >>> |
| 249 | +
|
259 | 250 | >>> pipe = FluxPipeline.from_pretrained( |
260 | 251 | ... "black-forest-labs/FLUX.1-dev", |
261 | 252 | ... torch_dtype=torch.bfloat16, |
262 | 253 | ... ) |
263 | 254 | >>> pipe.to("cuda") |
264 | | - >>> |
| 255 | +
|
265 | 256 | >>> config = TaylorSeerCacheConfig( |
266 | 257 | ... predict_steps=5, |
267 | 258 | ... max_order=1, |
|
0 commit comments