@@ -1846,8 +1846,10 @@ def run_with_hooks(
18461846 # Store hooks that we add so we can remove them later
18471847 added_hooks : List [Tuple [HookPoint , str ]] = []
18481848
1849- def add_hook_to_point (hook_point : HookPoint , hook_fn : Callable , name : str ):
1850- hook_point .add_hook (hook_fn )
1849+ def add_hook_to_point (
1850+ hook_point : HookPoint , hook_fn : Callable , name : str , dir : Literal ["fwd" , "bwd" ] = "fwd"
1851+ ):
1852+ hook_point .add_hook (hook_fn , dir = dir )
18511853 added_hooks .append ((hook_point , name ))
18521854
18531855 # Add stop_at_layer hook if specified
@@ -1868,10 +1870,11 @@ def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
18681870 block_hook_name = f"blocks.{ last_layer_to_process } .hook_out"
18691871 hook_dict = self .hook_dict
18701872 if block_hook_name in hook_dict :
1871- add_hook_to_point (hook_dict [block_hook_name ], stop_hook , block_hook_name )
1873+ add_hook_to_point (hook_dict [block_hook_name ], stop_hook , block_hook_name , "fwd" )
18721874
18731875 # Helper function to apply hooks based on name or filter function
18741876 def apply_hooks (hooks : List [Tuple [Union [str , Callable ], Callable ]], is_fwd : bool ):
1877+ direction : Literal ["fwd" , "bwd" ] = "fwd" if is_fwd else "bwd"
18751878 # Collect aliases for resolving legacy hook names
18761879 aliases = collect_aliases_recursive (self )
18771880
@@ -1904,13 +1907,15 @@ def wrapped_hook_fn(tensor, hook):
19041907 actual_hook_name = aliases [hook_name_or_filter ]
19051908
19061909 if actual_hook_name in hook_dict :
1907- add_hook_to_point (hook_dict [actual_hook_name ], hook_fn , actual_hook_name )
1910+ add_hook_to_point (
1911+ hook_dict [actual_hook_name ], hook_fn , actual_hook_name , direction
1912+ )
19081913 else :
19091914 # Filter function
19101915 hook_dict = self .hook_dict
19111916 for name , hook_point in hook_dict .items ():
19121917 if hook_name_or_filter (name ):
1913- add_hook_to_point (hook_point , hook_fn , name )
1918+ add_hook_to_point (hook_point , hook_fn , name , direction )
19141919
19151920 try :
19161921 # Apply forward hooks
@@ -2330,10 +2335,18 @@ def _hooks_context():
23302335 # Add forward hooks
23312336 for hook_name , hook_fn in fwd_hooks :
23322337 try :
2333- self .add_hook (hook_name , hook_fn )
2338+ self .add_hook (hook_name , hook_fn , dir = "fwd" )
2339+ added_hooks .append ((hook_name , hook_fn ))
2340+ except Exception as e :
2341+ print (f"Warning: Failed to add forward hook { hook_name } : { e } " )
2342+
2343+ # Add backward hooks
2344+ for hook_name , hook_fn in bwd_hooks :
2345+ try :
2346+ self .add_hook (hook_name , hook_fn , dir = "bwd" )
23342347 added_hooks .append ((hook_name , hook_fn ))
23352348 except Exception as e :
2336- print (f"Warning: Failed to add hook { hook_name } : { e } " )
2349+ print (f"Warning: Failed to add backward hook { hook_name } : { e } " )
23372350
23382351 yield
23392352
0 commit comments