2121import torch
2222
2323from ..utils import get_logger
24+ from .attention import FeedForward , LuminaFeedForward
25+ from .embeddings import LuminaPatchEmbed , CogVideoXPatchEmbed , CogView3PlusPatchEmbed , TimestepEmbedding , HunyuanDiTAttentionPool , AttentionPooling , MochiAttentionPool , GLIGENTextBoundingboxProjection , PixArtAlphaTextProjection
2426
2527
2628logger = get_logger (__name__ ) # pylint: disable=invalid-name
@@ -249,16 +251,58 @@ def align_maybe_tensor_dtype(input: Any, dtype: torch.dtype) -> Any:
249251
250252
251253class LayerwiseUpcastingGranualarity (str , Enum ):
254+ r"""
255+ An enumeration class that defines the granularity of the layerwise upcasting process.
256+
257+ Granularity can be one of the following:
258+ - `DIFFUSERS_MODEL`:
259+ Applies layerwise upcasting to the entire model at the highest diffusers modeling level. This
260+ will cast all the layers of model to the specified storage dtype. This results in the lowest
261+ memory usage for storing the model in memory, but may incur significant loss in quality because
262+ layers that perform normalization with learned parameters (e.g., RMSNorm with elementwise affinity)
263+ are cast to a lower dtype, but this is known to cause quality issues. This method will not reduce the
264+ memory required for the forward pass (which comprises of intermediate activations and gradients) of a
265+ given modeling component, but may be useful in cases like lowering the memory footprint of text
266+ encoders in a pipeline.
267+ - `DIFFUSERS_BLOCK`:
268+ TODO???
269+ - `DIFFUSERS_LAYER`:
270+ Applies layerwise upcasting to the lower-level diffusers layers of the model. This is more granular
271+ than the `DIFFUSERS_MODEL` level, but less granular than the `PYTORCH_LAYER` level. This method is
272+ applied to only those layers that are a group of linear layers, while excluding precision-critical
273+ layers like modulation and normalization layers.
274+ - `PYTORCH_LAYER`:
275+ Applies layerwise upcasting to lower-level PyTorch primitive layers of the model. This is the most
276+ granular level of layerwise upcasting. The memory footprint for inference and training is greatly
277+ reduced, while also ensuring important operations like normalization with learned parameters remain
278+ unaffected from the downcasting/upcasting process, by default. As not all parameters are casted to
279+ lower precision, the memory footprint for storing the model may be slightly higher than the alternatives.
280+ This method causes the highest number of casting operations, which may contribute to a slight increase
281+ in the overall computation time.
282+
283+ Note: try and ensure that precision-critical layers like modulation and normalization layers are not casted
284+ to lower precision, as this may lead to significant quality loss.
285+ """
286+
252287 DIFFUSERS_MODEL = "diffusers_model"
253288 DIFFUSERS_LAYER = "diffusers_layer"
254289 PYTORCH_LAYER = "pytorch_layer"
255290
256291# fmt: off
292+ _SUPPORTED_DIFFUSERS_LAYERS = [
293+ AttentionPooling , MochiAttentionPool , HunyuanDiTAttentionPool ,
294+ CogVideoXPatchEmbed , CogView3PlusPatchEmbed , LuminaPatchEmbed ,
295+ TimestepEmbedding , GLIGENTextBoundingboxProjection , PixArtAlphaTextProjection ,
296+ FeedForward , LuminaFeedForward ,
297+ ]
298+
257299_SUPPORTED_PYTORCH_LAYERS = [
258300 torch .nn .Conv1d , torch .nn .Conv2d , torch .nn .Conv3d ,
259301 torch .nn .ConvTranspose1d , torch .nn .ConvTranspose2d , torch .nn .ConvTranspose3d ,
260302 torch .nn .Linear ,
261303]
304+
305+ _DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN = ["pos_embed" , "patch_embed" , "norm" ]
262306# fmt: on
263307
264308
@@ -291,9 +335,9 @@ def apply_layerwise_upcasting(
291335 skip_modules_classes : List [Type [torch .nn .Module ]] = [],
292336) -> torch .nn .Module :
293337 if granularity == LayerwiseUpcastingGranualarity .DIFFUSERS_MODEL :
294- return _apply_layerwise_upcasting_diffusers_model (module , storage_dtype , compute_dtype , skip_modules_pattern , skip_modules_classes )
338+ return _apply_layerwise_upcasting_diffusers_model (module , storage_dtype , compute_dtype )
295339 if granularity == LayerwiseUpcastingGranualarity .DIFFUSERS_LAYER :
296- raise NotImplementedError ( f" { LayerwiseUpcastingGranualarity . DIFFUSERS_LAYER } is not yet supported" )
340+ return _apply_layerwise_upcasting_diffusers_layer ( module , storage_dtype , compute_dtype , skip_modules_pattern , skip_modules_classes )
297341 if granularity == LayerwiseUpcastingGranualarity .PYTORCH_LAYER :
298342 return _apply_layerwise_upcasting_pytorch_layer (module , storage_dtype , compute_dtype , skip_modules_pattern , skip_modules_classes )
299343
@@ -302,16 +346,29 @@ def _apply_layerwise_upcasting_diffusers_model(
302346 module : torch .nn .Module ,
303347 storage_dtype : torch .dtype ,
304348 compute_dtype : torch .dtype ,
305- skip_modules_pattern : List [str ] = [],
306- skip_modules_classes : List [Type [torch .nn .Module ]] = [],
307349) -> torch .nn .Module :
308350 from .modeling_utils import ModelMixin
309351
352+ if not isinstance (module , ModelMixin ):
353+ raise ValueError ("The input module must be an instance of ModelMixin" )
354+
355+ logger .debug (f"Applying layerwise upcasting to model \" { module .__class__ .__name__ } \" " )
356+ apply_layerwise_upcasting_hook (module , storage_dtype , compute_dtype )
357+ return module
358+
359+
360+ def _apply_layerwise_upcasting_diffusers_layer (
361+ module : torch .nn .Module ,
362+ storage_dtype : torch .dtype ,
363+ compute_dtype : torch .dtype ,
364+ skip_modules_pattern : List [str ] = _DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN ,
365+ skip_modules_classes : List [Type [torch .nn .Module ]] = [],
366+ ) -> torch .nn .Module :
310367 for name , submodule in module .named_modules ():
311368 if (
312369 any (re .search (pattern , name ) for pattern in skip_modules_pattern )
313370 or any (isinstance (submodule , module_class ) for module_class in skip_modules_classes )
314- or not isinstance (submodule , ModelMixin )
371+ or not isinstance (submodule , tuple ( _SUPPORTED_DIFFUSERS_LAYERS ) )
315372 ):
316373 logger .debug (f"Skipping layerwise upcasting for layer \" { name } \" " )
317374 continue
@@ -324,7 +381,7 @@ def _apply_layerwise_upcasting_pytorch_layer(
324381 module : torch .nn .Module ,
325382 storage_dtype : torch .dtype ,
326383 compute_dtype : torch .dtype ,
327- skip_modules_pattern : List [str ] = [] ,
384+ skip_modules_pattern : List [str ] = _DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN ,
328385 skip_modules_classes : List [Type [torch .nn .Module ]] = [],
329386) -> torch .nn .Module :
330387 for name , submodule in module .named_modules ():
0 commit comments