1313# limitations under the License.
1414
1515import functools
16- from typing import Any , Dict , Tuple
16+ from typing import Any , Dict , Optional , Tuple
1717
1818import torch
1919
@@ -33,7 +33,6 @@ class ModelHook:
3333 def initialize_hook (self , module : torch .nn .Module ) -> torch .nn .Module :
3434 r"""
3535 Hook that is executed when a model is initialized.
36-
3736 Args:
3837 module (`torch.nn.Module`):
3938 The module attached to this hook.
@@ -43,7 +42,6 @@ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
4342 def deinitalize_hook (self , module : torch .nn .Module ) -> torch .nn .Module :
4443 r"""
4544 Hook that is executed when a model is deinitalized.
46-
4745 Args:
4846 module (`torch.nn.Module`):
4947 The module attached to this hook.
@@ -55,15 +53,13 @@ def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
5553 def pre_forward (self , module : torch .nn .Module , * args , ** kwargs ) -> Tuple [Tuple [Any ], Dict [str , Any ]]:
5654 r"""
5755 Hook that is executed just before the forward method of the model.
58-
5956 Args:
6057 module (`torch.nn.Module`):
6158 The module whose forward pass will be executed just after this event.
6259 args (`Tuple[Any]`):
6360 The positional arguments passed to the module.
6461 kwargs (`Dict[Str, Any]`):
6562 The keyword arguments passed to the module.
66-
6763 Returns:
6864 `Tuple[Tuple[Any], Dict[Str, Any]]`:
6965 A tuple with the treated `args` and `kwargs`.
@@ -73,13 +69,11 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[A
7369 def post_forward (self , module : torch .nn .Module , output : Any ) -> Any :
7470 r"""
7571 Hook that is executed just after the forward method of the model.
76-
7772 Args:
7873 module (`torch.nn.Module`):
7974 The module whose forward pass been executed just before this event.
8075 output (`Any`):
8176 The output of the module.
82-
8377 Returns:
8478 `Any`: The processed `output`.
8579 """
@@ -88,7 +82,6 @@ def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
8882 def detach_hook (self , module : torch .nn .Module ) -> torch .nn .Module :
8983 r"""
9084 Hook that is executed when the hook is detached from a module.
91-
9285 Args:
9386 module (`torch.nn.Module`):
9487 The module detached from this hook.
@@ -123,31 +116,57 @@ def register_hook(self, hook: ModelHook, name: str) -> None:
123116 self ._module_ref = hook .initialize_hook (self ._module_ref )
124117
125118 if hasattr (hook , "new_forward" ):
126- new_forward = hook .new_forward
119+ rewritten_forward = hook .new_forward
120+
121+ def new_forward (module , * args , ** kwargs ):
122+ args , kwargs = hook .pre_forward (module , * args , ** kwargs )
123+ output = rewritten_forward (module , * args , ** kwargs )
124+ return hook .post_forward (module , output )
127125 else :
128126
129127 def new_forward (module , * args , ** kwargs ):
130128 args , kwargs = hook .pre_forward (module , * args , ** kwargs )
131129 output = old_forward (* args , ** kwargs )
132130 return hook .post_forward (module , output )
133131
134- new_forward = functools .update_wrapper (new_forward , old_forward )
135- self ._module_ref .forward = new_forward .__get__ (self ._module_ref )
132+ self ._module_ref .forward = functools .update_wrapper (
133+ functools .partial (new_forward , self ._module_ref ), old_forward
134+ )
136135
137136 self .hooks [name ] = hook
138137 self ._hook_order .append (name )
139138
140- def get_hook (self , name : str ) -> ModelHook :
139+ def get_hook (self , name : str ) -> Optional [ ModelHook ] :
141140 if name not in self .hooks .keys ():
142- raise ValueError ( f"Hook with name { name } not found." )
141+ return None
143142 return self .hooks [name ]
144143
145- def remove_hook (self , name : str ) -> None :
146- if name not in self .hooks .keys ():
147- raise ValueError (f"Hook with name { name } not found." )
148- self .hooks [name ].deinitalize_hook (self ._module_ref )
149- del self .hooks [name ]
150- self ._hook_order .remove (name )
144+ def remove_hook (self , name : str , recurse : bool = True ) -> None :
145+ if name in self .hooks .keys ():
146+ hook = self .hooks [name ]
147+ self ._module_ref = hook .deinitalize_hook (self ._module_ref )
148+ del self .hooks [name ]
149+ self ._hook_order .remove (name )
150+
151+ if recurse :
152+ for module_name , module in self ._module_ref .named_modules ():
153+ if module_name == "" :
154+ continue
155+ if hasattr (module , "_diffusers_hook" ):
156+ module ._diffusers_hook .remove_hook (name , recurse = False )
157+
158+ def reset_stateful_hooks (self , recurse : bool = True ) -> None :
159+ for hook_name in self ._hook_order :
160+ hook = self .hooks [hook_name ]
161+ if hook ._is_stateful :
162+ hook .reset_state (self ._module_ref )
163+
164+ if recurse :
165+ for module_name , module in self ._module_ref .named_modules ():
166+ if module_name == "" :
167+ continue
168+ if hasattr (module , "_diffusers_hook" ):
169+ module ._diffusers_hook .reset_stateful_hooks (recurse = False )
151170
152171 @classmethod
153172 def check_if_exists_or_initialize (cls , module : torch .nn .Module ) -> "HookRegistry" :
0 commit comments