Skip to content

Commit d9e7372

Browse files
committed
init
1 parent 6131a93 commit d9e7372

File tree

3 files changed

+248
-0
lines changed

3 files changed

+248
-0
lines changed

src/diffusers/models/hooks.py

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import functools
16+
from typing import Any, Callable, Dict, Tuple
17+
18+
import torch
19+
20+
21+
# Reference: https://github.com/huggingface/accelerate/blob/ba7ab93f5e688466ea56908ea3b056fae2f9a023/src/accelerate/hooks.py
22+
class ModelHook:
23+
r"""
24+
A hook that contains callbacks to be executed just before and after the forward method of a model. The difference
25+
with PyTorch existing hooks is that they get passed along the kwargs.
26+
"""
27+
28+
def init_hook(self, module: torch.nn.Module) -> torch.nn.Module:
29+
r"""
30+
Hook that is executed when a model is initialized.
31+
Args:
32+
module (`torch.nn.Module`):
33+
The module attached to this hook.
34+
"""
35+
return module
36+
37+
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
38+
r"""
39+
Hook that is executed just before the forward method of the model.
40+
Args:
41+
module (`torch.nn.Module`):
42+
The module whose forward pass will be executed just after this event.
43+
args (`Tuple[Any]`):
44+
The positional arguments passed to the module.
45+
kwargs (`Dict[Str, Any]`):
46+
The keyword arguments passed to the module.
47+
Returns:
48+
`Tuple[Tuple[Any], Dict[Str, Any]]`:
49+
A tuple with the treated `args` and `kwargs`.
50+
"""
51+
return args, kwargs
52+
53+
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
54+
r"""
55+
Hook that is executed just after the forward method of the model.
56+
Args:
57+
module (`torch.nn.Module`):
58+
The module whose forward pass been executed just before this event.
59+
output (`Any`):
60+
The output of the module.
61+
Returns:
62+
`Any`: The processed `output`.
63+
"""
64+
return output
65+
66+
def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
67+
r"""
68+
Hook that is executed when the hook is detached from a module.
69+
Args:
70+
module (`torch.nn.Module`):
71+
The module detached from this hook.
72+
"""
73+
return module
74+
75+
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
76+
return module
77+
78+
79+
class SequentialHook(ModelHook):
80+
r"""A hook that can contain several hooks and iterates through them at each event."""
81+
82+
def __init__(self, *hooks):
83+
self.hooks = hooks
84+
85+
def init_hook(self, module):
86+
for hook in self.hooks:
87+
module = hook.init_hook(module)
88+
return module
89+
90+
def pre_forward(self, module, *args, **kwargs):
91+
for hook in self.hooks:
92+
args, kwargs = hook.pre_forward(module, *args, **kwargs)
93+
return args, kwargs
94+
95+
def post_forward(self, module, output):
96+
for hook in self.hooks:
97+
output = hook.post_forward(module, output)
98+
return output
99+
100+
def detach_hook(self, module):
101+
for hook in self.hooks:
102+
module = hook.detach_hook(module)
103+
return module
104+
105+
def reset_state(self, module):
106+
for hook in self.hooks:
107+
module = hook.reset_state(module)
108+
return module
109+
110+
111+
class FasterCacheHook(ModelHook):
112+
def __init__(
113+
self,
114+
skip_callback: Callable[[torch.nn.Module], bool],
115+
) -> None:
116+
super().__init__()
117+
118+
self.skip_callback = skip_callback
119+
120+
self.cache = None
121+
self._iteration = 0
122+
123+
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
124+
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
125+
126+
if self.cache is not None and self.skip_callback(module):
127+
output = self.cache
128+
else:
129+
output = module._old_forward(*args, **kwargs)
130+
131+
return module._diffusers_hook.post_forward(module, output)
132+
133+
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
134+
self.cache = output
135+
return output
136+
137+
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
138+
self.cache = None
139+
self._iteration = 0
140+
return module
141+
142+
143+
def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False):
144+
r"""
145+
Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
146+
this behavior and restore the original `forward` method, use `remove_hook_from_module`.
147+
<Tip warning={true}>
148+
If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks
149+
together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class.
150+
</Tip>
151+
Args:
152+
module (`torch.nn.Module`):
153+
The module to attach a hook to.
154+
hook (`ModelHook`):
155+
The hook to attach.
156+
append (`bool`, *optional*, defaults to `False`):
157+
Whether the hook should be chained with an existing one (if module already contains a hook) or not.
158+
Returns:
159+
`torch.nn.Module`:
160+
The same module, with the hook attached (the module is modified in place, so the result can be discarded).
161+
"""
162+
original_hook = hook
163+
164+
if append and getattr(module, "_diffusers_hook", None) is not None:
165+
old_hook = module._diffusers_hook
166+
remove_hook_from_module(module)
167+
hook = SequentialHook(old_hook, hook)
168+
169+
if hasattr(module, "_diffusers_hook") and hasattr(module, "_old_forward"):
170+
# If we already put some hook on this module, we replace it with the new one.
171+
old_forward = module._old_forward
172+
else:
173+
old_forward = module.forward
174+
module._old_forward = old_forward
175+
176+
module = hook.init_hook(module)
177+
module._diffusers_hook = hook
178+
179+
if hasattr(original_hook, "new_forward"):
180+
new_forward = original_hook.new_forward
181+
else:
182+
183+
def new_forward(module, *args, **kwargs):
184+
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
185+
output = module._old_forward(*args, **kwargs)
186+
return module._diffusers_hook.post_forward(module, output)
187+
188+
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
189+
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
190+
if "GraphModuleImpl" in str(type(module)):
191+
module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
192+
else:
193+
module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
194+
195+
return module
196+
197+
198+
def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> torch.nn.Module:
199+
"""
200+
Removes any hook attached to a module via `add_hook_to_module`.
201+
Args:
202+
module (`torch.nn.Module`):
203+
The module to attach a hook to.
204+
recurse (`bool`, defaults to `False`):
205+
Whether to remove the hooks recursively
206+
Returns:
207+
`torch.nn.Module`:
208+
The same module, with the hook detached (the module is modified in place, so the result can be discarded).
209+
"""
210+
211+
if hasattr(module, "_diffusers_hook"):
212+
module._diffusers_hook.detach_hook(module)
213+
delattr(module, "_diffusers_hook")
214+
215+
if hasattr(module, "_old_forward"):
216+
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
217+
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
218+
if "GraphModuleImpl" in str(type(module)):
219+
module.__class__.forward = module._old_forward
220+
else:
221+
module.forward = module._old_forward
222+
delattr(module, "_old_forward")
223+
224+
if recurse:
225+
for child in module.children():
226+
remove_hook_from_module(child, recurse)
227+
228+
return module
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,6 +1088,10 @@ def maybe_free_model_hooks(self):
10881088
is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it
10891089
functions correctly when applying enable_model_cpu_offload.
10901090
"""
1091+
1092+
if hasattr(self, "_diffusers_hook"):
1093+
self._diffusers_hook.reset_state()
1094+
10911095
if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0:
10921096
# `enable_model_cpu_offload` has not be called, so silently do nothing
10931097
return

0 commit comments

Comments
 (0)