1313# limitations under the License.
1414
1515import functools
16- from typing import Any , Dict , Tuple
16+ from typing import Any , Dict , List , Tuple
1717
1818import torch
1919
@@ -25,6 +25,8 @@ class ModelHook:
2525 with PyTorch existing hooks is that they get passed along the kwargs.
2626 """
2727
28+ _is_stateful = False
29+
2830 def init_hook (self , module : torch .nn .Module ) -> torch .nn .Module :
2931 r"""
3032 Hook that is executed when a model is initialized.
@@ -75,13 +77,17 @@ def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
7577 The module detached from this hook.
7678 """
7779 return module
80+
81+ def reset_state (self ):
82+ if self ._is_stateful :
83+ raise NotImplementedError ("This hook is stateful and needs to implement the `reset_state` method." )
7884
7985
8086class SequentialHook (ModelHook ):
8187 r"""A hook that can contain several hooks and iterates through them at each event."""
8288
8389 def __init__ (self , * hooks ):
84- self .hooks = hooks
90+ self .hooks : List [ ModelHook ] = hooks
8591
8692 def init_hook (self , module ):
8793 for hook in self .hooks :
@@ -102,6 +108,11 @@ def detach_hook(self, module):
102108 for hook in self .hooks :
103109 module = hook .detach_hook (module )
104110 return module
111+
112+ def reset_state (self ):
113+ for hook in self .hooks :
114+ if hook ._is_stateful :
115+ hook .reset_state ()
105116
106117
107118def add_hook_to_module (module : torch .nn .Module , hook : ModelHook , append : bool = False ):
@@ -195,3 +206,19 @@ def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> t
195206 remove_hook_from_module (child , recurse )
196207
197208 return module
209+
210+
211+ def reset_stateful_hooks (module : torch .nn .Module , recurse : bool = False ):
212+ """
213+ Resets the state of all stateful hooks attached to a module.
214+
215+ Args:
216+ module (`torch.nn.Module`):
217+ The module to reset the stateful hooks from.
218+ """
219+ if hasattr (module , "_diffusers_hook" ) and (module ._diffusers_hook ._is_stateful or isinstance (module ._diffusers_hook , SequentialHook )):
220+ module ._diffusers_hook .reset_state (module )
221+
222+ if recurse :
223+ for child in module .children ():
224+ reset_stateful_hooks (child , recurse )
0 commit comments