|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import re |
15 | 16 | from dataclasses import dataclass |
16 | | -from typing import Any, Callable, Dict, Optional, Tuple |
| 17 | +from typing import List, Optional |
17 | 18 |
|
| 19 | +import numpy as np |
18 | 20 | import torch |
19 | 21 | import torch.nn as nn |
20 | 22 |
|
| 23 | +from ..models import ( |
| 24 | + FluxTransformer2DModel, |
| 25 | + HunyuanVideoTransformer3DModel, |
| 26 | + LTXVideoTransformer3DModel, |
| 27 | + LuminaNextDiT2DModel, |
| 28 | + MochiTransformer3DModel, |
| 29 | +) |
21 | 30 | from ..models.hooks import ModelHook, add_hook_to_module |
22 | 31 | from ..utils import logging |
23 | 32 | from .pipeline_utils import DiffusionPipeline |
|
26 | 35 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
27 | 36 |
|
28 | 37 |
|
| 38 | +# Source: https://github.com/ali-vilab/TeaCache |
| 39 | +# TODO(aryan): Implement methods to calibrate and compute polynomial coefficients on-the-fly, and export to file for re-use. |
| 40 | +# fmt: off |
| 41 | +_MODEL_TO_POLY_COEFFICIENTS = { |
| 42 | + FluxTransformer2DModel: [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01], |
| 43 | + HunyuanVideoTransformer3DModel: [7.33226126e02, -4.01131952e02, 6.75869174e01, -3.14987800e00, 9.61237896e-02], |
| 44 | + LTXVideoTransformer3DModel: [2.14700694e01, -1.28016453e01, 2.31279151e00, 7.92487521e-01, 9.69274326e-03], |
| 45 | + LuminaNextDiT2DModel: [393.76566581, -603.50993606, 209.10239044, -23.00726601, 0.86377344], |
| 46 | + MochiTransformer3DModel: [-3.51241319e03, 8.11675948e02, -6.09400215e01, 2.42429681e00, 3.05291719e-03], |
| 47 | +} |
| 48 | +# fmt: on |
| 49 | + |
| 50 | +_MODEL_TO_1_POINT_5X_SPEEDUP_THRESHOLD = { |
| 51 | + FluxTransformer2DModel: 0.25, |
| 52 | + HunyuanVideoTransformer3DModel: 0.1, |
| 53 | + LTXVideoTransformer3DModel: 0.05, |
| 54 | + LuminaNextDiT2DModel: 0.2, |
| 55 | + MochiTransformer3DModel: 0.06, |
| 56 | +} |
| 57 | + |
| 58 | +_MODEL_TO_TIMESTEP_MODULATED_LAYER_IDENTIFIER = { |
| 59 | + FluxTransformer2DModel: "transformer_blocks.0.norm1", |
| 60 | +} |
| 61 | + |
| 62 | +_MODEL_TO_SKIP_END_LAYER_IDENTIFIER = { |
| 63 | + FluxTransformer2DModel: "norm_out", |
| 64 | +} |
| 65 | + |
| 66 | +_DEFAULT_SKIP_LAYER_IDENTIFIERS = [ |
| 67 | + "blocks", |
| 68 | + "transformer_blocks", |
| 69 | + "single_transformer_blocks", |
| 70 | + "temporal_transformer_blocks", |
| 71 | +] |
| 72 | + |
| 73 | + |
29 | 74 | @dataclass |
30 | 75 | class TeaCacheConfig: |
31 | | - pass |
| 76 | + l1_threshold: Optional[float] = None |
| 77 | + |
| 78 | + skip_layer_identifiers: List[str] = _DEFAULT_SKIP_LAYER_IDENTIFIERS |
| 79 | + |
| 80 | + _polynomial_coefficients: Optional[List[float]] = None |
32 | 81 |
|
33 | 82 |
|
34 | 83 | class TeaCacheDenoiserState: |
35 | 84 | def __init__(self): |
36 | | - self.iteration = 0 |
37 | | - self.accumulated_l1_difference = 0.0 |
38 | | - self.timestep_modulated_cache = None |
39 | | - |
| 85 | + self.iteration: int = 0 |
| 86 | + self.accumulated_l1_difference: float = 0.0 |
| 87 | + self.timestep_modulated_cache: torch.Tensor = None |
| 88 | + self.residual_cache: torch.Tensor = None |
| 89 | + self.should_skip_blocks: bool = False |
| 90 | + |
40 | 91 | def reset(self): |
41 | 92 | self.iteration = 0 |
42 | 93 | self.accumulated_l1_difference = 0.0 |
43 | 94 | self.timestep_modulated_cache = None |
| 95 | + self.residual_cache = None |
44 | 96 |
|
45 | 97 |
|
46 | | -def apply_teacache(pipeline: DiffusionPipeline, config: TeaCacheConfig, denoiser: Optional[nn.Module]) -> None: |
47 | | - r"""Applies [TeaCache]() to a given pipeline or denoiser module. |
48 | | - |
| 98 | +def apply_teacache( |
| 99 | + pipeline: DiffusionPipeline, config: Optional[TeaCacheConfig] = None, denoiser: Optional[nn.Module] = None |
| 100 | +) -> None: |
| 101 | + r"""Applies [TeaCache](https://huggingface.co/papers/2411.19108) to a given pipeline or denoiser module. |
| 102 | +
|
49 | 103 | Args: |
50 | 104 | TODO |
51 | 105 | """ |
| 106 | + |
| 107 | + if config is None: |
| 108 | + logger.warning("No TeaCacheConfig provided. Using default configuration.") |
| 109 | + config = TeaCacheConfig() |
| 110 | + |
| 111 | + if denoiser is None: |
| 112 | + denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet |
| 113 | + |
| 114 | + if isinstance(denoiser, (_MODEL_TO_POLY_COEFFICIENTS.keys())): |
| 115 | + if config.l1_threshold is None: |
| 116 | + logger.info( |
| 117 | + f"No L1 threshold was provided for {type(denoiser)}. Using default threshold as provided in the TeaCache paper for 1.5x speedup. " |
| 118 | + f"For higher speedup, increase the threshold." |
| 119 | + ) |
| 120 | + config.l1_threshold = _MODEL_TO_1_POINT_5X_SPEEDUP_THRESHOLD[type(denoiser)] |
| 121 | + if config.timestep_modulated_layer_identifier is None: |
| 122 | + logger.info( |
| 123 | + f"No timestep modulated layer identifier was provided for {type(denoiser)}. Using default identifier as provided in the TeaCache paper." |
| 124 | + ) |
| 125 | + config.timestep_modulated_layer_identifier = _MODEL_TO_TIMESTEP_MODULATED_LAYER_IDENTIFIER[type(denoiser)] |
| 126 | + if config._polynomial_coefficients is None: |
| 127 | + logger.info( |
| 128 | + f"No polynomial coefficients were provided for {type(denoiser)}. Using default coefficients as provided in the TeaCache paper." |
| 129 | + ) |
| 130 | + config._polynomial_coefficients = _MODEL_TO_POLY_COEFFICIENTS[type(denoiser)] |
| 131 | + else: |
| 132 | + if config.l1_threshold is None: |
| 133 | + raise ValueError( |
| 134 | + f"No L1 threshold was provided for {type(denoiser)}. Using TeaCache with this model is not supported " |
| 135 | + f"in Diffusers. Please provide the L1 threshold in the config by setting the `l1_threshold` attribute." |
| 136 | + ) |
| 137 | + if config.timestep_modulated_layer_identifier is None: |
| 138 | + raise ValueError( |
| 139 | + f"No timestep modulated layer identifier was provided for {type(denoiser)}. Using TeaCache with this model is not supported " |
| 140 | + f"in Diffusers. Please provide the layer identifier in the config by setting the `timestep_modulated_layer_identifier` attribute." |
| 141 | + ) |
| 142 | + if config._polynomial_coefficients is None: |
| 143 | + raise ValueError( |
| 144 | + f"No polynomial coefficients were provided for {type(denoiser)}. Using TeaCache with this model is not " |
| 145 | + f"supported in Diffusers. Please provide the polynomial coefficients in the config by setting the " |
| 146 | + f"`_polynomial_coefficients` attribute. Automatic calibration will be implemented in the future." |
| 147 | + ) |
| 148 | + |
| 149 | + timestep_modulated_layer_matches = list( |
| 150 | + { |
| 151 | + module |
| 152 | + for name, module in denoiser.named_modules() |
| 153 | + if re.match(config.timestep_modulated_layer_identifier, name) |
| 154 | + } |
| 155 | + ) |
| 156 | + |
| 157 | + if len(timestep_modulated_layer_matches) == 0: |
| 158 | + raise ValueError( |
| 159 | + f"No layer in the denoiser module matched the provided timestep modulated layer identifier: " |
| 160 | + f"{config.timestep_modulated_layer_identifier}. Please provide a valid layer identifier." |
| 161 | + ) |
| 162 | + if len(timestep_modulated_layer_matches) > 1: |
| 163 | + logger.warning( |
| 164 | + f"Multiple layers in the denoiser module matched the provided timestep modulated layer identifier: " |
| 165 | + f"{config.timestep_modulated_layer_identifier}. Using the first match." |
| 166 | + ) |
| 167 | + |
| 168 | + denoiser_state = TeaCacheDenoiserState() |
| 169 | + |
| 170 | + timestep_modulated_layer = timestep_modulated_layer_matches[0] |
| 171 | + hook = TimestepModulatedOutputCacheHook(denoiser_state, config.l1_threshold, config._polynomial_coefficients) |
| 172 | + add_hook_to_module(timestep_modulated_layer, hook, append=True) |
| 173 | + |
| 174 | + skip_layer_identifiers = config.skip_layer_identifiers |
| 175 | + skip_layer_matches = list( |
| 176 | + { |
| 177 | + module |
| 178 | + for name, module in denoiser.named_modules() |
| 179 | + if any(re.match(identifier, name) for identifier in skip_layer_identifiers) |
| 180 | + } |
| 181 | + ) |
| 182 | + |
| 183 | + for skip_layer in skip_layer_matches: |
| 184 | + hook = DenoiserStateBasedSkipLayerHook(denoiser_state) |
| 185 | + add_hook_to_module(skip_layer, hook, append=True) |
| 186 | + |
| 187 | + |
| 188 | +class TimestepModulatedOutputCacheHook(ModelHook): |
| 189 | + # The denoiser hook will reset its state, so we don't have to handle it here |
| 190 | + _is_stateful = False |
| 191 | + |
| 192 | + def __init__( |
| 193 | + self, |
| 194 | + denoiser_state: TeaCacheDenoiserState, |
| 195 | + l1_threshold: float, |
| 196 | + polynomial_coefficients: List[float], |
| 197 | + ) -> None: |
| 198 | + self.denoiser_state = denoiser_state |
| 199 | + self.l1_threshold = l1_threshold |
| 200 | + # TODO(aryan): implement torch equivalent |
| 201 | + self.rescale_fn = np.poly1d(polynomial_coefficients) |
| 202 | + |
| 203 | + def post_forward(self, module, output): |
| 204 | + if isinstance(output, tuple): |
| 205 | + # This assumes that the first element of the output tuple is the timestep modulated noise output. |
| 206 | + # For Diffusers models, this is true. For models outside diffusers, users will have to ensure |
| 207 | + # that the first element of the output tuple is the timestep modulated noise output (seems to be |
| 208 | + # the case for most research model implementations). |
| 209 | + timestep_modulated_noise = output[0] |
| 210 | + elif torch.is_tensor(output): |
| 211 | + timestep_modulated_noise = output |
| 212 | + else: |
| 213 | + raise ValueError( |
| 214 | + f"Expected output to be a tensor or a tuple with first element as timestep modulated noise. " |
| 215 | + f"Got {type(output)} instead. Please ensure that the denoiser module returns the timestep " |
| 216 | + f"modulated noise output as the first element." |
| 217 | + ) |
| 218 | + |
| 219 | + if self.denoiser_state.timestep_modulated_cache is not None: |
| 220 | + l1_diff = (timestep_modulated_noise - self.denoiser_state.timestep_modulated_cache).abs().mean() |
| 221 | + normalized_l1_diff = l1_diff / self.denoiser_state.timestep_modulated_cache.abs().mean() |
| 222 | + rescaled_l1_diff = self.rescale_fn(normalized_l1_diff) |
| 223 | + self.denoiser_state.accumulated_l1_difference += rescaled_l1_diff |
| 224 | + |
| 225 | + if self.denoiser_state.accumulated_l1_difference >= self.l1_threshold: |
| 226 | + self.denoiser_state.should_skip_blocks = True |
| 227 | + self.denoiser_state.accumulated_l1_difference = 0.0 |
| 228 | + else: |
| 229 | + self.denoiser_state.should_skip_blocks = False |
| 230 | + |
| 231 | + self.denoiser_state.timestep_modulated_cache = timestep_modulated_noise |
| 232 | + return output |
| 233 | + |
| 234 | + |
| 235 | +class DenoiserStateBasedSkipLayerHook(ModelHook): |
| 236 | + _is_stateful = False |
| 237 | + |
| 238 | + def __init__(self, denoiser_state: TeaCacheDenoiserState) -> None: |
| 239 | + self.denoiser_state = denoiser_state |
| 240 | + |
| 241 | + def new_forward(self, module, *args, **kwargs): |
| 242 | + args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) |
| 243 | + |
| 244 | + if not self.denoiser_state.should_skip_blocks: |
| 245 | + output = module._old_forward(*args, **kwargs) |
| 246 | + else: |
| 247 | + # Diffusers models either expect one output (hidden_states) or a tuple of two outputs (hidden_states, encoder_hidden_states). |
| 248 | + # Returning a tuple of None values handles both cases. It is okay to do because we are not going to be using these |
| 249 | + # anywhere if self.denoiser_state.should_skip_blocks is True. |
| 250 | + output = (None, None) |
| 251 | + |
| 252 | + return module._diffusers_hook.post_forward(module, output) |
0 commit comments