@@ -32,7 +32,7 @@ class ModelHook:
3232 _is_stateful = False
3333
3434 def __init__ (self ):
35- self .fn_ref : "FunctionReference " = None
35+ self .fn_ref : "HookFunctionReference " = None
3636
3737 def initialize_hook (self , module : torch .nn .Module ) -> torch .nn .Module :
3838 r"""
@@ -101,12 +101,27 @@ def reset_state(self, module: torch.nn.Module):
101101 return module
102102
103103
104- class FunctionReference :
104+ class HookFunctionReference :
105105 def __init__ (self ) -> None :
106+ """A container class that maintains mutable references to forward pass functions in a hook chain.
107+
108+ Its mutable nature allows the hook system to modify the execution chain dynamically without rebuilding the
109+ entire forward pass structure.
110+
111+ Attributes:
112+ pre_forward: A callable that processes inputs before the main forward pass.
113+ post_forward: A callable that processes outputs after the main forward pass.
114+ forward: The current forward function in the hook chain.
115+ original_forward: The original forward function, stored when a hook provides a custom new_forward.
116+
117+ The class enables hook removal by allowing updates to the forward chain through reference modification rather
118+ than requiring reconstruction of the entire chain. When a hook is removed, only the relevant references need to
119+ be updated, preserving the execution order of the remaining hooks.
120+ """
106121 self .pre_forward = None
107122 self .post_forward = None
108- self .old_forward = None
109- self .overwritten_forward = None
123+ self .forward = None
124+ self .original_forward = None
110125
111126
112127class HookRegistry :
@@ -125,24 +140,24 @@ def register_hook(self, hook: ModelHook, name: str) -> None:
125140
126141 self ._module_ref = hook .initialize_hook (self ._module_ref )
127142
128- def create_new_forward (function_reference : FunctionReference ):
143+ def create_new_forward (function_reference : HookFunctionReference ):
129144 def new_forward (module , * args , ** kwargs ):
130145 args , kwargs = function_reference .pre_forward (module , * args , ** kwargs )
131- output = function_reference .old_forward (* args , ** kwargs )
146+ output = function_reference .forward (* args , ** kwargs )
132147 return function_reference .post_forward (module , output )
133148
134149 return new_forward
135150
136151 forward = self ._module_ref .forward
137152
138- fn_ref = FunctionReference ()
153+ fn_ref = HookFunctionReference ()
139154 fn_ref .pre_forward = hook .pre_forward
140155 fn_ref .post_forward = hook .post_forward
141- fn_ref .old_forward = forward
156+ fn_ref .forward = forward
142157
143158 if hasattr (hook , "new_forward" ):
144- fn_ref .overwritten_forward = forward
145- fn_ref .old_forward = functools .update_wrapper (
159+ fn_ref .original_forward = forward
160+ fn_ref .forward = functools .update_wrapper (
146161 functools .partial (hook .new_forward , self ._module_ref ), hook .new_forward
147162 )
148163
@@ -160,25 +175,28 @@ def get_hook(self, name: str) -> Optional[ModelHook]:
160175 return self .hooks .get (name , None )
161176
162177 def remove_hook (self , name : str , recurse : bool = True ) -> None :
163- num_hooks = len (self ._hook_order )
164- if name in self .hooks .keys ():
165- hook = self .hooks [name ]
166- index = self ._hook_order .index (name )
167- fn_ref = self ._fn_refs [index ]
168-
169- old_forward = fn_ref .old_forward
170- if fn_ref .overwritten_forward is not None :
171- old_forward = fn_ref .overwritten_forward
178+ if name not in self .hooks .keys ():
179+ logger .warning (f"hook: { name } was not found in HookRegistry" )
180+ return
172181
173- if index == num_hooks - 1 :
174- self ._module_ref .forward = old_forward
175- else :
176- self ._fn_refs [index + 1 ].old_forward = old_forward
177-
178- self ._module_ref = hook .deinitalize_hook (self ._module_ref )
179- del self .hooks [name ]
180- self ._hook_order .pop (index )
181- self ._fn_refs .pop (index )
182+ num_hooks = len (self ._hook_order )
183+ hook = self .hooks [name ]
184+ index = self ._hook_order .index (name )
185+ fn_ref = self ._fn_refs [index ]
186+
187+ old_forward = fn_ref .forward
188+ if fn_ref .original_forward is not None :
189+ old_forward = fn_ref .original_forward
190+
191+ if index == num_hooks - 1 :
192+ self ._module_ref .forward = old_forward
193+ else :
194+ self ._fn_refs [index + 1 ].forward = old_forward
195+
196+ self ._module_ref = hook .deinitalize_hook (self ._module_ref )
197+ del self .hooks [name ]
198+ self ._hook_order .pop (index )
199+ self ._fn_refs .pop (index )
182200
183201 if recurse :
184202 for module_name , module in self ._module_ref .named_modules ():
0 commit comments