Skip to content

Commit 42046c0

Browse files
committed
update
1 parent 36b0c37 commit 42046c0

File tree

1 file changed

+63
-6
lines changed

1 file changed

+63
-6
lines changed

src/diffusers/models/hooks.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import torch
2222

2323
from ..utils import get_logger
24+
from .attention import FeedForward, LuminaFeedForward
25+
from .embeddings import LuminaPatchEmbed, CogVideoXPatchEmbed, CogView3PlusPatchEmbed, TimestepEmbedding, HunyuanDiTAttentionPool, AttentionPooling, MochiAttentionPool, GLIGENTextBoundingboxProjection, PixArtAlphaTextProjection
2426

2527

2628
logger = get_logger(__name__) # pylint: disable=invalid-name
@@ -249,16 +251,58 @@ def align_maybe_tensor_dtype(input: Any, dtype: torch.dtype) -> Any:
249251

250252

251253
class LayerwiseUpcastingGranualarity(str, Enum):
254+
r"""
255+
An enumeration class that defines the granularity of the layerwise upcasting process.
256+
257+
Granularity can be one of the following:
258+
- `DIFFUSERS_MODEL`:
259+
Applies layerwise upcasting to the entire model at the highest diffusers modeling level. This
260+
will cast all the layers of model to the specified storage dtype. This results in the lowest
261+
memory usage for storing the model in memory, but may incur significant loss in quality because
262+
layers that perform normalization with learned parameters (e.g., RMSNorm with elementwise affinity)
263+
are cast to a lower dtype, but this is known to cause quality issues. This method will not reduce the
264+
memory required for the forward pass (which comprises of intermediate activations and gradients) of a
265+
given modeling component, but may be useful in cases like lowering the memory footprint of text
266+
encoders in a pipeline.
267+
- `DIFFUSERS_BLOCK`:
268+
TODO???
269+
- `DIFFUSERS_LAYER`:
270+
Applies layerwise upcasting to the lower-level diffusers layers of the model. This is more granular
271+
than the `DIFFUSERS_MODEL` level, but less granular than the `PYTORCH_LAYER` level. This method is
272+
applied to only those layers that are a group of linear layers, while excluding precision-critical
273+
layers like modulation and normalization layers.
274+
- `PYTORCH_LAYER`:
275+
Applies layerwise upcasting to lower-level PyTorch primitive layers of the model. This is the most
276+
granular level of layerwise upcasting. The memory footprint for inference and training is greatly
277+
reduced, while also ensuring important operations like normalization with learned parameters remain
278+
unaffected from the downcasting/upcasting process, by default. As not all parameters are casted to
279+
lower precision, the memory footprint for storing the model may be slightly higher than the alternatives.
280+
This method causes the highest number of casting operations, which may contribute to a slight increase
281+
in the overall computation time.
282+
283+
Note: try and ensure that precision-critical layers like modulation and normalization layers are not casted
284+
to lower precision, as this may lead to significant quality loss.
285+
"""
286+
252287
DIFFUSERS_MODEL = "diffusers_model"
253288
DIFFUSERS_LAYER = "diffusers_layer"
254289
PYTORCH_LAYER = "pytorch_layer"
255290

256291
# fmt: off
292+
_SUPPORTED_DIFFUSERS_LAYERS = [
293+
AttentionPooling, MochiAttentionPool, HunyuanDiTAttentionPool,
294+
CogVideoXPatchEmbed, CogView3PlusPatchEmbed, LuminaPatchEmbed,
295+
TimestepEmbedding, GLIGENTextBoundingboxProjection, PixArtAlphaTextProjection,
296+
FeedForward, LuminaFeedForward,
297+
]
298+
257299
_SUPPORTED_PYTORCH_LAYERS = [
258300
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
259301
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
260302
torch.nn.Linear,
261303
]
304+
305+
_DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN = ["pos_embed", "patch_embed", "norm"]
262306
# fmt: on
263307

264308

@@ -291,9 +335,9 @@ def apply_layerwise_upcasting(
291335
skip_modules_classes: List[Type[torch.nn.Module]] = [],
292336
) -> torch.nn.Module:
293337
if granularity == LayerwiseUpcastingGranualarity.DIFFUSERS_MODEL:
294-
return _apply_layerwise_upcasting_diffusers_model(module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes)
338+
return _apply_layerwise_upcasting_diffusers_model(module, storage_dtype, compute_dtype)
295339
if granularity == LayerwiseUpcastingGranualarity.DIFFUSERS_LAYER:
296-
raise NotImplementedError(f"{LayerwiseUpcastingGranualarity.DIFFUSERS_LAYER} is not yet supported")
340+
return _apply_layerwise_upcasting_diffusers_layer(module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes)
297341
if granularity == LayerwiseUpcastingGranualarity.PYTORCH_LAYER:
298342
return _apply_layerwise_upcasting_pytorch_layer(module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes)
299343

@@ -302,16 +346,29 @@ def _apply_layerwise_upcasting_diffusers_model(
302346
module: torch.nn.Module,
303347
storage_dtype: torch.dtype,
304348
compute_dtype: torch.dtype,
305-
skip_modules_pattern: List[str] = [],
306-
skip_modules_classes: List[Type[torch.nn.Module]] = [],
307349
) -> torch.nn.Module:
308350
from .modeling_utils import ModelMixin
309351

352+
if not isinstance(module, ModelMixin):
353+
raise ValueError("The input module must be an instance of ModelMixin")
354+
355+
logger.debug(f"Applying layerwise upcasting to model \"{module.__class__.__name__}\"")
356+
apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype)
357+
return module
358+
359+
360+
def _apply_layerwise_upcasting_diffusers_layer(
361+
module: torch.nn.Module,
362+
storage_dtype: torch.dtype,
363+
compute_dtype: torch.dtype,
364+
skip_modules_pattern: List[str] = _DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN,
365+
skip_modules_classes: List[Type[torch.nn.Module]] = [],
366+
) -> torch.nn.Module:
310367
for name, submodule in module.named_modules():
311368
if (
312369
any(re.search(pattern, name) for pattern in skip_modules_pattern)
313370
or any(isinstance(submodule, module_class) for module_class in skip_modules_classes)
314-
or not isinstance(submodule, ModelMixin)
371+
or not isinstance(submodule, tuple(_SUPPORTED_DIFFUSERS_LAYERS))
315372
):
316373
logger.debug(f"Skipping layerwise upcasting for layer \"{name}\"")
317374
continue
@@ -324,7 +381,7 @@ def _apply_layerwise_upcasting_pytorch_layer(
324381
module: torch.nn.Module,
325382
storage_dtype: torch.dtype,
326383
compute_dtype: torch.dtype,
327-
skip_modules_pattern: List[str] = [],
384+
skip_modules_pattern: List[str] = _DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN,
328385
skip_modules_classes: List[Type[torch.nn.Module]] = [],
329386
) -> torch.nn.Module:
330387
for name, submodule in module.named_modules():

0 commit comments

Comments
 (0)