@@ -31,6 +31,9 @@ class ModelHook:
3131
3232    _is_stateful  =  False 
3333
34+     def  __init__ (self ) ->  None :
35+         self .fn_ref  =  None 
36+ 
3437    def  initialize_hook (self , module : torch .nn .Module ) ->  torch .nn .Module :
3538        r""" 
3639        Hook that is executed when a model is initialized. 
@@ -103,6 +106,7 @@ def __init__(self) -> None:
103106        self .pre_forward  =  None 
104107        self .post_forward  =  None 
105108        self .old_forward  =  None 
109+         self .is_overwritten_forward  =  False 
106110
107111
108112class  HookRegistry :
@@ -119,40 +123,36 @@ def register_hook(self, hook: ModelHook, name: str) -> None:
119123        if  name  in  self .hooks .keys ():
120124            logger .warning (f"Hook with name { name }  )
121125
122-         forward  =  self ._module_ref .forward 
123- 
124-         fn_ref  =  FunctionReference ()
125-         fn_ref .pre_forward  =  hook .pre_forward 
126-         fn_ref .post_forward  =  hook .post_forward 
127-         fn_ref .old_forward  =  forward 
128- 
129126        self ._module_ref  =  hook .initialize_hook (self ._module_ref )
130127
131-         def  create_new_forward (function_reference : FunctionReference ):
128+         def  create_new_forward (function_reference : FunctionReference ,  forward ):
132129            def  new_forward (module , * args , ** kwargs ):
133130                args , kwargs  =  function_reference .pre_forward (module , * args , ** kwargs )
134-                 output  =  function_reference . old_forward (* args , ** kwargs )
131+                 output  =  forward (* args , ** kwargs )
135132                return  function_reference .post_forward (module , output )
136133
137134            return  new_forward 
138135
139-         # if hasattr(hook, "new_forward"): 
140-         #     fn_ref.old_forward = hook.new_forward 
136+         forward  =  self ._module_ref .forward 
141137
142-         #     def new_forward(module, *args, **kwargs): 
143-         #         args, kwargs = hook.pre_forward(module, *args, **kwargs) 
144-         #         output = rewritten_forward(module, *args, **kwargs) 
145-         #         return hook.post_forward(module, output) 
146-         # else: 
138+         fn_ref  =  FunctionReference ()
139+         fn_ref .pre_forward  =  hook .pre_forward 
140+         fn_ref .post_forward  =  hook .post_forward 
141+         fn_ref .old_forward  =  forward 
147142
148-         #     def new_forward(module, *args, **kwargs): 
149-         #         args, kwargs = hook.pre_forward(module, *args, **kwargs) 
150-         #         output = forward(*args, **kwargs) 
151-         #         return hook.post_forward(module, output) 
143+         if  hasattr (hook , "new_forward" ):
144+             new_forward  =  hook .new_forward 
145+             fn_ref .is_overwritten_forward  =  True 
146+         else :
147+             new_forward  =  forward 
148+             fn_ref .is_overwritten_forward  =  False 
152149
153-         new_forward  =  create_new_forward (fn_ref )
154-         self ._module_ref .forward  =  functools .update_wrapper (functools .partial (new_forward , self ._module_ref ), forward )
150+         rewritten_forward  =  create_new_forward (fn_ref , new_forward )
151+         self ._module_ref .forward  =  functools .update_wrapper (
152+             functools .partial (rewritten_forward , self ._module_ref ), forward 
153+         )
155154
155+         hook .fn_ref  =  fn_ref 
156156        self .hooks [name ] =  hook 
157157        self ._hook_order .append (name )
158158        self ._fn_refs .append (fn_ref )
@@ -165,7 +165,6 @@ def remove_hook(self, name: str, recurse: bool = True) -> None:
165165        if  name  in  self .hooks .keys ():
166166            hook  =  self .hooks [name ]
167167            index  =  self ._hook_order .index (name )
168- 
169168            fn_ref  =  self ._fn_refs [index ]
170169
171170            if  index  ==  num_hooks  -  1 :
0 commit comments