1414
1515import functools
1616import re
17- from dataclasses import dataclass
1817from enum import Enum
19- from typing import Any , Dict , Tuple , List , Type
18+ from typing import Any , Dict , List , Tuple , Type
2019
2120import torch
2221
2322from ..utils import get_logger
2423from .attention import FeedForward , LuminaFeedForward
25- from .embeddings import LuminaPatchEmbed , CogVideoXPatchEmbed , CogView3PlusPatchEmbed , TimestepEmbedding , HunyuanDiTAttentionPool , AttentionPooling , MochiAttentionPool , GLIGENTextBoundingboxProjection , PixArtAlphaTextProjection
24+ from .embeddings import (
25+ AttentionPooling ,
26+ CogVideoXPatchEmbed ,
27+ CogView3PlusPatchEmbed ,
28+ GLIGENTextBoundingboxProjection ,
29+ HunyuanDiTAttentionPool ,
30+ LuminaPatchEmbed ,
31+ MochiAttentionPool ,
32+ PixArtAlphaTextProjection ,
33+ TimestepEmbedding ,
34+ )
2635
2736
2837logger = get_logger (__name__ ) # pylint: disable=invalid-name
@@ -38,6 +47,7 @@ class ModelHook:
3847 def init_hook (self , module : torch .nn .Module ) -> torch .nn .Module :
3948 r"""
4049 Hook that is executed when a model is initialized.
50+
4151 Args:
4252 module (`torch.nn.Module`):
4353 The module attached to this hook.
@@ -47,6 +57,7 @@ def init_hook(self, module: torch.nn.Module) -> torch.nn.Module:
4757 def pre_forward (self , module : torch .nn .Module , * args , ** kwargs ) -> Tuple [Tuple [Any ], Dict [str , Any ]]:
4858 r"""
4959 Hook that is executed just before the forward method of the model.
60+
5061 Args:
5162 module (`torch.nn.Module`):
5263 The module whose forward pass will be executed just after this event.
@@ -63,6 +74,7 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[A
6374 def post_forward (self , module : torch .nn .Module , output : Any ) -> Any :
6475 r"""
6576 Hook that is executed just after the forward method of the model.
77+
6678 Args:
6779 module (`torch.nn.Module`):
6880 The module whose forward pass been executed just before this event.
@@ -76,6 +88,7 @@ def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
7688 def detach_hook (self , module : torch .nn .Module ) -> torch .nn .Module :
7789 r"""
7890 Hook that is executed when the hook is detached from a module.
91+
7992 Args:
8093 module (`torch.nn.Module`):
8194 The module detached from this hook.
@@ -112,10 +125,10 @@ def detach_hook(self, module):
112125
113126class LayerwiseUpcastingHook (ModelHook ):
114127 r"""
115- A hook that cast the input tensors and torch.nn.Module to a pre-specified dtype before the forward pass
116- and cast the module back to the original dtype after the forward pass. This is useful when a model is
117- loaded/stored in a lower precision dtype but performs computation in a higher precision dtype. This
118- process may lead to quality loss in the output, but can significantly reduce the memory footprint.
128+ A hook that cast the input tensors and torch.nn.Module to a pre-specified dtype before the forward pass and cast
129+ the module back to the original dtype after the forward pass. This is useful when a model is loaded/stored in a
130+ lower precision dtype but performs computation in a higher precision dtype. This process may lead to quality loss
131+ in the output, but can significantly reduce the memory footprint.
119132 """
120133
121134 def __init__ (self , storage_dtype : torch .dtype , compute_dtype : torch .dtype ) -> None :
@@ -144,10 +157,14 @@ def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool =
144157 r"""
145158 Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
146159 this behavior and restore the original `forward` method, use `remove_hook_from_module`.
160+
147161 <Tip warning={true}>
162+
148163 If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks
149164 together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class.
165+
150166 </Tip>
167+
151168 Args:
152169 module (`torch.nn.Module`):
153170 The module to attach a hook to.
@@ -198,6 +215,7 @@ def new_forward(module, *args, **kwargs):
198215def remove_hook_from_module (module : torch .nn .Module , recurse : bool = False ) -> torch .nn .Module :
199216 """
200217 Removes any hook attached to a module via `add_hook_to_module`.
218+
201219 Args:
202220 module (`torch.nn.Module`):
203221 The module to attach a hook to.
@@ -231,10 +249,11 @@ def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> t
231249def align_maybe_tensor_dtype (input : Any , dtype : torch .dtype ) -> Any :
232250 r"""
233251 Aligns the dtype of a tensor or a list of tensors to a given dtype.
252+
234253 Args:
235254 input (`Any`):
236- The input tensor, list of tensors, or dictionary of tensors to align. If the input is neither
237- of these types, it will be returned as is.
255+ The input tensor, list of tensors, or dictionary of tensors to align. If the input is neither of these
256+ types, it will be returned as is.
238257 dtype (`torch.dtype`):
239258 The dtype to align the tensor(s) to.
240259 Returns:
@@ -256,38 +275,38 @@ class LayerwiseUpcastingGranualarity(str, Enum):
256275
257276 Granularity can be one of the following:
258277 - `DIFFUSERS_MODEL`:
259- Applies layerwise upcasting to the entire model at the highest diffusers modeling level. This
260- will cast all the layers of model to the specified storage dtype. This results in the lowest
261- memory usage for storing the model in memory, but may incur significant loss in quality because
262- layers that perform normalization with learned parameters (e.g., RMSNorm with elementwise affinity)
263- are cast to a lower dtype, but this is known to cause quality issues. This method will not reduce the
264- memory required for the forward pass (which comprises of intermediate activations and gradients) of a
265- given modeling component, but may be useful in cases like lowering the memory footprint of text
266- encoders in a pipeline.
278+ Applies layerwise upcasting to the entire model at the highest diffusers modeling level. This will cast all
279+ the layers of model to the specified storage dtype. This results in the lowest memory usage for storing the
280+ model in memory, but may incur significant loss in quality because layers that perform normalization with
281+ learned parameters (e.g., RMSNorm with elementwise affinity) are cast to a lower dtype, but this is known
282+ to cause quality issues. This method will not reduce the memory required for the forward pass (which
283+ comprises of intermediate activations and gradients) of a given modeling component, but may be useful in
284+ cases like lowering the memory footprint of text encoders in a pipeline.
267285 - `DIFFUSERS_BLOCK`:
268286 TODO???
269287 - `DIFFUSERS_LAYER`:
270- Applies layerwise upcasting to the lower-level diffusers layers of the model. This is more granular
271- than the `DIFFUSERS_MODEL` level, but less granular than the `PYTORCH_LAYER` level. This method is
272- applied to only those layers that are a group of linear layers, while excluding precision-critical
273- layers like modulation and normalization layers.
288+ Applies layerwise upcasting to the lower-level diffusers layers of the model. This is more granular than
289+ the `DIFFUSERS_MODEL` level, but less granular than the `PYTORCH_LAYER` level. This method is applied to
290+ only those layers that are a group of linear layers, while excluding precision-critical layers like
291+ modulation and normalization layers.
274292 - `PYTORCH_LAYER`:
275- Applies layerwise upcasting to lower-level PyTorch primitive layers of the model. This is the most
276- granular level of layerwise upcasting. The memory footprint for inference and training is greatly
277- reduced, while also ensuring important operations like normalization with learned parameters remain
278- unaffected from the downcasting/upcasting process, by default. As not all parameters are casted to
279- lower precision, the memory footprint for storing the model may be slightly higher than the alternatives.
280- This method causes the highest number of casting operations, which may contribute to a slight increase
281- in the overall computation time.
282-
283- Note: try and ensure that precision-critical layers like modulation and normalization layers are not casted
284- to lower precision, as this may lead to significant quality loss.
293+ Applies layerwise upcasting to lower-level PyTorch primitive layers of the model. This is the most granular
294+ level of layerwise upcasting. The memory footprint for inference and training is greatly reduced, while
295+ also ensuring important operations like normalization with learned parameters remain unaffected from the
296+ downcasting/upcasting process, by default. As not all parameters are casted to lower precision, the memory
297+ footprint for storing the model may be slightly higher than the alternatives. This method causes the
298+ highest number of casting operations, which may contribute to a slight increase in the overall computation
299+ time.
300+
301+ Note: try and ensure that precision-critical layers like modulation and normalization layers are not casted to
302+ lower precision, as this may lead to significant quality loss.
285303 """
286-
304+
287305 DIFFUSERS_MODEL = "diffusers_model"
288306 DIFFUSERS_LAYER = "diffusers_layer"
289307 PYTORCH_LAYER = "pytorch_layer"
290308
309+
291310# fmt: off
292311_SUPPORTED_DIFFUSERS_LAYERS = [
293312 AttentionPooling , MochiAttentionPool , HunyuanDiTAttentionPool ,
@@ -306,18 +325,20 @@ class LayerwiseUpcastingGranualarity(str, Enum):
306325# fmt: on
307326
308327
309- def apply_layerwise_upcasting_hook (module : torch .nn .Module , storage_dtype : torch .dtype , compute_dtype : torch .dtype ) -> torch .nn .Module :
328+ def apply_layerwise_upcasting_hook (
329+ module : torch .nn .Module , storage_dtype : torch .dtype , compute_dtype : torch .dtype
330+ ) -> torch .nn .Module :
310331 r"""
311332 Applies a `LayerwiseUpcastingHook` to a given module.
312-
333+
313334 Args:
314335 module (`torch.nn.Module`):
315336 The module to attach the hook to.
316337 storage_dtype (`torch.dtype`):
317338 The dtype to cast the module to before the forward pass.
318339 compute_dtype (`torch.dtype`):
319340 The dtype to cast the module to during the forward pass.
320-
341+
321342 Returns:
322343 `torch.nn.Module`:
323344 The same module, with the hook attached (the module is modified in place, so the result can be discarded).
@@ -337,9 +358,13 @@ def apply_layerwise_upcasting(
337358 if granularity == LayerwiseUpcastingGranualarity .DIFFUSERS_MODEL :
338359 return _apply_layerwise_upcasting_diffusers_model (module , storage_dtype , compute_dtype )
339360 if granularity == LayerwiseUpcastingGranualarity .DIFFUSERS_LAYER :
340- return _apply_layerwise_upcasting_diffusers_layer (module , storage_dtype , compute_dtype , skip_modules_pattern , skip_modules_classes )
361+ return _apply_layerwise_upcasting_diffusers_layer (
362+ module , storage_dtype , compute_dtype , skip_modules_pattern , skip_modules_classes
363+ )
341364 if granularity == LayerwiseUpcastingGranualarity .PYTORCH_LAYER :
342- return _apply_layerwise_upcasting_pytorch_layer (module , storage_dtype , compute_dtype , skip_modules_pattern , skip_modules_classes )
365+ return _apply_layerwise_upcasting_pytorch_layer (
366+ module , storage_dtype , compute_dtype , skip_modules_pattern , skip_modules_classes
367+ )
343368
344369
345370def _apply_layerwise_upcasting_diffusers_model (
@@ -352,7 +377,7 @@ def _apply_layerwise_upcasting_diffusers_model(
352377 if not isinstance (module , ModelMixin ):
353378 raise ValueError ("The input module must be an instance of ModelMixin" )
354379
355- logger .debug (f" Applying layerwise upcasting to model \ "{ module .__class__ .__name__ } \" " )
380+ logger .debug (f' Applying layerwise upcasting to model "{ module .__class__ .__name__ } "' )
356381 apply_layerwise_upcasting_hook (module , storage_dtype , compute_dtype )
357382 return module
358383
@@ -370,9 +395,9 @@ def _apply_layerwise_upcasting_diffusers_layer(
370395 or any (isinstance (submodule , module_class ) for module_class in skip_modules_classes )
371396 or not isinstance (submodule , tuple (_SUPPORTED_DIFFUSERS_LAYERS ))
372397 ):
373- logger .debug (f" Skipping layerwise upcasting for layer \ "{ name } \" " )
398+ logger .debug (f' Skipping layerwise upcasting for layer "{ name } "' )
374399 continue
375- logger .debug (f" Applying layerwise upcasting to layer \ "{ name } \" " )
400+ logger .debug (f' Applying layerwise upcasting to layer "{ name } "' )
376401 apply_layerwise_upcasting_hook (submodule , storage_dtype , compute_dtype )
377402 return module
378403
@@ -390,8 +415,8 @@ def _apply_layerwise_upcasting_pytorch_layer(
390415 or any (isinstance (submodule , module_class ) for module_class in skip_modules_classes )
391416 or not isinstance (submodule , tuple (_SUPPORTED_PYTORCH_LAYERS ))
392417 ):
393- logger .debug (f" Skipping layerwise upcasting for layer \ "{ name } \" " )
418+ logger .debug (f' Skipping layerwise upcasting for layer "{ name } "' )
394419 continue
395- logger .debug (f" Applying layerwise upcasting to layer \ "{ name } \" " )
420+ logger .debug (f' Applying layerwise upcasting to layer "{ name } "' )
396421 apply_layerwise_upcasting_hook (submodule , storage_dtype , compute_dtype )
397422 return module
0 commit comments