@@ -32,7 +32,7 @@ class ModelHook:
3232    _is_stateful  =  False 
3333
3434    def  __init__ (self ):
35-         self .fn_ref : "FunctionReference "  =  None 
35+         self .fn_ref : "HookFunctionReference "  =  None 
3636
3737    def  initialize_hook (self , module : torch .nn .Module ) ->  torch .nn .Module :
3838        r""" 
@@ -101,12 +101,17 @@ def reset_state(self, module: torch.nn.Module):
101101        return  module 
102102
103103
104- class  FunctionReference :
104+ class  HookFunctionReference :
105105    def  __init__ (self ) ->  None :
106+         """ 
107+         Holding class for forward functions references used in Diffusers hooks. This struct allows you to easily swap 
108+         out the forward function in the when a hook is removed from a modules hook registry. 
109+ 
110+         """ 
106111        self .pre_forward  =  None 
107112        self .post_forward  =  None 
108-         self .old_forward  =  None 
109-         self .overwritten_forward  =  None 
113+         self .forward  =  None 
114+         self .original_forward  =  None 
110115
111116
112117class  HookRegistry :
@@ -125,24 +130,24 @@ def register_hook(self, hook: ModelHook, name: str) -> None:
125130
126131        self ._module_ref  =  hook .initialize_hook (self ._module_ref )
127132
128-         def  create_new_forward (function_reference : FunctionReference ):
133+         def  create_new_forward (function_reference : HookFunctionReference ):
129134            def  new_forward (module , * args , ** kwargs ):
130135                args , kwargs  =  function_reference .pre_forward (module , * args , ** kwargs )
131-                 output  =  function_reference .old_forward (* args , ** kwargs )
136+                 output  =  function_reference .forward (* args , ** kwargs )
132137                return  function_reference .post_forward (module , output )
133138
134139            return  new_forward 
135140
136141        forward  =  self ._module_ref .forward 
137142
138-         fn_ref  =  FunctionReference ()
143+         fn_ref  =  HookFunctionReference ()
139144        fn_ref .pre_forward  =  hook .pre_forward 
140145        fn_ref .post_forward  =  hook .post_forward 
141-         fn_ref .old_forward  =  forward 
146+         fn_ref .forward  =  forward 
142147
143148        if  hasattr (hook , "new_forward" ):
144-             fn_ref .overwritten_forward  =  forward 
145-             fn_ref .old_forward  =  functools .update_wrapper (
149+             fn_ref .original_forward  =  forward 
150+             fn_ref .forward  =  functools .update_wrapper (
146151                functools .partial (hook .new_forward , self ._module_ref ), hook .new_forward 
147152            )
148153
@@ -160,25 +165,28 @@ def get_hook(self, name: str) -> Optional[ModelHook]:
160165        return  self .hooks .get (name , None )
161166
162167    def  remove_hook (self , name : str , recurse : bool  =  True ) ->  None :
163-         num_hooks  =  len (self ._hook_order )
164-         if  name  in  self .hooks .keys ():
165-             hook  =  self .hooks [name ]
166-             index  =  self ._hook_order .index (name )
167-             fn_ref  =  self ._fn_refs [index ]
168+         if  name  not  in self .hooks .keys ():
169+             logger .warning (f"hook: { name }  )
170+             return 
168171
169-             old_forward  =  fn_ref .old_forward 
170-             if  fn_ref .overwritten_forward  is  not None :
171-                 old_forward  =  fn_ref .overwritten_forward 
172- 
173-             if  index  ==  num_hooks  -  1 :
174-                 self ._module_ref .forward  =  old_forward 
175-             else :
176-                 self ._fn_refs [index  +  1 ].old_forward  =  old_forward 
177- 
178-             self ._module_ref  =  hook .deinitalize_hook (self ._module_ref )
179-             del  self .hooks [name ]
180-             self ._hook_order .pop (index )
181-             self ._fn_refs .pop (index )
172+         num_hooks  =  len (self ._hook_order )
173+         hook  =  self .hooks [name ]
174+         index  =  self ._hook_order .index (name )
175+         fn_ref  =  self ._fn_refs [index ]
176+ 
177+         old_forward  =  fn_ref .forward 
178+         if  fn_ref .original_forward  is  not None :
179+             old_forward  =  fn_ref .original_forward 
180+ 
181+         if  index  ==  num_hooks  -  1 :
182+             self ._module_ref .forward  =  old_forward 
183+         else :
184+             self ._fn_refs [index  +  1 ].forward  =  old_forward 
185+ 
186+         self ._module_ref  =  hook .deinitalize_hook (self ._module_ref )
187+         del  self .hooks [name ]
188+         self ._hook_order .pop (index )
189+         self ._fn_refs .pop (index )
182190
183191        if  recurse :
184192            for  module_name , module  in  self ._module_ref .named_modules ():
0 commit comments