Skip to content

Commit a641182

Browse files
committed
update
1 parent 13d5af7 commit a641182

File tree

2 files changed

+223
-10
lines changed

2 files changed

+223
-10
lines changed

src/diffusers/models/hooks.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class ModelHook:
2929
def init_hook(self, module: torch.nn.Module) -> torch.nn.Module:
3030
r"""
3131
Hook that is executed when a model is initialized.
32+
3233
Args:
3334
module (`torch.nn.Module`):
3435
The module attached to this hook.
@@ -38,6 +39,7 @@ def init_hook(self, module: torch.nn.Module) -> torch.nn.Module:
3839
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
3940
r"""
4041
Hook that is executed just before the forward method of the model.
42+
4143
Args:
4244
module (`torch.nn.Module`):
4345
The module whose forward pass will be executed just after this event.
@@ -54,6 +56,7 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[A
5456
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
5557
r"""
5658
Hook that is executed just after the forward method of the model.
59+
5760
Args:
5861
module (`torch.nn.Module`):
5962
The module whose forward pass been executed just before this event.
@@ -67,15 +70,17 @@ def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
6770
def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
6871
r"""
6972
Hook that is executed when the hook is detached from a module.
73+
7074
Args:
7175
module (`torch.nn.Module`):
7276
The module detached from this hook.
7377
"""
7478
return module
7579

76-
def reset_state(self, module: torch.nn.Module):
80+
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
7781
if self._is_stateful:
7882
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
83+
return module
7984

8085

8186
class SequentialHook(ModelHook):
@@ -108,16 +113,21 @@ def reset_state(self, module):
108113
for hook in self.hooks:
109114
if hook._is_stateful:
110115
hook.reset_state(module)
116+
return module
111117

112118

113119
def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False) -> torch.nn.Module:
114120
r"""
115121
Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
116122
this behavior and restore the original `forward` method, use `remove_hook_from_module`.
123+
117124
<Tip warning={true}>
125+
118126
If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks
119127
together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class.
128+
120129
</Tip>
130+
121131
Args:
122132
module (`torch.nn.Module`):
123133
The module to attach a hook to.
@@ -168,6 +178,7 @@ def new_forward(module, *args, **kwargs):
168178
def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> torch.nn.Module:
169179
"""
170180
Removes any hook attached to a module via `add_hook_to_module`.
181+
171182
Args:
172183
module (`torch.nn.Module`):
173184
The module to attach a hook to.
@@ -201,6 +212,7 @@ def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> t
201212
def reset_stateful_hooks(module: torch.nn.Module, recurse: bool = False):
202213
"""
203214
Resets the state of all stateful hooks attached to a module.
215+
204216
Args:
205217
module (`torch.nn.Module`):
206218
The module to reset the stateful hooks from.

src/diffusers/pipelines/teacache_utils.py

Lines changed: 210 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,21 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import re
1516
from dataclasses import dataclass
16-
from typing import Any, Callable, Dict, Optional, Tuple
17+
from typing import List, Optional
1718

19+
import numpy as np
1820
import torch
1921
import torch.nn as nn
2022

23+
from ..models import (
24+
FluxTransformer2DModel,
25+
HunyuanVideoTransformer3DModel,
26+
LTXVideoTransformer3DModel,
27+
LuminaNextDiT2DModel,
28+
MochiTransformer3DModel,
29+
)
2130
from ..models.hooks import ModelHook, add_hook_to_module
2231
from ..utils import logging
2332
from .pipeline_utils import DiffusionPipeline
@@ -26,26 +35,218 @@
2635
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2736

2837

38+
# Source: https://github.com/ali-vilab/TeaCache
39+
# TODO(aryan): Implement methods to calibrate and compute polynomial coefficients on-the-fly, and export to file for re-use.
40+
# fmt: off
41+
_MODEL_TO_POLY_COEFFICIENTS = {
42+
FluxTransformer2DModel: [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01],
43+
HunyuanVideoTransformer3DModel: [7.33226126e02, -4.01131952e02, 6.75869174e01, -3.14987800e00, 9.61237896e-02],
44+
LTXVideoTransformer3DModel: [2.14700694e01, -1.28016453e01, 2.31279151e00, 7.92487521e-01, 9.69274326e-03],
45+
LuminaNextDiT2DModel: [393.76566581, -603.50993606, 209.10239044, -23.00726601, 0.86377344],
46+
MochiTransformer3DModel: [-3.51241319e03, 8.11675948e02, -6.09400215e01, 2.42429681e00, 3.05291719e-03],
47+
}
48+
# fmt: on
49+
50+
_MODEL_TO_1_POINT_5X_SPEEDUP_THRESHOLD = {
51+
FluxTransformer2DModel: 0.25,
52+
HunyuanVideoTransformer3DModel: 0.1,
53+
LTXVideoTransformer3DModel: 0.05,
54+
LuminaNextDiT2DModel: 0.2,
55+
MochiTransformer3DModel: 0.06,
56+
}
57+
58+
_MODEL_TO_TIMESTEP_MODULATED_LAYER_IDENTIFIER = {
59+
FluxTransformer2DModel: "transformer_blocks.0.norm1",
60+
}
61+
62+
_MODEL_TO_SKIP_END_LAYER_IDENTIFIER = {
63+
FluxTransformer2DModel: "norm_out",
64+
}
65+
66+
_DEFAULT_SKIP_LAYER_IDENTIFIERS = [
67+
"blocks",
68+
"transformer_blocks",
69+
"single_transformer_blocks",
70+
"temporal_transformer_blocks",
71+
]
72+
73+
2974
@dataclass
3075
class TeaCacheConfig:
31-
pass
76+
l1_threshold: Optional[float] = None
77+
78+
skip_layer_identifiers: List[str] = _DEFAULT_SKIP_LAYER_IDENTIFIERS
79+
80+
_polynomial_coefficients: Optional[List[float]] = None
3281

3382

3483
class TeaCacheDenoiserState:
3584
def __init__(self):
36-
self.iteration = 0
37-
self.accumulated_l1_difference = 0.0
38-
self.timestep_modulated_cache = None
39-
85+
self.iteration: int = 0
86+
self.accumulated_l1_difference: float = 0.0
87+
self.timestep_modulated_cache: torch.Tensor = None
88+
self.residual_cache: torch.Tensor = None
89+
self.should_skip_blocks: bool = False
90+
4091
def reset(self):
4192
self.iteration = 0
4293
self.accumulated_l1_difference = 0.0
4394
self.timestep_modulated_cache = None
95+
self.residual_cache = None
4496

4597

46-
def apply_teacache(pipeline: DiffusionPipeline, config: TeaCacheConfig, denoiser: Optional[nn.Module]) -> None:
47-
r"""Applies [TeaCache]() to a given pipeline or denoiser module.
48-
98+
def apply_teacache(
99+
pipeline: DiffusionPipeline, config: Optional[TeaCacheConfig] = None, denoiser: Optional[nn.Module] = None
100+
) -> None:
101+
r"""Applies [TeaCache](https://huggingface.co/papers/2411.19108) to a given pipeline or denoiser module.
102+
49103
Args:
50104
TODO
51105
"""
106+
107+
if config is None:
108+
logger.warning("No TeaCacheConfig provided. Using default configuration.")
109+
config = TeaCacheConfig()
110+
111+
if denoiser is None:
112+
denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet
113+
114+
if isinstance(denoiser, (_MODEL_TO_POLY_COEFFICIENTS.keys())):
115+
if config.l1_threshold is None:
116+
logger.info(
117+
f"No L1 threshold was provided for {type(denoiser)}. Using default threshold as provided in the TeaCache paper for 1.5x speedup. "
118+
f"For higher speedup, increase the threshold."
119+
)
120+
config.l1_threshold = _MODEL_TO_1_POINT_5X_SPEEDUP_THRESHOLD[type(denoiser)]
121+
if config.timestep_modulated_layer_identifier is None:
122+
logger.info(
123+
f"No timestep modulated layer identifier was provided for {type(denoiser)}. Using default identifier as provided in the TeaCache paper."
124+
)
125+
config.timestep_modulated_layer_identifier = _MODEL_TO_TIMESTEP_MODULATED_LAYER_IDENTIFIER[type(denoiser)]
126+
if config._polynomial_coefficients is None:
127+
logger.info(
128+
f"No polynomial coefficients were provided for {type(denoiser)}. Using default coefficients as provided in the TeaCache paper."
129+
)
130+
config._polynomial_coefficients = _MODEL_TO_POLY_COEFFICIENTS[type(denoiser)]
131+
else:
132+
if config.l1_threshold is None:
133+
raise ValueError(
134+
f"No L1 threshold was provided for {type(denoiser)}. Using TeaCache with this model is not supported "
135+
f"in Diffusers. Please provide the L1 threshold in the config by setting the `l1_threshold` attribute."
136+
)
137+
if config.timestep_modulated_layer_identifier is None:
138+
raise ValueError(
139+
f"No timestep modulated layer identifier was provided for {type(denoiser)}. Using TeaCache with this model is not supported "
140+
f"in Diffusers. Please provide the layer identifier in the config by setting the `timestep_modulated_layer_identifier` attribute."
141+
)
142+
if config._polynomial_coefficients is None:
143+
raise ValueError(
144+
f"No polynomial coefficients were provided for {type(denoiser)}. Using TeaCache with this model is not "
145+
f"supported in Diffusers. Please provide the polynomial coefficients in the config by setting the "
146+
f"`_polynomial_coefficients` attribute. Automatic calibration will be implemented in the future."
147+
)
148+
149+
timestep_modulated_layer_matches = list(
150+
{
151+
module
152+
for name, module in denoiser.named_modules()
153+
if re.match(config.timestep_modulated_layer_identifier, name)
154+
}
155+
)
156+
157+
if len(timestep_modulated_layer_matches) == 0:
158+
raise ValueError(
159+
f"No layer in the denoiser module matched the provided timestep modulated layer identifier: "
160+
f"{config.timestep_modulated_layer_identifier}. Please provide a valid layer identifier."
161+
)
162+
if len(timestep_modulated_layer_matches) > 1:
163+
logger.warning(
164+
f"Multiple layers in the denoiser module matched the provided timestep modulated layer identifier: "
165+
f"{config.timestep_modulated_layer_identifier}. Using the first match."
166+
)
167+
168+
denoiser_state = TeaCacheDenoiserState()
169+
170+
timestep_modulated_layer = timestep_modulated_layer_matches[0]
171+
hook = TimestepModulatedOutputCacheHook(denoiser_state, config.l1_threshold, config._polynomial_coefficients)
172+
add_hook_to_module(timestep_modulated_layer, hook, append=True)
173+
174+
skip_layer_identifiers = config.skip_layer_identifiers
175+
skip_layer_matches = list(
176+
{
177+
module
178+
for name, module in denoiser.named_modules()
179+
if any(re.match(identifier, name) for identifier in skip_layer_identifiers)
180+
}
181+
)
182+
183+
for skip_layer in skip_layer_matches:
184+
hook = DenoiserStateBasedSkipLayerHook(denoiser_state)
185+
add_hook_to_module(skip_layer, hook, append=True)
186+
187+
188+
class TimestepModulatedOutputCacheHook(ModelHook):
189+
# The denoiser hook will reset its state, so we don't have to handle it here
190+
_is_stateful = False
191+
192+
def __init__(
193+
self,
194+
denoiser_state: TeaCacheDenoiserState,
195+
l1_threshold: float,
196+
polynomial_coefficients: List[float],
197+
) -> None:
198+
self.denoiser_state = denoiser_state
199+
self.l1_threshold = l1_threshold
200+
# TODO(aryan): implement torch equivalent
201+
self.rescale_fn = np.poly1d(polynomial_coefficients)
202+
203+
def post_forward(self, module, output):
204+
if isinstance(output, tuple):
205+
# This assumes that the first element of the output tuple is the timestep modulated noise output.
206+
# For Diffusers models, this is true. For models outside diffusers, users will have to ensure
207+
# that the first element of the output tuple is the timestep modulated noise output (seems to be
208+
# the case for most research model implementations).
209+
timestep_modulated_noise = output[0]
210+
elif torch.is_tensor(output):
211+
timestep_modulated_noise = output
212+
else:
213+
raise ValueError(
214+
f"Expected output to be a tensor or a tuple with first element as timestep modulated noise. "
215+
f"Got {type(output)} instead. Please ensure that the denoiser module returns the timestep "
216+
f"modulated noise output as the first element."
217+
)
218+
219+
if self.denoiser_state.timestep_modulated_cache is not None:
220+
l1_diff = (timestep_modulated_noise - self.denoiser_state.timestep_modulated_cache).abs().mean()
221+
normalized_l1_diff = l1_diff / self.denoiser_state.timestep_modulated_cache.abs().mean()
222+
rescaled_l1_diff = self.rescale_fn(normalized_l1_diff)
223+
self.denoiser_state.accumulated_l1_difference += rescaled_l1_diff
224+
225+
if self.denoiser_state.accumulated_l1_difference >= self.l1_threshold:
226+
self.denoiser_state.should_skip_blocks = True
227+
self.denoiser_state.accumulated_l1_difference = 0.0
228+
else:
229+
self.denoiser_state.should_skip_blocks = False
230+
231+
self.denoiser_state.timestep_modulated_cache = timestep_modulated_noise
232+
return output
233+
234+
235+
class DenoiserStateBasedSkipLayerHook(ModelHook):
236+
_is_stateful = False
237+
238+
def __init__(self, denoiser_state: TeaCacheDenoiserState) -> None:
239+
self.denoiser_state = denoiser_state
240+
241+
def new_forward(self, module, *args, **kwargs):
242+
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
243+
244+
if not self.denoiser_state.should_skip_blocks:
245+
output = module._old_forward(*args, **kwargs)
246+
else:
247+
# Diffusers models either expect one output (hidden_states) or a tuple of two outputs (hidden_states, encoder_hidden_states).
248+
# Returning a tuple of None values handles both cases. It is okay to do because we are not going to be using these
249+
# anywhere if self.denoiser_state.should_skip_blocks is True.
250+
output = (None, None)
251+
252+
return module._diffusers_hook.post_forward(module, output)

0 commit comments

Comments
 (0)