1313# limitations under the License.
1414
1515import functools
16+ import gc
1617from typing import Any , Dict , Optional , Tuple
1718
1819import torch
@@ -30,6 +31,9 @@ class ModelHook:
3031
3132 _is_stateful = False
3233
34+ def __init__ (self ):
35+ self .fn_ref : "FunctionReference" = None
36+
3337 def initialize_hook (self , module : torch .nn .Module ) -> torch .nn .Module :
3438 r"""
3539 Hook that is executed when a model is initialized.
@@ -48,8 +52,6 @@ def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
4852 module (`torch.nn.Module`):
4953 The module attached to this hook.
5054 """
51- module .forward = module ._old_forward
52- del module ._old_forward
5355 return module
5456
5557 def pre_forward (self , module : torch .nn .Module , * args , ** kwargs ) -> Tuple [Tuple [Any ], Dict [str , Any ]]:
@@ -99,6 +101,14 @@ def reset_state(self, module: torch.nn.Module):
99101 return module
100102
101103
104+ class FunctionReference :
105+ def __init__ (self ) -> None :
106+ self .pre_forward = None
107+ self .post_forward = None
108+ self .old_forward = None
109+ self .overwritten_forward = None
110+
111+
102112class HookRegistry :
103113 def __init__ (self , module_ref : torch .nn .Module ) -> None :
104114 super ().__init__ ()
@@ -107,51 +117,68 @@ def __init__(self, module_ref: torch.nn.Module) -> None:
107117
108118 self ._module_ref = module_ref
109119 self ._hook_order = []
120+ self ._fn_refs = []
110121
111122 def register_hook (self , hook : ModelHook , name : str ) -> None :
112123 if name in self .hooks .keys ():
113124 logger .warning (f"Hook with name { name } already exists, replacing it." )
114125
115- if hasattr (self ._module_ref , "_old_forward" ):
116- old_forward = self ._module_ref ._old_forward
117- else :
118- old_forward = self ._module_ref .forward
119- self ._module_ref ._old_forward = self ._module_ref .forward
120-
121126 self ._module_ref = hook .initialize_hook (self ._module_ref )
122127
123- if hasattr (hook , "new_forward" ):
124- rewritten_forward = hook .new_forward
125-
128+ def create_new_forward (function_reference : FunctionReference ):
126129 def new_forward (module , * args , ** kwargs ):
127- args , kwargs = hook .pre_forward (module , * args , ** kwargs )
128- output = rewritten_forward (module , * args , ** kwargs )
129- return hook .post_forward (module , output )
130- else :
130+ args , kwargs = function_reference .pre_forward (module , * args , ** kwargs )
131+ output = function_reference .old_forward (* args , ** kwargs )
132+ return function_reference .post_forward (module , output )
131133
132- def new_forward (module , * args , ** kwargs ):
133- args , kwargs = hook .pre_forward (module , * args , ** kwargs )
134- output = old_forward (* args , ** kwargs )
135- return hook .post_forward (module , output )
134+ return new_forward
135+
136+ forward = self ._module_ref .forward
136137
138+ fn_ref = FunctionReference ()
139+ fn_ref .pre_forward = hook .pre_forward
140+ fn_ref .post_forward = hook .post_forward
141+ fn_ref .old_forward = forward
142+
143+ if hasattr (hook , "new_forward" ):
144+ fn_ref .overwritten_forward = forward
145+ fn_ref .old_forward = functools .update_wrapper (
146+ functools .partial (hook .new_forward , self ._module_ref ), hook .new_forward
147+ )
148+
149+ rewritten_forward = create_new_forward (fn_ref )
137150 self ._module_ref .forward = functools .update_wrapper (
138- functools .partial (new_forward , self ._module_ref ), old_forward
151+ functools .partial (rewritten_forward , self ._module_ref ), rewritten_forward
139152 )
140153
154+ hook .fn_ref = fn_ref
141155 self .hooks [name ] = hook
142156 self ._hook_order .append (name )
157+ self ._fn_refs .append (fn_ref )
143158
144159 def get_hook (self , name : str ) -> Optional [ModelHook ]:
145- if name not in self .hooks .keys ():
146- return None
147- return self .hooks [name ]
160+ return self .hooks .get (name , None )
148161
149162 def remove_hook (self , name : str , recurse : bool = True ) -> None :
163+ num_hooks = len (self ._hook_order )
150164 if name in self .hooks .keys ():
151165 hook = self .hooks [name ]
166+ index = self ._hook_order .index (name )
167+ fn_ref = self ._fn_refs [index ]
168+
169+ old_forward = fn_ref .old_forward
170+ if fn_ref .overwritten_forward is not None :
171+ old_forward = fn_ref .overwritten_forward
172+
173+ if index == num_hooks - 1 :
174+ self ._module_ref .forward = old_forward
175+ else :
176+ self ._fn_refs [index + 1 ].old_forward = old_forward
177+
152178 self ._module_ref = hook .deinitalize_hook (self ._module_ref )
153179 del self .hooks [name ]
154- self ._hook_order .remove (name )
180+ self ._hook_order .pop (index )
181+ self ._fn_refs .pop (index )
155182
156183 if recurse :
157184 for module_name , module in self ._module_ref .named_modules ():
@@ -160,8 +187,10 @@ def remove_hook(self, name: str, recurse: bool = True) -> None:
160187 if hasattr (module , "_diffusers_hook" ):
161188 module ._diffusers_hook .remove_hook (name , recurse = False )
162189
190+ gc .collect ()
191+
163192 def reset_stateful_hooks (self , recurse : bool = True ) -> None :
164- for hook_name in self ._hook_order :
193+ for hook_name in reversed ( self ._hook_order ) :
165194 hook = self .hooks [hook_name ]
166195 if hook ._is_stateful :
167196 hook .reset_state (self ._module_ref )
@@ -180,9 +209,13 @@ def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry
180209 return module ._diffusers_hook
181210
182211 def __repr__ (self ) -> str :
183- hook_repr = ""
212+ registry_repr = ""
184213 for i , hook_name in enumerate (self ._hook_order ):
185- hook_repr += f" ({ i } ) { hook_name } - ({ self .hooks [hook_name ].__class__ .__name__ } )"
214+ if self .hooks [hook_name ].__class__ .__repr__ is not object .__repr__ :
215+ hook_repr = self .hooks [hook_name ].__repr__ ()
216+ else :
217+ hook_repr = self .hooks [hook_name ].__class__ .__name__
218+ registry_repr += f" ({ i } ) { hook_name } - { hook_repr } "
186219 if i < len (self ._hook_order ) - 1 :
187- hook_repr += "\n "
188- return f"HookRegistry(\n { hook_repr } \n )"
220+ registry_repr += "\n "
221+ return f"HookRegistry(\n { registry_repr } \n )"
0 commit comments