Skip to content

Commit 2436b3f

Browse files
committed
refactor
1 parent 35296eb commit 2436b3f

File tree

6 files changed

+237
-322
lines changed

6 files changed

+237
-322
lines changed

src/diffusers/__init__.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@
7575
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
7676

7777
else:
78+
_import_structure["hooks"].extend(
79+
[
80+
"PyramidAttentionBroadcastConfig",
81+
"apply_pyramid_attention_broadcast",
82+
]
83+
)
7884
_import_structure["models"].extend(
7985
[
8086
"AllegroTransformer3DModel",
@@ -336,7 +342,6 @@
336342
"PixArtAlphaPipeline",
337343
"PixArtSigmaPAGPipeline",
338344
"PixArtSigmaPipeline",
339-
"PyramidAttentionBroadcastConfig",
340345
"ReduxImageEncoder",
341346
"SanaPAGPipeline",
342347
"SanaPipeline",
@@ -423,8 +428,6 @@
423428
"WuerstchenCombinedPipeline",
424429
"WuerstchenDecoderPipeline",
425430
"WuerstchenPriorPipeline",
426-
"apply_pyramid_attention_broadcast",
427-
"apply_pyramid_attention_broadcast_on_module",
428431
]
429432
)
430433

@@ -589,6 +592,7 @@
589592
except OptionalDependencyNotAvailable:
590593
from .utils.dummy_pt_objects import * # noqa F403
591594
else:
595+
from .hooks import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
592596
from .models import (
593597
AllegroTransformer3DModel,
594598
AsymmetricAutoencoderKL,
@@ -828,7 +832,6 @@
828832
PixArtAlphaPipeline,
829833
PixArtSigmaPAGPipeline,
830834
PixArtSigmaPipeline,
831-
PyramidAttentionBroadcastConfig,
832835
ReduxImageEncoder,
833836
SanaPAGPipeline,
834837
SanaPipeline,
@@ -913,8 +916,6 @@
913916
WuerstchenCombinedPipeline,
914917
WuerstchenDecoderPipeline,
915918
WuerstchenPriorPipeline,
916-
apply_pyramid_attention_broadcast,
917-
apply_pyramid_attention_broadcast_on_module,
918919
)
919920

920921
try:

src/diffusers/hooks/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from ..utils import is_torch_available
2+
3+
4+
if is_torch_available():
5+
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast

src/diffusers/hooks/hooks.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import functools
16+
from typing import Any, Dict, Tuple
17+
18+
import torch
19+
20+
from ..utils.logging import get_logger
21+
22+
23+
logger = get_logger(__name__) # pylint: disable=invalid-name
24+
25+
26+
class ModelHook:
27+
r"""
28+
A hook that contains callbacks to be executed just before and after the forward method of a model.
29+
"""
30+
31+
_is_stateful = False
32+
33+
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
34+
r"""
35+
Hook that is executed when a model is initialized.
36+
37+
Args:
38+
module (`torch.nn.Module`):
39+
The module attached to this hook.
40+
"""
41+
return module
42+
43+
def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
44+
r"""
45+
Hook that is executed when a model is deinitalized.
46+
47+
Args:
48+
module (`torch.nn.Module`):
49+
The module attached to this hook.
50+
"""
51+
module.forward = module._old_forward
52+
del module._old_forward
53+
return module
54+
55+
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
56+
r"""
57+
Hook that is executed just before the forward method of the model.
58+
59+
Args:
60+
module (`torch.nn.Module`):
61+
The module whose forward pass will be executed just after this event.
62+
args (`Tuple[Any]`):
63+
The positional arguments passed to the module.
64+
kwargs (`Dict[Str, Any]`):
65+
The keyword arguments passed to the module.
66+
Returns:
67+
`Tuple[Tuple[Any], Dict[Str, Any]]`:
68+
A tuple with the treated `args` and `kwargs`.
69+
"""
70+
return args, kwargs
71+
72+
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
73+
r"""
74+
Hook that is executed just after the forward method of the model.
75+
76+
Args:
77+
module (`torch.nn.Module`):
78+
The module whose forward pass been executed just before this event.
79+
output (`Any`):
80+
The output of the module.
81+
Returns:
82+
`Any`: The processed `output`.
83+
"""
84+
return output
85+
86+
def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
87+
r"""
88+
Hook that is executed when the hook is detached from a module.
89+
90+
Args:
91+
module (`torch.nn.Module`):
92+
The module detached from this hook.
93+
"""
94+
return module
95+
96+
def reset_state(self, module: torch.nn.Module):
97+
if self._is_stateful:
98+
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
99+
return module
100+
101+
102+
class HookRegistry:
103+
def __init__(self, module_ref: torch.nn.Module) -> None:
104+
super().__init__()
105+
106+
self.hooks: Dict[str, ModelHook] = {}
107+
108+
self._module_ref = module_ref
109+
self._hook_order = []
110+
111+
def register_hook(self, hook: ModelHook, name: str) -> None:
112+
if name in self.hooks.keys():
113+
logger.warning(f"Hook with name {name} already exists, replacing it.")
114+
115+
if hasattr(self._module_ref, "_old_forward"):
116+
old_forward = self._module_ref._old_forward
117+
else:
118+
old_forward = self._module_ref.forward
119+
self._module_ref._old_forward = self._module_ref.forward
120+
121+
self._module_ref = hook.initialize_hook(self._module_ref)
122+
123+
if hasattr(hook, "new_forward"):
124+
new_forward = hook.new_forward
125+
else:
126+
127+
def new_forward(module, *args, **kwargs):
128+
args, kwargs = hook.pre_forward(module, *args, **kwargs)
129+
output = old_forward(*args, **kwargs)
130+
return hook.post_forward(module, output)
131+
132+
new_forward = functools.update_wrapper(new_forward, old_forward)
133+
self._module_ref.forward = new_forward.__get__(self._module_ref)
134+
135+
self.hooks[name] = hook
136+
self._hook_order.append(name)
137+
138+
def get_hook(self, name: str) -> ModelHook:
139+
if name not in self.hooks.keys():
140+
raise ValueError(f"Hook with name {name} not found.")
141+
return self.hooks[name]
142+
143+
def remove_hook(self, name: str) -> None:
144+
if name not in self.hooks.keys():
145+
raise ValueError(f"Hook with name {name} not found.")
146+
self.hooks[name].deinitalize_hook(self._module_ref)
147+
del self.hooks[name]
148+
self._hook_order.remove(name)
149+
150+
@classmethod
151+
def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry":
152+
if not hasattr(module, "_diffusers_hook"):
153+
module._diffusers_hook = cls(module)
154+
return module._diffusers_hook
155+
156+
def __repr__(self) -> str:
157+
hook_repr = ""
158+
for i, hook_name in enumerate(self._hook_order):
159+
hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})"
160+
if i < len(self._hook_order) - 1:
161+
hook_repr += "\n"
162+
return f"HookRegistry(\n{hook_repr}\n)"

0 commit comments

Comments
 (0)