Skip to content

Commit 13d5af7

Browse files
committed
init
1 parent 4b9f1c7 commit 13d5af7

File tree

2 files changed

+266
-0
lines changed

2 files changed

+266
-0
lines changed

src/diffusers/models/hooks.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import functools
16+
from typing import Any, Dict, Tuple
17+
18+
import torch
19+
20+
21+
# Reference: https://github.com/huggingface/accelerate/blob/ba7ab93f5e688466ea56908ea3b056fae2f9a023/src/accelerate/hooks.py
22+
class ModelHook:
23+
r"""
24+
A hook that contains callbacks to be executed just before and after the forward method of a model.
25+
"""
26+
27+
_is_stateful = False
28+
29+
def init_hook(self, module: torch.nn.Module) -> torch.nn.Module:
30+
r"""
31+
Hook that is executed when a model is initialized.
32+
Args:
33+
module (`torch.nn.Module`):
34+
The module attached to this hook.
35+
"""
36+
return module
37+
38+
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
39+
r"""
40+
Hook that is executed just before the forward method of the model.
41+
Args:
42+
module (`torch.nn.Module`):
43+
The module whose forward pass will be executed just after this event.
44+
args (`Tuple[Any]`):
45+
The positional arguments passed to the module.
46+
kwargs (`Dict[Str, Any]`):
47+
The keyword arguments passed to the module.
48+
Returns:
49+
`Tuple[Tuple[Any], Dict[Str, Any]]`:
50+
A tuple with the treated `args` and `kwargs`.
51+
"""
52+
return args, kwargs
53+
54+
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
55+
r"""
56+
Hook that is executed just after the forward method of the model.
57+
Args:
58+
module (`torch.nn.Module`):
59+
The module whose forward pass been executed just before this event.
60+
output (`Any`):
61+
The output of the module.
62+
Returns:
63+
`Any`: The processed `output`.
64+
"""
65+
return output
66+
67+
def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
68+
r"""
69+
Hook that is executed when the hook is detached from a module.
70+
Args:
71+
module (`torch.nn.Module`):
72+
The module detached from this hook.
73+
"""
74+
return module
75+
76+
def reset_state(self, module: torch.nn.Module):
77+
if self._is_stateful:
78+
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
79+
80+
81+
class SequentialHook(ModelHook):
82+
r"""A hook that can contain several hooks and iterates through them at each event."""
83+
84+
def __init__(self, *hooks):
85+
self.hooks = hooks
86+
87+
def init_hook(self, module):
88+
for hook in self.hooks:
89+
module = hook.init_hook(module)
90+
return module
91+
92+
def pre_forward(self, module, *args, **kwargs):
93+
for hook in self.hooks:
94+
args, kwargs = hook.pre_forward(module, *args, **kwargs)
95+
return args, kwargs
96+
97+
def post_forward(self, module, output):
98+
for hook in self.hooks:
99+
output = hook.post_forward(module, output)
100+
return output
101+
102+
def detach_hook(self, module):
103+
for hook in self.hooks:
104+
module = hook.detach_hook(module)
105+
return module
106+
107+
def reset_state(self, module):
108+
for hook in self.hooks:
109+
if hook._is_stateful:
110+
hook.reset_state(module)
111+
112+
113+
def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False) -> torch.nn.Module:
114+
r"""
115+
Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
116+
this behavior and restore the original `forward` method, use `remove_hook_from_module`.
117+
<Tip warning={true}>
118+
If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks
119+
together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class.
120+
</Tip>
121+
Args:
122+
module (`torch.nn.Module`):
123+
The module to attach a hook to.
124+
hook (`ModelHook`):
125+
The hook to attach.
126+
append (`bool`, *optional*, defaults to `False`):
127+
Whether the hook should be chained with an existing one (if module already contains a hook) or not.
128+
Returns:
129+
`torch.nn.Module`:
130+
The same module, with the hook attached (the module is modified in place, so the result can be discarded).
131+
"""
132+
original_hook = hook
133+
134+
if append and getattr(module, "_diffusers_hook", None) is not None:
135+
old_hook = module._diffusers_hook
136+
remove_hook_from_module(module)
137+
hook = SequentialHook(old_hook, hook)
138+
139+
if hasattr(module, "_diffusers_hook") and hasattr(module, "_old_forward"):
140+
# If we already put some hook on this module, we replace it with the new one.
141+
old_forward = module._old_forward
142+
else:
143+
old_forward = module.forward
144+
module._old_forward = old_forward
145+
146+
module = hook.init_hook(module)
147+
module._diffusers_hook = hook
148+
149+
if hasattr(original_hook, "new_forward"):
150+
new_forward = original_hook.new_forward
151+
else:
152+
153+
def new_forward(module, *args, **kwargs):
154+
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
155+
output = module._old_forward(*args, **kwargs)
156+
return module._diffusers_hook.post_forward(module, output)
157+
158+
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
159+
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
160+
if "GraphModuleImpl" in str(type(module)):
161+
module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
162+
else:
163+
module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
164+
165+
return module
166+
167+
168+
def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> torch.nn.Module:
169+
"""
170+
Removes any hook attached to a module via `add_hook_to_module`.
171+
Args:
172+
module (`torch.nn.Module`):
173+
The module to attach a hook to.
174+
recurse (`bool`, defaults to `False`):
175+
Whether to remove the hooks recursively
176+
Returns:
177+
`torch.nn.Module`:
178+
The same module, with the hook detached (the module is modified in place, so the result can be discarded).
179+
"""
180+
181+
if hasattr(module, "_diffusers_hook"):
182+
module._diffusers_hook.detach_hook(module)
183+
delattr(module, "_diffusers_hook")
184+
185+
if hasattr(module, "_old_forward"):
186+
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
187+
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
188+
if "GraphModuleImpl" in str(type(module)):
189+
module.__class__.forward = module._old_forward
190+
else:
191+
module.forward = module._old_forward
192+
delattr(module, "_old_forward")
193+
194+
if recurse:
195+
for child in module.children():
196+
remove_hook_from_module(child, recurse)
197+
198+
return module
199+
200+
201+
def reset_stateful_hooks(module: torch.nn.Module, recurse: bool = False):
202+
"""
203+
Resets the state of all stateful hooks attached to a module.
204+
Args:
205+
module (`torch.nn.Module`):
206+
The module to reset the stateful hooks from.
207+
"""
208+
if hasattr(module, "_diffusers_hook") and (
209+
module._diffusers_hook._is_stateful or isinstance(module._diffusers_hook, SequentialHook)
210+
):
211+
module._diffusers_hook.reset_state(module)
212+
213+
if recurse:
214+
for child in module.children():
215+
reset_stateful_hooks(child, recurse)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import dataclass
16+
from typing import Any, Callable, Dict, Optional, Tuple
17+
18+
import torch
19+
import torch.nn as nn
20+
21+
from ..models.hooks import ModelHook, add_hook_to_module
22+
from ..utils import logging
23+
from .pipeline_utils import DiffusionPipeline
24+
25+
26+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27+
28+
29+
@dataclass
30+
class TeaCacheConfig:
31+
pass
32+
33+
34+
class TeaCacheDenoiserState:
35+
def __init__(self):
36+
self.iteration = 0
37+
self.accumulated_l1_difference = 0.0
38+
self.timestep_modulated_cache = None
39+
40+
def reset(self):
41+
self.iteration = 0
42+
self.accumulated_l1_difference = 0.0
43+
self.timestep_modulated_cache = None
44+
45+
46+
def apply_teacache(pipeline: DiffusionPipeline, config: TeaCacheConfig, denoiser: Optional[nn.Module]) -> None:
47+
r"""Applies [TeaCache]() to a given pipeline or denoiser module.
48+
49+
Args:
50+
TODO
51+
"""

0 commit comments

Comments
 (0)