Skip to content

Commit 7858f2c

Browse files
committed
update
1 parent 558c64e commit 7858f2c

File tree

9 files changed

+308
-510
lines changed

9 files changed

+308
-510
lines changed

src/diffusers/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@
107107
"I2VGenXLUNet",
108108
"Kandinsky3UNet",
109109
"LatteTransformer3DModel",
110-
"LayerwiseUpcastingGranularity",
111110
"LTXVideoTransformer3DModel",
112111
"LuminaNextDiT2DModel",
113112
"MochiTransformer3DModel",
@@ -136,8 +135,6 @@
136135
"UNetSpatioTemporalConditionModel",
137136
"UVit2DModel",
138137
"VQModel",
139-
"apply_layerwise_upcasting",
140-
"apply_layerwise_upcasting_hook",
141138
]
142139
)
143140
_import_structure["optimization"] = [
@@ -620,7 +617,6 @@
620617
I2VGenXLUNet,
621618
Kandinsky3UNet,
622619
LatteTransformer3DModel,
623-
LayerwiseUpcastingGranularity,
624620
LTXVideoTransformer3DModel,
625621
LuminaNextDiT2DModel,
626622
MochiTransformer3DModel,
@@ -648,8 +644,6 @@
648644
UNetSpatioTemporalConditionModel,
649645
UVit2DModel,
650646
VQModel,
651-
apply_layerwise_upcasting,
652-
apply_layerwise_upcasting_hook,
653647
)
654648
from .optimization import (
655649
get_constant_schedule,

src/diffusers/hooks/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from ..utils import is_torch_available
2+
3+
4+
if is_torch_available():
5+
from .layerwise_upcasting import apply_layerwise_upcasting, apply_layerwise_upcasting_hook

src/diffusers/hooks/hooks.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import functools
16+
from typing import Any, Dict, Tuple
17+
18+
import torch
19+
20+
from ..utils.logging import get_logger
21+
22+
23+
logger = get_logger(__name__) # pylint: disable=invalid-name
24+
25+
26+
class ModelHook:
27+
r"""
28+
A hook that contains callbacks to be executed just before and after the forward method of a model.
29+
"""
30+
31+
_is_stateful = False
32+
33+
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
34+
r"""
35+
Hook that is executed when a model is initialized.
36+
37+
Args:
38+
module (`torch.nn.Module`):
39+
The module attached to this hook.
40+
"""
41+
return module
42+
43+
def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
44+
r"""
45+
Hook that is executed when a model is deinitalized.
46+
47+
Args:
48+
module (`torch.nn.Module`):
49+
The module attached to this hook.
50+
"""
51+
module.forward = module._old_forward
52+
del module._old_forward
53+
return module
54+
55+
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
56+
r"""
57+
Hook that is executed just before the forward method of the model.
58+
59+
Args:
60+
module (`torch.nn.Module`):
61+
The module whose forward pass will be executed just after this event.
62+
args (`Tuple[Any]`):
63+
The positional arguments passed to the module.
64+
kwargs (`Dict[Str, Any]`):
65+
The keyword arguments passed to the module.
66+
67+
Returns:
68+
`Tuple[Tuple[Any], Dict[Str, Any]]`:
69+
A tuple with the treated `args` and `kwargs`.
70+
"""
71+
return args, kwargs
72+
73+
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
74+
r"""
75+
Hook that is executed just after the forward method of the model.
76+
77+
Args:
78+
module (`torch.nn.Module`):
79+
The module whose forward pass been executed just before this event.
80+
output (`Any`):
81+
The output of the module.
82+
83+
Returns:
84+
`Any`: The processed `output`.
85+
"""
86+
return output
87+
88+
def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
89+
r"""
90+
Hook that is executed when the hook is detached from a module.
91+
92+
Args:
93+
module (`torch.nn.Module`):
94+
The module detached from this hook.
95+
"""
96+
return module
97+
98+
def reset_state(self, module: torch.nn.Module):
99+
if self._is_stateful:
100+
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
101+
return module
102+
103+
104+
class HookRegistry:
105+
def __init__(self, module_ref: torch.nn.Module) -> None:
106+
super().__init__()
107+
108+
self.hooks: Dict[str, ModelHook] = {}
109+
110+
self._module_ref = module_ref
111+
self._hook_order = []
112+
113+
def register_hook(self, hook: ModelHook, name: str) -> None:
114+
if name in self.hooks.keys():
115+
logger.warning(f"Hook with name {name} already exists, replacing it.")
116+
117+
if hasattr(self._module_ref, "_old_forward"):
118+
old_forward = self._module_ref._old_forward
119+
else:
120+
old_forward = self._module_ref.forward
121+
self._module_ref._old_forward = self._module_ref.forward
122+
123+
self._module_ref = hook.initialize_hook(self._module_ref)
124+
125+
if hasattr(hook, "new_forward"):
126+
new_forward = hook.new_forward
127+
else:
128+
129+
def new_forward(module, *args, **kwargs):
130+
args, kwargs = hook.pre_forward(module, *args, **kwargs)
131+
output = old_forward(*args, **kwargs)
132+
return hook.post_forward(module, output)
133+
134+
new_forward = functools.update_wrapper(new_forward, old_forward)
135+
self._module_ref.forward = new_forward.__get__(self._module_ref)
136+
137+
self.hooks[name] = hook
138+
self._hook_order.append(name)
139+
140+
def get_hook(self, name: str) -> ModelHook:
141+
if name not in self.hooks.keys():
142+
raise ValueError(f"Hook with name {name} not found.")
143+
return self.hooks[name]
144+
145+
def remove_hook(self, name: str) -> None:
146+
if name not in self.hooks.keys():
147+
raise ValueError(f"Hook with name {name} not found.")
148+
self.hooks[name].deinitalize_hook(self._module_ref)
149+
del self.hooks[name]
150+
self._hook_order.remove(name)
151+
152+
@classmethod
153+
def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry":
154+
if not hasattr(module, "_diffusers_hook"):
155+
module._diffusers_hook = cls(module)
156+
return module._diffusers_hook
157+
158+
def __repr__(self) -> str:
159+
hook_repr = ""
160+
for i, hook_name in enumerate(self._hook_order):
161+
hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})"
162+
if i < len(self._hook_order) - 1:
163+
hook_repr += "\n"
164+
return f"HookRegistry(\n{hook_repr}\n)"
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import re
16+
from typing import List, Type
17+
18+
import torch
19+
20+
from ..utils import get_logger
21+
from .hooks import HookRegistry, ModelHook
22+
23+
24+
logger = get_logger(__name__) # pylint: disable=invalid-name
25+
26+
27+
# fmt: off
28+
_SUPPORTED_PYTORCH_LAYERS = [
29+
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
30+
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
31+
torch.nn.Linear,
32+
]
33+
34+
_DEFAULT_SKIP_MODULES_PATTERN = ["pos_embed", "patch_embed", "norm"]
35+
# fmt: on
36+
37+
38+
class LayerwiseUpcastingHook(ModelHook):
39+
r"""
40+
A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype
41+
for storage. This process may lead to quality loss in the output, but can significantly reduce the memory
42+
footprint.
43+
"""
44+
45+
_is_stateful = False
46+
47+
def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype) -> None:
48+
self.storage_dtype = storage_dtype
49+
self.compute_dtype = compute_dtype
50+
51+
def initialize_hook(self, module: torch.nn.Module):
52+
module.to(dtype=self.storage_dtype)
53+
return module
54+
55+
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
56+
module.to(dtype=self.compute_dtype)
57+
return args, kwargs
58+
59+
def post_forward(self, module: torch.nn.Module, output):
60+
module.to(dtype=self.storage_dtype)
61+
return output
62+
63+
64+
def apply_layerwise_upcasting(
65+
module: torch.nn.Module,
66+
storage_dtype: torch.dtype,
67+
compute_dtype: torch.dtype,
68+
skip_modules_pattern: List[str] = _DEFAULT_SKIP_MODULES_PATTERN,
69+
skip_modules_classes: List[Type[torch.nn.Module]] = [],
70+
) -> torch.nn.Module:
71+
r"""
72+
Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any
73+
nn.Module using diffusers layers or pytorch primitives.
74+
75+
Args:
76+
module (`torch.nn.Module`):
77+
The module whose leaf modules will be cast to a high precision dtype for computation, and to a low
78+
precision dtype for storage.
79+
storage_dtype (`torch.dtype`):
80+
The dtype to cast the module to before/after the forward pass for storage.
81+
compute_dtype (`torch.dtype`):
82+
The dtype to cast the module to during the forward pass for computation.
83+
skip_modules_pattern (`List[str]`, defaults to `["pos_embed", "patch_embed", "norm"]`):
84+
A list of patterns to match the names of the modules to skip during the layerwise upcasting process.
85+
skip_modules_classes (`List[Type[torch.nn.Module]]`, defaults to `[]`):
86+
A list of module classes to skip during the layerwise upcasting process.
87+
"""
88+
for name, submodule in module.named_modules():
89+
if (
90+
any(re.search(pattern, name) for pattern in skip_modules_pattern)
91+
or any(isinstance(submodule, module_class) for module_class in skip_modules_classes)
92+
or not isinstance(submodule, tuple(_SUPPORTED_PYTORCH_LAYERS))
93+
or len(list(submodule.children())) > 0
94+
):
95+
logger.debug(f'Skipping layerwise upcasting for layer "{name}"')
96+
continue
97+
logger.debug(f'Applying layerwise upcasting to layer "{name}"')
98+
apply_layerwise_upcasting_hook(submodule, storage_dtype, compute_dtype)
99+
return module
100+
101+
102+
def apply_layerwise_upcasting_hook(
103+
module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype
104+
) -> torch.nn.Module:
105+
r"""
106+
Applies a `LayerwiseUpcastingHook` to a given module.
107+
108+
Args:
109+
module (`torch.nn.Module`):
110+
The module to attach the hook to.
111+
storage_dtype (`torch.dtype`):
112+
The dtype to cast the module to before the forward pass.
113+
compute_dtype (`torch.dtype`):
114+
The dtype to cast the module to during the forward pass.
115+
116+
Returns:
117+
`torch.nn.Module`:
118+
The same module, with the hook attached (the module is modified in place).
119+
"""
120+
registry = HookRegistry.check_if_exists_or_initialize(module)
121+
hook = LayerwiseUpcastingHook(storage_dtype, compute_dtype)
122+
registry.register_hook(hook, "layerwise_upcasting")

src/diffusers/models/__init__.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,6 @@
5151
_import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
5252
_import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"]
5353
_import_structure["embeddings"] = ["ImageProjection"]
54-
_import_structure["layerwise_upcasting_utils"] = [
55-
"LayerwiseUpcastingGranularity",
56-
"apply_layerwise_upcasting",
57-
"apply_layerwise_upcasting_hook",
58-
]
5954
_import_structure["modeling_utils"] = ["ModelMixin"]
6055
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
6156
_import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"]
@@ -128,11 +123,6 @@
128123
UNetControlNetXSModel,
129124
)
130125
from .embeddings import ImageProjection
131-
from .layerwise_upcasting_utils import (
132-
LayerwiseUpcastingGranularity,
133-
apply_layerwise_upcasting,
134-
apply_layerwise_upcasting_hook,
135-
)
136126
from .modeling_utils import ModelMixin
137127
from .transformers import (
138128
AllegroTransformer3DModel,

0 commit comments

Comments
 (0)