Skip to content

Commit 7dc739b

Browse files
committed
make style
1 parent 42046c0 commit 7dc739b

File tree

1 file changed

+67
-42
lines changed

1 file changed

+67
-42
lines changed

src/diffusers/models/hooks.py

Lines changed: 67 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,24 @@
1414

1515
import functools
1616
import re
17-
from dataclasses import dataclass
1817
from enum import Enum
19-
from typing import Any, Dict, Tuple, List, Type
18+
from typing import Any, Dict, List, Tuple, Type
2019

2120
import torch
2221

2322
from ..utils import get_logger
2423
from .attention import FeedForward, LuminaFeedForward
25-
from .embeddings import LuminaPatchEmbed, CogVideoXPatchEmbed, CogView3PlusPatchEmbed, TimestepEmbedding, HunyuanDiTAttentionPool, AttentionPooling, MochiAttentionPool, GLIGENTextBoundingboxProjection, PixArtAlphaTextProjection
24+
from .embeddings import (
25+
AttentionPooling,
26+
CogVideoXPatchEmbed,
27+
CogView3PlusPatchEmbed,
28+
GLIGENTextBoundingboxProjection,
29+
HunyuanDiTAttentionPool,
30+
LuminaPatchEmbed,
31+
MochiAttentionPool,
32+
PixArtAlphaTextProjection,
33+
TimestepEmbedding,
34+
)
2635

2736

2837
logger = get_logger(__name__) # pylint: disable=invalid-name
@@ -38,6 +47,7 @@ class ModelHook:
3847
def init_hook(self, module: torch.nn.Module) -> torch.nn.Module:
3948
r"""
4049
Hook that is executed when a model is initialized.
50+
4151
Args:
4252
module (`torch.nn.Module`):
4353
The module attached to this hook.
@@ -47,6 +57,7 @@ def init_hook(self, module: torch.nn.Module) -> torch.nn.Module:
4757
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
4858
r"""
4959
Hook that is executed just before the forward method of the model.
60+
5061
Args:
5162
module (`torch.nn.Module`):
5263
The module whose forward pass will be executed just after this event.
@@ -63,6 +74,7 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[A
6374
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
6475
r"""
6576
Hook that is executed just after the forward method of the model.
77+
6678
Args:
6779
module (`torch.nn.Module`):
6880
The module whose forward pass been executed just before this event.
@@ -76,6 +88,7 @@ def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
7688
def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
7789
r"""
7890
Hook that is executed when the hook is detached from a module.
91+
7992
Args:
8093
module (`torch.nn.Module`):
8194
The module detached from this hook.
@@ -112,10 +125,10 @@ def detach_hook(self, module):
112125

113126
class LayerwiseUpcastingHook(ModelHook):
114127
r"""
115-
A hook that cast the input tensors and torch.nn.Module to a pre-specified dtype before the forward pass
116-
and cast the module back to the original dtype after the forward pass. This is useful when a model is
117-
loaded/stored in a lower precision dtype but performs computation in a higher precision dtype. This
118-
process may lead to quality loss in the output, but can significantly reduce the memory footprint.
128+
A hook that cast the input tensors and torch.nn.Module to a pre-specified dtype before the forward pass and cast
129+
the module back to the original dtype after the forward pass. This is useful when a model is loaded/stored in a
130+
lower precision dtype but performs computation in a higher precision dtype. This process may lead to quality loss
131+
in the output, but can significantly reduce the memory footprint.
119132
"""
120133

121134
def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype) -> None:
@@ -144,10 +157,14 @@ def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool =
144157
r"""
145158
Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
146159
this behavior and restore the original `forward` method, use `remove_hook_from_module`.
160+
147161
<Tip warning={true}>
162+
148163
If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks
149164
together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class.
165+
150166
</Tip>
167+
151168
Args:
152169
module (`torch.nn.Module`):
153170
The module to attach a hook to.
@@ -198,6 +215,7 @@ def new_forward(module, *args, **kwargs):
198215
def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> torch.nn.Module:
199216
"""
200217
Removes any hook attached to a module via `add_hook_to_module`.
218+
201219
Args:
202220
module (`torch.nn.Module`):
203221
The module to attach a hook to.
@@ -231,10 +249,11 @@ def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> t
231249
def align_maybe_tensor_dtype(input: Any, dtype: torch.dtype) -> Any:
232250
r"""
233251
Aligns the dtype of a tensor or a list of tensors to a given dtype.
252+
234253
Args:
235254
input (`Any`):
236-
The input tensor, list of tensors, or dictionary of tensors to align. If the input is neither
237-
of these types, it will be returned as is.
255+
The input tensor, list of tensors, or dictionary of tensors to align. If the input is neither of these
256+
types, it will be returned as is.
238257
dtype (`torch.dtype`):
239258
The dtype to align the tensor(s) to.
240259
Returns:
@@ -256,38 +275,38 @@ class LayerwiseUpcastingGranualarity(str, Enum):
256275
257276
Granularity can be one of the following:
258277
- `DIFFUSERS_MODEL`:
259-
Applies layerwise upcasting to the entire model at the highest diffusers modeling level. This
260-
will cast all the layers of model to the specified storage dtype. This results in the lowest
261-
memory usage for storing the model in memory, but may incur significant loss in quality because
262-
layers that perform normalization with learned parameters (e.g., RMSNorm with elementwise affinity)
263-
are cast to a lower dtype, but this is known to cause quality issues. This method will not reduce the
264-
memory required for the forward pass (which comprises of intermediate activations and gradients) of a
265-
given modeling component, but may be useful in cases like lowering the memory footprint of text
266-
encoders in a pipeline.
278+
Applies layerwise upcasting to the entire model at the highest diffusers modeling level. This will cast all
279+
the layers of model to the specified storage dtype. This results in the lowest memory usage for storing the
280+
model in memory, but may incur significant loss in quality because layers that perform normalization with
281+
learned parameters (e.g., RMSNorm with elementwise affinity) are cast to a lower dtype, but this is known
282+
to cause quality issues. This method will not reduce the memory required for the forward pass (which
283+
comprises of intermediate activations and gradients) of a given modeling component, but may be useful in
284+
cases like lowering the memory footprint of text encoders in a pipeline.
267285
- `DIFFUSERS_BLOCK`:
268286
TODO???
269287
- `DIFFUSERS_LAYER`:
270-
Applies layerwise upcasting to the lower-level diffusers layers of the model. This is more granular
271-
than the `DIFFUSERS_MODEL` level, but less granular than the `PYTORCH_LAYER` level. This method is
272-
applied to only those layers that are a group of linear layers, while excluding precision-critical
273-
layers like modulation and normalization layers.
288+
Applies layerwise upcasting to the lower-level diffusers layers of the model. This is more granular than
289+
the `DIFFUSERS_MODEL` level, but less granular than the `PYTORCH_LAYER` level. This method is applied to
290+
only those layers that are a group of linear layers, while excluding precision-critical layers like
291+
modulation and normalization layers.
274292
- `PYTORCH_LAYER`:
275-
Applies layerwise upcasting to lower-level PyTorch primitive layers of the model. This is the most
276-
granular level of layerwise upcasting. The memory footprint for inference and training is greatly
277-
reduced, while also ensuring important operations like normalization with learned parameters remain
278-
unaffected from the downcasting/upcasting process, by default. As not all parameters are casted to
279-
lower precision, the memory footprint for storing the model may be slightly higher than the alternatives.
280-
This method causes the highest number of casting operations, which may contribute to a slight increase
281-
in the overall computation time.
282-
283-
Note: try and ensure that precision-critical layers like modulation and normalization layers are not casted
284-
to lower precision, as this may lead to significant quality loss.
293+
Applies layerwise upcasting to lower-level PyTorch primitive layers of the model. This is the most granular
294+
level of layerwise upcasting. The memory footprint for inference and training is greatly reduced, while
295+
also ensuring important operations like normalization with learned parameters remain unaffected from the
296+
downcasting/upcasting process, by default. As not all parameters are casted to lower precision, the memory
297+
footprint for storing the model may be slightly higher than the alternatives. This method causes the
298+
highest number of casting operations, which may contribute to a slight increase in the overall computation
299+
time.
300+
301+
Note: try and ensure that precision-critical layers like modulation and normalization layers are not casted to
302+
lower precision, as this may lead to significant quality loss.
285303
"""
286-
304+
287305
DIFFUSERS_MODEL = "diffusers_model"
288306
DIFFUSERS_LAYER = "diffusers_layer"
289307
PYTORCH_LAYER = "pytorch_layer"
290308

309+
291310
# fmt: off
292311
_SUPPORTED_DIFFUSERS_LAYERS = [
293312
AttentionPooling, MochiAttentionPool, HunyuanDiTAttentionPool,
@@ -306,18 +325,20 @@ class LayerwiseUpcastingGranualarity(str, Enum):
306325
# fmt: on
307326

308327

309-
def apply_layerwise_upcasting_hook(module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype) -> torch.nn.Module:
328+
def apply_layerwise_upcasting_hook(
329+
module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype
330+
) -> torch.nn.Module:
310331
r"""
311332
Applies a `LayerwiseUpcastingHook` to a given module.
312-
333+
313334
Args:
314335
module (`torch.nn.Module`):
315336
The module to attach the hook to.
316337
storage_dtype (`torch.dtype`):
317338
The dtype to cast the module to before the forward pass.
318339
compute_dtype (`torch.dtype`):
319340
The dtype to cast the module to during the forward pass.
320-
341+
321342
Returns:
322343
`torch.nn.Module`:
323344
The same module, with the hook attached (the module is modified in place, so the result can be discarded).
@@ -337,9 +358,13 @@ def apply_layerwise_upcasting(
337358
if granularity == LayerwiseUpcastingGranualarity.DIFFUSERS_MODEL:
338359
return _apply_layerwise_upcasting_diffusers_model(module, storage_dtype, compute_dtype)
339360
if granularity == LayerwiseUpcastingGranualarity.DIFFUSERS_LAYER:
340-
return _apply_layerwise_upcasting_diffusers_layer(module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes)
361+
return _apply_layerwise_upcasting_diffusers_layer(
362+
module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes
363+
)
341364
if granularity == LayerwiseUpcastingGranualarity.PYTORCH_LAYER:
342-
return _apply_layerwise_upcasting_pytorch_layer(module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes)
365+
return _apply_layerwise_upcasting_pytorch_layer(
366+
module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes
367+
)
343368

344369

345370
def _apply_layerwise_upcasting_diffusers_model(
@@ -352,7 +377,7 @@ def _apply_layerwise_upcasting_diffusers_model(
352377
if not isinstance(module, ModelMixin):
353378
raise ValueError("The input module must be an instance of ModelMixin")
354379

355-
logger.debug(f"Applying layerwise upcasting to model \"{module.__class__.__name__}\"")
380+
logger.debug(f'Applying layerwise upcasting to model "{module.__class__.__name__}"')
356381
apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype)
357382
return module
358383

@@ -370,9 +395,9 @@ def _apply_layerwise_upcasting_diffusers_layer(
370395
or any(isinstance(submodule, module_class) for module_class in skip_modules_classes)
371396
or not isinstance(submodule, tuple(_SUPPORTED_DIFFUSERS_LAYERS))
372397
):
373-
logger.debug(f"Skipping layerwise upcasting for layer \"{name}\"")
398+
logger.debug(f'Skipping layerwise upcasting for layer "{name}"')
374399
continue
375-
logger.debug(f"Applying layerwise upcasting to layer \"{name}\"")
400+
logger.debug(f'Applying layerwise upcasting to layer "{name}"')
376401
apply_layerwise_upcasting_hook(submodule, storage_dtype, compute_dtype)
377402
return module
378403

@@ -390,8 +415,8 @@ def _apply_layerwise_upcasting_pytorch_layer(
390415
or any(isinstance(submodule, module_class) for module_class in skip_modules_classes)
391416
or not isinstance(submodule, tuple(_SUPPORTED_PYTORCH_LAYERS))
392417
):
393-
logger.debug(f"Skipping layerwise upcasting for layer \"{name}\"")
418+
logger.debug(f'Skipping layerwise upcasting for layer "{name}"')
394419
continue
395-
logger.debug(f"Applying layerwise upcasting to layer \"{name}\"")
420+
logger.debug(f'Applying layerwise upcasting to layer "{name}"')
396421
apply_layerwise_upcasting_hook(submodule, storage_dtype, compute_dtype)
397422
return module

0 commit comments

Comments
 (0)