Skip to content

Commit 341fbfc

Browse files
committed
update mixin
1 parent 8975bbf commit 341fbfc

26 files changed

+126
-9
lines changed

src/diffusers/models/autoencoders/autoencoder_asym_kl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
6060
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
6161
"""
6262

63+
_always_upcast_modules = ["MaskConditionDecoder"]
64+
6365
@register_to_config
6466
def __init__(
6567
self,

src/diffusers/models/autoencoders/vq_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ class VQModel(ModelMixin, ConfigMixin):
7171
Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
7272
"""
7373

74+
_always_upcast_modules = ["VectorQuantizer"]
75+
7476
@register_to_config
7577
def __init__(
7678
self,

src/diffusers/models/layerwise_upcasting_utils.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ class LayerwiseUpcastingHook(ModelHook):
4545
in the output, but can significantly reduce the memory footprint.
4646
"""
4747

48+
_is_stateful = False
49+
4850
def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype) -> None:
4951
self.storage_dtype = storage_dtype
5052
self.compute_dtype = compute_dtype
@@ -56,8 +58,8 @@ def init_hook(self, module: torch.nn.Module):
5658
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
5759
module.to(dtype=self.compute_dtype)
5860
# How do we account for LongTensor, BoolTensor, etc.?
59-
# args = tuple(align_maybe_tensor_dtype(arg, self.compute_dtype) for arg in args)
60-
# kwargs = {k: align_maybe_tensor_dtype(v, self.compute_dtype) for k, v in kwargs.items()}
61+
# args = tuple(_align_maybe_tensor_dtype(arg, self.compute_dtype) for arg in args)
62+
# kwargs = {k: _align_maybe_tensor_dtype(v, self.compute_dtype) for k, v in kwargs.items()}
6163
return args, kwargs
6264

6365
def post_forward(self, module: torch.nn.Module, output):
@@ -105,7 +107,7 @@ class LayerwiseUpcastingGranularity(str, Enum):
105107
torch.nn.Linear,
106108
]
107109

108-
_DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN = ["pos_embed", "patch_embed", "norm"]
110+
_DEFAULT_SKIP_MODULES_PATTERN = ["pos_embed", "patch_embed", "norm"]
109111
# fmt: on
110112

111113

@@ -114,9 +116,27 @@ def apply_layerwise_upcasting(
114116
storage_dtype: torch.dtype,
115117
compute_dtype: torch.dtype,
116118
granularity: LayerwiseUpcastingGranularity = LayerwiseUpcastingGranularity.PYTORCH_LAYER,
117-
skip_modules_pattern: List[str] = [],
119+
skip_modules_pattern: List[str] = _DEFAULT_SKIP_MODULES_PATTERN,
118120
skip_modules_classes: List[Type[torch.nn.Module]] = [],
119121
) -> torch.nn.Module:
122+
r"""
123+
Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any
124+
nn.Module using diffusers layers or pytorch primitives.
125+
126+
Args:
127+
module (`torch.nn.Module`):
128+
The module to attach the hook to.
129+
storage_dtype (`torch.dtype`):
130+
The dtype to cast the module to before the forward pass.
131+
compute_dtype (`torch.dtype`):
132+
The dtype to cast the module to during the forward pass.
133+
granularity (`LayerwiseUpcastingGranularity`, *optional*, defaults to `LayerwiseUpcastingGranularity.PYTORCH_LAYER`):
134+
The granularity of the layerwise upcasting process.
135+
skip_modules_pattern (`List[str]`, defaults to `["pos_embed", "patch_embed", "norm"]`):
136+
A list of patterns to match the names of the modules to skip during the layerwise upcasting process.
137+
skip_modules_classes (`List[Type[torch.nn.Module]]`, defaults to `[]`):
138+
A list of module classes to skip during the layerwise upcasting process.
139+
"""
120140
if granularity == LayerwiseUpcastingGranularity.DIFFUSERS_LAYER:
121141
return _apply_layerwise_upcasting_diffusers_layer(
122142
module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes
@@ -153,7 +173,7 @@ def _apply_layerwise_upcasting_diffusers_layer(
153173
module: torch.nn.Module,
154174
storage_dtype: torch.dtype,
155175
compute_dtype: torch.dtype,
156-
skip_modules_pattern: List[str] = _DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN,
176+
skip_modules_pattern: List[str] = _DEFAULT_SKIP_MODULES_PATTERN,
157177
skip_modules_classes: List[Type[torch.nn.Module]] = [],
158178
) -> torch.nn.Module:
159179
for name, submodule in module.named_modules():
@@ -173,7 +193,7 @@ def _apply_layerwise_upcasting_pytorch_layer(
173193
module: torch.nn.Module,
174194
storage_dtype: torch.dtype,
175195
compute_dtype: torch.dtype,
176-
skip_modules_pattern: List[str] = _DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN,
196+
skip_modules_pattern: List[str] = _DEFAULT_SKIP_MODULES_PATTERN,
177197
skip_modules_classes: List[Type[torch.nn.Module]] = [],
178198
) -> torch.nn.Module:
179199
for name, submodule in module.named_modules():
@@ -189,7 +209,7 @@ def _apply_layerwise_upcasting_pytorch_layer(
189209
return module
190210

191211

192-
def align_maybe_tensor_dtype(input: Any, dtype: torch.dtype) -> Any:
212+
def _align_maybe_tensor_dtype(input: Any, dtype: torch.dtype) -> Any:
193213
r"""
194214
Aligns the dtype of a tensor or a list of tensors to a given dtype.
195215
@@ -199,14 +219,15 @@ def align_maybe_tensor_dtype(input: Any, dtype: torch.dtype) -> Any:
199219
types, it will be returned as is.
200220
dtype (`torch.dtype`):
201221
The dtype to align the tensor(s) to.
222+
202223
Returns:
203224
`Any`:
204225
The tensor or list of tensors aligned to the given dtype.
205226
"""
206227
if isinstance(input, torch.Tensor):
207228
return input.to(dtype=dtype)
208229
if isinstance(input, (list, tuple)):
209-
return [align_maybe_tensor_dtype(t, dtype) for t in input]
230+
return [_align_maybe_tensor_dtype(t, dtype) for t in input]
210231
if isinstance(input, dict):
211-
return {k: align_maybe_tensor_dtype(v, dtype) for k, v in input.items()}
232+
return {k: _align_maybe_tensor_dtype(v, dtype) for k, v in input.items()}
212233
return input

src/diffusers/models/modeling_utils.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
load_or_create_model_card,
5757
populate_model_card,
5858
)
59+
from .layerwise_upcasting_utils import LayerwiseUpcastingGranularity, apply_layerwise_upcasting
5960
from .model_loading_utils import (
6061
_determine_device_map,
6162
_fetch_index_file,
@@ -150,6 +151,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
150151
_keys_to_ignore_on_load_unexpected = None
151152
_no_split_modules = None
152153
_keep_in_fp32_modules = None
154+
_always_upcast_modules = None
153155

154156
def __init__(self):
155157
super().__init__()
@@ -314,6 +316,67 @@ def disable_xformers_memory_efficient_attention(self) -> None:
314316
"""
315317
self.set_use_memory_efficient_attention_xformers(False)
316318

319+
def enable_layerwise_upcasting(
320+
self,
321+
storage_dtype: torch.dtype = torch.float8_e4m3fn,
322+
compute_dtype: Optional[torch.dtype] = None,
323+
granularity: LayerwiseUpcastingGranularity = LayerwiseUpcastingGranularity.PYTORCH_LAYER,
324+
) -> None:
325+
r"""
326+
Activates layerwise upcasting for the current model.
327+
328+
Layerwise upcasting is a technique that casts the model weights to a lower precision dtype for storage but
329+
upcasts them on-the-fly to a higher precision dtype for computation. This process can significantly reduce the
330+
memory footprint from model weights, but may lead to some quality degradation in the outputs. Most degradations
331+
are negligible, mostly stemming from weight casting in normalization and modulation layers.
332+
333+
By default, most models in diffusers set the `_always_upcast_modules` attribute to ignore patch embedding,
334+
positional embedding and normalization layers. This is because these layers are most likely precision-critical
335+
for quality. If you wish to change this behavior, you can set the `_always_upcast_modules` attribute to `None`,
336+
or call [`~apply_layerwise_upcasting`] with custom arguments.
337+
338+
Example:
339+
Using [`~models.ModelMixin.enable_layerwise_upcasting`]:
340+
341+
```python
342+
>>> from diffusers import CogVideoXTransformer3DModel, apply_layerwise_upcasting
343+
344+
>>> transformer = CogVideoXTransformer3DModel.from_pretrained(
345+
... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
346+
... )
347+
348+
>>> # Enable layerwise upcasting via the model, which ignores certain modules by default
349+
>>> transformer.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
350+
351+
>>> # Or, enable layerwise upcasting with custom arguments via the `apply_layerwise_upcasting` function
352+
>>> apply_layerwise_upcasting(
353+
... transformer, torch.float8_e4m3fn, torch.bfloat16, skip_modules_pattern=["patch_embed", "norm.*"]
354+
... )
355+
```
356+
357+
Args:
358+
storage_dtype (`torch.dtype`):
359+
The dtype to which the model should be cast for storage.
360+
compute_dtype (`torch.dtype`):
361+
The dtype to which the model weights should be cast during the forward pass.
362+
granularity (`LayerwiseUpcastingGranularity`, defaults to "pytorch_layer"):
363+
The granularity of the layerwise upcasting process. Read the documentation of
364+
[`~LayerwiseUpcastingGranularity`] for more information.
365+
"""
366+
367+
skip_modules_pattern = []
368+
if self._keep_in_fp32_modules is not None:
369+
skip_modules_pattern.extend(self._keep_in_fp32_modules)
370+
if self._always_upcast_modules is not None:
371+
skip_modules_pattern.extend(self._always_upcast_modules)
372+
skip_modules_pattern = list(set(skip_modules_pattern))
373+
374+
if compute_dtype is None:
375+
logger.info("`compute_dtype` not provided when enabling layerwise upcasting. Using `storage_dtype`.")
376+
compute_dtype = self.dtype
377+
378+
apply_layerwise_upcasting(self, storage_dtype, compute_dtype, granularity, skip_modules_pattern)
379+
317380
def save_pretrained(
318381
self,
319382
save_directory: Union[str, os.PathLike],

src/diffusers/models/transformers/auraflow_transformer_2d.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
275275
"""
276276

277277
_no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
278+
_always_upcast_modules = ["pos_embed", "norm.*"]
278279
_supports_gradient_checkpointing = True
279280

280281
@register_to_config

src/diffusers/models/transformers/cogvideox_transformer_3d.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
209209
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
210210
"""
211211

212+
_always_upcast_modules = ["patch_embed", "norm.*"]
212213
_supports_gradient_checkpointing = True
213214

214215
@register_to_config

src/diffusers/models/transformers/dit_transformer_2d.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
6464
A small constant added to the denominator in normalization layers to prevent division by zero.
6565
"""
6666

67+
_always_upcast_modules = ["pos_embed", "norm.*"]
6768
_supports_gradient_checkpointing = True
6869

6970
@register_to_config

src/diffusers/models/transformers/hunyuan_transformer_2d.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,8 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
244244
Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
245245
"""
246246

247+
_always_upcast_modules = ["pos_embed", "norm.*", "pooler"]
248+
247249
@register_to_config
248250
def __init__(
249251
self,

src/diffusers/models/transformers/latte_transformer_3d.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
6565
The number of frames in the video-like data.
6666
"""
6767

68+
_always_upcast_modules = ["pos_embed", "norm.*"]
69+
6870
@register_to_config
6971
def __init__(
7072
self,

src/diffusers/models/transformers/lumina_nextdit2d.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,8 @@ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
221221
overall scale of the model's operations.
222222
"""
223223

224+
_always_upcast_modules = ["patch_embedder", "norm.*", "ffn_norm.*"]
225+
224226
@register_to_config
225227
def __init__(
226228
self,

0 commit comments

Comments
 (0)