Skip to content

Commit d1737e3

Browse files
committed
update
1 parent 661bde0 commit d1737e3

File tree

3 files changed

+361
-0
lines changed

3 files changed

+361
-0
lines changed

src/diffusers/hooks/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from ..utils import is_torch_available
2+
3+
4+
if is_torch_available():
5+
from .group_offloading import apply_group_offloading
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import re
16+
from typing import List, Optional, Union
17+
18+
import torch
19+
20+
from .hooks import HookRegistry, ModelHook
21+
22+
23+
_TRANSFORMER_STACK_IDENTIFIERS = [
24+
"transformer_blocks",
25+
"single_transformer_blocks",
26+
"temporal_transformer_blocks",
27+
"transformer_layers",
28+
"layers",
29+
"blocks",
30+
]
31+
32+
33+
class ModuleGroup:
34+
def __init__(
35+
self,
36+
modules: List[torch.nn.Module],
37+
offload_device: torch.device,
38+
onload_device: torch.device,
39+
offload_leader: torch.nn.Module,
40+
onload_leader: Optional[torch.nn.Module] = None,
41+
) -> None:
42+
self.modules = modules
43+
self.offload_device = offload_device
44+
self.onload_device = onload_device
45+
self.offload_leader = offload_leader
46+
self.onload_leader = onload_leader
47+
48+
49+
class GroupOffloadingHook(ModelHook):
50+
r"""
51+
A hook that offloads groups of torch.nn.Module to the CPU for storage and onloads to accelerator device for
52+
computation. Each group has one "onload leader" module that is responsible for onloading, and an "offload leader"
53+
module that is responsible for offloading.
54+
55+
This implementation assumes the following:
56+
- For offload_group_patterns="diffusers_block", the leader of a group can be automatically determined. For a custom
57+
user-provided regex pattern, the module that triggers its forward pass first is considered the leader.
58+
- The inputs are already on the correct device. This is expected because the hook does not modify the state of
59+
inputs or outputs at any stage of the forward pass. If an error is raised due to the device of modules and inputs
60+
not matching during the forward pass for any model in Diffusers, this means that the forward pass of the model is
61+
not written in the expected. Please open an issue at https://github.com/huggingface/diffusers/issues if you
62+
encounter such an error.
63+
"""
64+
65+
def __init__(self, group: ModuleGroup, offload_on_init: bool = True) -> None:
66+
self.group = group
67+
self.offload_on_init = offload_on_init
68+
69+
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
70+
if self.offload_on_init:
71+
self.offload_(module)
72+
return module
73+
74+
def onload_(self, module: torch.nn.Module) -> None:
75+
if self.group.onload_leader is not None and self.group.onload_leader == module:
76+
for group_module in self.group.modules:
77+
group_module.to(self.group.onload_device)
78+
79+
def offload_(self, module: torch.nn.Module) -> None:
80+
if self.group.offload_leader == module:
81+
for group_module in self.group.modules:
82+
group_module.to(self.group.offload_device)
83+
84+
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
85+
if self.group.onload_leader is None:
86+
self.group.onload_leader = module
87+
self.onload_(module)
88+
return args, kwargs
89+
90+
def post_forward(self, module: torch.nn.Module, output):
91+
self.offload_(module)
92+
return output
93+
94+
95+
def apply_group_offloading(
96+
module: torch.nn.Module,
97+
offload_group_patterns: Union[str, List[str]] = "diffusers_block",
98+
num_blocks_per_group: Optional[int] = None,
99+
offload_device: torch.device = torch.device("cpu"),
100+
onload_device: torch.device = torch.device("cuda"),
101+
force_offload: bool = True,
102+
) -> None:
103+
if offload_group_patterns == "diffusers_block":
104+
_apply_group_offloading_diffusers_block(
105+
module, num_blocks_per_group, offload_device, onload_device, force_offload
106+
)
107+
else:
108+
_apply_group_offloading_group_patterns(
109+
module, offload_group_patterns, offload_device, onload_device, force_offload
110+
)
111+
112+
113+
def _apply_group_offloading_diffusers_block(
114+
module: torch.nn.Module,
115+
num_blocks_per_group: int,
116+
offload_device: torch.device,
117+
onload_device: torch.device,
118+
force_offload: bool,
119+
) -> None:
120+
if num_blocks_per_group is None:
121+
raise ValueError("num_blocks_per_group must be provided when using GroupOffloadingType.DIFFUSERS_BLOCK.")
122+
123+
for transformer_stack_identifier in _TRANSFORMER_STACK_IDENTIFIERS:
124+
if not hasattr(module, transformer_stack_identifier) or not isinstance(
125+
getattr(module, transformer_stack_identifier), torch.nn.ModuleList
126+
):
127+
continue
128+
129+
transformer_stack = getattr(module, transformer_stack_identifier)
130+
num_blocks = len(transformer_stack)
131+
132+
for i in range(0, num_blocks, num_blocks_per_group):
133+
blocks = transformer_stack[i : i + num_blocks_per_group]
134+
group = ModuleGroup(
135+
blocks, offload_device, onload_device, offload_leader=blocks[-1], onload_leader=blocks[0]
136+
)
137+
should_offload = force_offload or i == 0
138+
_apply_group_offloading(group, should_offload)
139+
140+
141+
def _apply_group_offloading_group_patterns(
142+
module: torch.nn.Module,
143+
offload_group_patterns: List[str],
144+
offload_device: torch.device,
145+
onload_device: torch.device,
146+
force_offload: bool,
147+
) -> None:
148+
per_group_modules = []
149+
for i, offload_group_pattern in enumerate(offload_group_patterns):
150+
group_modules = []
151+
group_module_names = []
152+
for name, module in module.named_modules():
153+
if re.search(offload_group_pattern, name) is not None:
154+
group_modules.append(module)
155+
group_module_names.append(name)
156+
per_group_modules.append(
157+
{
158+
"modules": group_modules,
159+
"module_names": group_module_names,
160+
}
161+
)
162+
163+
# Check if there are any overlapping modules between groups
164+
for i, group in enumerate(per_group_modules):
165+
for j, other_group in enumerate(per_group_modules):
166+
if j <= i:
167+
continue
168+
if any(module_name in group["module_names"] for module_name in other_group["module_names"]):
169+
raise ValueError(
170+
f"Overlapping modules between groups {i} and {j}. Please ensure that offloading group patterns are mutually exclusive."
171+
)
172+
173+
# Apply offloading to each group
174+
for group in per_group_modules:
175+
# TODO: handle offload leader correctly
176+
group = ModuleGroup(group["modules"], offload_device, onload_device, offload_leader=group["modules"][-1])
177+
_apply_group_offloading(group, force_offload)
178+
179+
180+
def _apply_group_offloading(group: ModuleGroup, offload_on_init) -> None:
181+
for module in group.modules:
182+
hook = GroupOffloadingHook(group, offload_on_init=offload_on_init)
183+
registry = HookRegistry.check_if_exists_or_initialize(module)
184+
registry.register_hook(hook, "group_offloading")

src/diffusers/hooks/hooks.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import functools
16+
from typing import Any, Dict, Tuple
17+
18+
import torch
19+
20+
from ..utils.logging import get_logger
21+
22+
23+
logger = get_logger(__name__) # pylint: disable=invalid-name
24+
25+
26+
class ModelHook:
27+
r"""
28+
A hook that contains callbacks to be executed just before and after the forward method of a model.
29+
"""
30+
31+
_is_stateful = False
32+
33+
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
34+
r"""
35+
Hook that is executed when a model is initialized.
36+
37+
Args:
38+
module (`torch.nn.Module`):
39+
The module attached to this hook.
40+
"""
41+
return module
42+
43+
def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
44+
r"""
45+
Hook that is executed when a model is deinitalized.
46+
47+
Args:
48+
module (`torch.nn.Module`):
49+
The module attached to this hook.
50+
"""
51+
module.forward = module._old_forward
52+
del module._old_forward
53+
return module
54+
55+
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
56+
r"""
57+
Hook that is executed just before the forward method of the model.
58+
59+
Args:
60+
module (`torch.nn.Module`):
61+
The module whose forward pass will be executed just after this event.
62+
args (`Tuple[Any]`):
63+
The positional arguments passed to the module.
64+
kwargs (`Dict[Str, Any]`):
65+
The keyword arguments passed to the module.
66+
67+
Returns:
68+
`Tuple[Tuple[Any], Dict[Str, Any]]`:
69+
A tuple with the treated `args` and `kwargs`.
70+
"""
71+
return args, kwargs
72+
73+
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
74+
r"""
75+
Hook that is executed just after the forward method of the model.
76+
77+
Args:
78+
module (`torch.nn.Module`):
79+
The module whose forward pass been executed just before this event.
80+
output (`Any`):
81+
The output of the module.
82+
83+
Returns:
84+
`Any`: The processed `output`.
85+
"""
86+
return output
87+
88+
def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
89+
r"""
90+
Hook that is executed when the hook is detached from a module.
91+
92+
Args:
93+
module (`torch.nn.Module`):
94+
The module detached from this hook.
95+
"""
96+
return module
97+
98+
def reset_state(self, module: torch.nn.Module):
99+
if self._is_stateful:
100+
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
101+
return module
102+
103+
104+
class HookRegistry:
105+
def __init__(self, module_ref: torch.nn.Module) -> None:
106+
super().__init__()
107+
108+
self.hooks: Dict[str, ModelHook] = {}
109+
110+
self._module_ref = module_ref
111+
self._hook_order = []
112+
113+
def register_hook(self, hook: ModelHook, name: str) -> None:
114+
if name in self.hooks.keys():
115+
logger.warning(f"Hook with name {name} already exists, replacing it.")
116+
117+
if hasattr(self._module_ref, "_old_forward"):
118+
old_forward = self._module_ref._old_forward
119+
else:
120+
old_forward = self._module_ref.forward
121+
self._module_ref._old_forward = self._module_ref.forward
122+
123+
self._module_ref = hook.initialize_hook(self._module_ref)
124+
125+
if hasattr(hook, "new_forward"):
126+
new_forward = hook.new_forward
127+
else:
128+
129+
def new_forward(module, *args, **kwargs):
130+
args, kwargs = hook.pre_forward(module, *args, **kwargs)
131+
output = old_forward(*args, **kwargs)
132+
return hook.post_forward(module, output)
133+
134+
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
135+
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
136+
if "GraphModuleImpl" in str(type(self._module_ref)):
137+
self._module_ref.__class__.forward = functools.update_wrapper(
138+
functools.partial(new_forward, self._module_ref), old_forward
139+
)
140+
else:
141+
self._module_ref.forward = functools.update_wrapper(
142+
functools.partial(new_forward, self._module_ref), old_forward
143+
)
144+
145+
self.hooks[name] = hook
146+
self._hook_order.append(name)
147+
148+
def get_hook(self, name: str) -> ModelHook:
149+
if name not in self.hooks.keys():
150+
raise ValueError(f"Hook with name {name} not found.")
151+
return self.hooks[name]
152+
153+
def remove_hook(self, name: str) -> None:
154+
if name not in self.hooks.keys():
155+
raise ValueError(f"Hook with name {name} not found.")
156+
self.hooks[name].deinitalize_hook(self._module_ref)
157+
del self.hooks[name]
158+
self._hook_order.remove(name)
159+
160+
@classmethod
161+
def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry":
162+
if not hasattr(module, "_diffusers_hook"):
163+
module._diffusers_hook = cls(module)
164+
return module._diffusers_hook
165+
166+
def __repr__(self) -> str:
167+
hook_repr = ""
168+
for i, hook_name in enumerate(self._hook_order):
169+
hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})"
170+
if i < len(self._hook_order) - 1:
171+
hook_repr += "\n"
172+
return f"HookRegistry(\n{hook_repr}\n)"

0 commit comments

Comments
 (0)