Skip to content

Commit 6047114

Browse files
committed
update
1 parent 80c5acd commit 6047114

File tree

3 files changed

+298
-67
lines changed

3 files changed

+298
-67
lines changed

src/diffusers/models/hooks.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import functools
16-
from typing import Any, Dict, Tuple
16+
from typing import Any, Dict, List, Tuple
1717

1818
import torch
1919

@@ -25,6 +25,8 @@ class ModelHook:
2525
with PyTorch existing hooks is that they get passed along the kwargs.
2626
"""
2727

28+
_is_stateful = False
29+
2830
def init_hook(self, module: torch.nn.Module) -> torch.nn.Module:
2931
r"""
3032
Hook that is executed when a model is initialized.
@@ -75,13 +77,17 @@ def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
7577
The module detached from this hook.
7678
"""
7779
return module
80+
81+
def reset_state(self):
82+
if self._is_stateful:
83+
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
7884

7985

8086
class SequentialHook(ModelHook):
8187
r"""A hook that can contain several hooks and iterates through them at each event."""
8288

8389
def __init__(self, *hooks):
84-
self.hooks = hooks
90+
self.hooks: List[ModelHook] = hooks
8591

8692
def init_hook(self, module):
8793
for hook in self.hooks:
@@ -102,6 +108,11 @@ def detach_hook(self, module):
102108
for hook in self.hooks:
103109
module = hook.detach_hook(module)
104110
return module
111+
112+
def reset_state(self):
113+
for hook in self.hooks:
114+
if hook._is_stateful:
115+
hook.reset_state()
105116

106117

107118
def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False):
@@ -195,3 +206,19 @@ def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> t
195206
remove_hook_from_module(child, recurse)
196207

197208
return module
209+
210+
211+
def reset_stateful_hooks(module: torch.nn.Module, recurse: bool = False):
212+
"""
213+
Resets the state of all stateful hooks attached to a module.
214+
215+
Args:
216+
module (`torch.nn.Module`):
217+
The module to reset the stateful hooks from.
218+
"""
219+
if hasattr(module, "_diffusers_hook") and (module._diffusers_hook._is_stateful or isinstance(module._diffusers_hook, SequentialHook)):
220+
module._diffusers_hook.reset_state(module)
221+
222+
if recurse:
223+
for child in module.children():
224+
reset_stateful_hooks(child, recurse)

0 commit comments

Comments
 (0)