Skip to content

Commit 8975bbf

Browse files
committed
update
1 parent 7c31bb0 commit 8975bbf

File tree

5 files changed

+268
-180
lines changed

5 files changed

+268
-180
lines changed

src/diffusers/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
"I2VGenXLUNet",
108108
"Kandinsky3UNet",
109109
"LatteTransformer3DModel",
110+
"LayerwiseUpcastingGranularity",
110111
"LTXVideoTransformer3DModel",
111112
"LuminaNextDiT2DModel",
112113
"MochiTransformer3DModel",
@@ -135,6 +136,8 @@
135136
"UNetSpatioTemporalConditionModel",
136137
"UVit2DModel",
137138
"VQModel",
139+
"apply_layerwise_upcasting",
140+
"apply_layerwise_upcasting_hook",
138141
]
139142
)
140143
_import_structure["optimization"] = [
@@ -617,6 +620,7 @@
617620
I2VGenXLUNet,
618621
Kandinsky3UNet,
619622
LatteTransformer3DModel,
623+
LayerwiseUpcastingGranularity,
620624
LTXVideoTransformer3DModel,
621625
LuminaNextDiT2DModel,
622626
MochiTransformer3DModel,
@@ -644,6 +648,8 @@
644648
UNetSpatioTemporalConditionModel,
645649
UVit2DModel,
646650
VQModel,
651+
apply_layerwise_upcasting,
652+
apply_layerwise_upcasting_hook,
647653
)
648654
from .optimization import (
649655
get_constant_schedule,

src/diffusers/models/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,11 @@
123123
UNetControlNetXSModel,
124124
)
125125
from .embeddings import ImageProjection
126+
from .layerwise_upcasting_utils import (
127+
LayerwiseUpcastingGranularity,
128+
apply_layerwise_upcasting,
129+
apply_layerwise_upcasting_hook,
130+
)
126131
from .modeling_utils import ModelMixin
127132
from .transformers import (
128133
AllegroTransformer3DModel,

src/diffusers/models/hooks.py

Lines changed: 22 additions & 180 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,11 @@
1313
# limitations under the License.
1414

1515
import functools
16-
import re
17-
from enum import Enum
18-
from typing import Any, Dict, List, Tuple, Type
16+
from typing import Any, Dict, Tuple
1917

2018
import torch
2119

2220
from ..utils import get_logger
23-
from .attention import FeedForward, LuminaFeedForward
24-
from .embeddings import (
25-
AttentionPooling,
26-
CogVideoXPatchEmbed,
27-
CogView3PlusPatchEmbed,
28-
GLIGENTextBoundingboxProjection,
29-
HunyuanDiTAttentionPool,
30-
LuminaPatchEmbed,
31-
MochiAttentionPool,
32-
PixArtAlphaTextProjection,
33-
TimestepEmbedding,
34-
)
3521

3622

3723
logger = get_logger(__name__) # pylint: disable=invalid-name
@@ -44,6 +30,8 @@ class ModelHook:
4430
with PyTorch existing hooks is that they get passed along the kwargs.
4531
"""
4632

33+
_is_stateful = False
34+
4735
def init_hook(self, module: torch.nn.Module) -> torch.nn.Module:
4836
r"""
4937
Hook that is executed when a model is initialized.
@@ -95,6 +83,11 @@ def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
9583
"""
9684
return module
9785

86+
def reset_state(self, module: torch.nn.Module):
87+
if self._is_stateful:
88+
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
89+
return module
90+
9891

9992
class SequentialHook(ModelHook):
10093
r"""A hook that can contain several hooks and iterates through them at each event."""
@@ -122,34 +115,12 @@ def detach_hook(self, module):
122115
module = hook.detach_hook(module)
123116
return module
124117

125-
126-
class LayerwiseUpcastingHook(ModelHook):
127-
r"""
128-
A hook that cast the input tensors and torch.nn.Module to a pre-specified dtype before the forward pass and cast
129-
the module back to the original dtype after the forward pass. This is useful when a model is loaded/stored in a
130-
lower precision dtype but performs computation in a higher precision dtype. This process may lead to quality loss
131-
in the output, but can significantly reduce the memory footprint.
132-
"""
133-
134-
def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype) -> None:
135-
self.storage_dtype = storage_dtype
136-
self.compute_dtype = compute_dtype
137-
138-
def init_hook(self, module: torch.nn.Module):
139-
module.to(dtype=self.storage_dtype)
118+
def reset_state(self, module):
119+
for hook in self.hooks:
120+
if hook._is_stateful:
121+
hook.reset_state(module)
140122
return module
141123

142-
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
143-
module.to(dtype=self.compute_dtype)
144-
# How do we account for LongTensor, BoolTensor, etc.?
145-
# args = tuple(align_maybe_tensor_dtype(arg, self.compute_dtype) for arg in args)
146-
# kwargs = {k: align_maybe_tensor_dtype(v, self.compute_dtype) for k, v in kwargs.items()}
147-
return args, kwargs
148-
149-
def post_forward(self, module: torch.nn.Module, output):
150-
module.to(dtype=self.storage_dtype)
151-
return output
152-
153124

154125
def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False):
155126
r"""
@@ -244,148 +215,19 @@ def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> t
244215
return module
245216

246217

247-
def align_maybe_tensor_dtype(input: Any, dtype: torch.dtype) -> Any:
248-
r"""
249-
Aligns the dtype of a tensor or a list of tensors to a given dtype.
250-
251-
Args:
252-
input (`Any`):
253-
The input tensor, list of tensors, or dictionary of tensors to align. If the input is neither of these
254-
types, it will be returned as is.
255-
dtype (`torch.dtype`):
256-
The dtype to align the tensor(s) to.
257-
Returns:
258-
`Any`:
259-
The tensor or list of tensors aligned to the given dtype.
260-
"""
261-
if isinstance(input, torch.Tensor):
262-
return input.to(dtype=dtype)
263-
if isinstance(input, (list, tuple)):
264-
return [align_maybe_tensor_dtype(t, dtype) for t in input]
265-
if isinstance(input, dict):
266-
return {k: align_maybe_tensor_dtype(v, dtype) for k, v in input.items()}
267-
return input
268-
269-
270-
class LayerwiseUpcastingGranularity(str, Enum):
271-
r"""
272-
An enumeration class that defines the granularity of the layerwise upcasting process.
273-
274-
Granularity can be one of the following:
275-
- `DIFFUSERS_LAYER`:
276-
Applies layerwise upcasting to the lower-level diffusers layers of the model. This method is applied to
277-
only those layers that are a group of linear layers, while excluding precision-critical layers like
278-
modulation and normalization layers.
279-
- `PYTORCH_LAYER`:
280-
Applies layerwise upcasting to lower-level PyTorch primitive layers of the model. This is the most granular
281-
level of layerwise upcasting. The memory footprint for inference and training is greatly reduced, while
282-
also ensuring important operations like normalization with learned parameters remain unaffected from the
283-
downcasting/upcasting process, by default. As not all parameters are casted to lower precision, the memory
284-
footprint for storing the model may be slightly higher than the alternatives. This method causes the
285-
highest number of casting operations, which may contribute to a slight increase in the overall computation
286-
time.
287-
288-
Note: try and ensure that precision-critical layers like modulation and normalization layers are not casted to
289-
lower precision, as this may lead to significant quality loss.
218+
def reset_stateful_hooks(module: torch.nn.Module, recurse: bool = False):
290219
"""
291-
292-
DIFFUSERS_LAYER = "diffusers_layer"
293-
PYTORCH_LAYER = "pytorch_layer"
294-
295-
296-
# fmt: off
297-
_SUPPORTED_DIFFUSERS_LAYERS = [
298-
AttentionPooling, MochiAttentionPool, HunyuanDiTAttentionPool,
299-
CogVideoXPatchEmbed, CogView3PlusPatchEmbed, LuminaPatchEmbed,
300-
TimestepEmbedding, GLIGENTextBoundingboxProjection, PixArtAlphaTextProjection,
301-
FeedForward, LuminaFeedForward,
302-
]
303-
304-
_SUPPORTED_PYTORCH_LAYERS = [
305-
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
306-
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
307-
torch.nn.Linear,
308-
]
309-
310-
_DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN = ["pos_embed", "patch_embed", "norm"]
311-
# fmt: on
312-
313-
314-
def apply_layerwise_upcasting_hook(
315-
module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype
316-
) -> torch.nn.Module:
317-
r"""
318-
Applies a `LayerwiseUpcastingHook` to a given module.
220+
Resets the state of all stateful hooks attached to a module.
319221
320222
Args:
321223
module (`torch.nn.Module`):
322-
The module to attach the hook to.
323-
storage_dtype (`torch.dtype`):
324-
The dtype to cast the module to before the forward pass.
325-
compute_dtype (`torch.dtype`):
326-
The dtype to cast the module to during the forward pass.
327-
328-
Returns:
329-
`torch.nn.Module`:
330-
The same module, with the hook attached (the module is modified in place, so the result can be discarded).
224+
The module to reset the stateful hooks from.
331225
"""
332-
hook = LayerwiseUpcastingHook(storage_dtype, compute_dtype)
333-
return add_hook_to_module(module, hook, append=True)
334-
335-
336-
def apply_layerwise_upcasting(
337-
module: torch.nn.Module,
338-
storage_dtype: torch.dtype,
339-
compute_dtype: torch.dtype,
340-
granularity: LayerwiseUpcastingGranularity = LayerwiseUpcastingGranularity.PYTORCH_LAYER,
341-
skip_modules_pattern: List[str] = [],
342-
skip_modules_classes: List[Type[torch.nn.Module]] = [],
343-
) -> torch.nn.Module:
344-
if granularity == LayerwiseUpcastingGranularity.DIFFUSERS_LAYER:
345-
return _apply_layerwise_upcasting_diffusers_layer(
346-
module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes
347-
)
348-
if granularity == LayerwiseUpcastingGranularity.PYTORCH_LAYER:
349-
return _apply_layerwise_upcasting_pytorch_layer(
350-
module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes
351-
)
352-
353-
354-
def _apply_layerwise_upcasting_diffusers_layer(
355-
module: torch.nn.Module,
356-
storage_dtype: torch.dtype,
357-
compute_dtype: torch.dtype,
358-
skip_modules_pattern: List[str] = _DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN,
359-
skip_modules_classes: List[Type[torch.nn.Module]] = [],
360-
) -> torch.nn.Module:
361-
for name, submodule in module.named_modules():
362-
if (
363-
any(re.search(pattern, name) for pattern in skip_modules_pattern)
364-
or any(isinstance(submodule, module_class) for module_class in skip_modules_classes)
365-
or not isinstance(submodule, tuple(_SUPPORTED_DIFFUSERS_LAYERS))
366-
):
367-
logger.debug(f'Skipping layerwise upcasting for layer "{name}"')
368-
continue
369-
logger.debug(f'Applying layerwise upcasting to layer "{name}"')
370-
apply_layerwise_upcasting_hook(submodule, storage_dtype, compute_dtype)
371-
return module
226+
if hasattr(module, "_diffusers_hook") and (
227+
module._diffusers_hook._is_stateful or isinstance(module._diffusers_hook, SequentialHook)
228+
):
229+
module._diffusers_hook.reset_state(module)
372230

373-
374-
def _apply_layerwise_upcasting_pytorch_layer(
375-
module: torch.nn.Module,
376-
storage_dtype: torch.dtype,
377-
compute_dtype: torch.dtype,
378-
skip_modules_pattern: List[str] = _DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN,
379-
skip_modules_classes: List[Type[torch.nn.Module]] = [],
380-
) -> torch.nn.Module:
381-
for name, submodule in module.named_modules():
382-
if (
383-
any(re.search(pattern, name) for pattern in skip_modules_pattern)
384-
or any(isinstance(submodule, module_class) for module_class in skip_modules_classes)
385-
or not isinstance(submodule, tuple(_SUPPORTED_PYTORCH_LAYERS))
386-
):
387-
logger.debug(f'Skipping layerwise upcasting for layer "{name}"')
388-
continue
389-
logger.debug(f'Applying layerwise upcasting to layer "{name}"')
390-
apply_layerwise_upcasting_hook(submodule, storage_dtype, compute_dtype)
391-
return module
231+
if recurse:
232+
for child in module.children():
233+
reset_stateful_hooks(child, recurse)

0 commit comments

Comments
 (0)