@@ -131,28 +131,6 @@ def _initialize_hook_registry(self) -> None:
131131 # Scan existing components for hooks
132132 self ._scan_existing_hooks (self , "" )
133133
134- # Add bridge aliases if compatibility mode is enabled
135- if self .compatibility_mode and self .hook_aliases :
136- for alias_name , target in self .hook_aliases .items ():
137- # Use the existing alias system to resolve the target hook
138- # Convert to Dict[str, str] for resolve_alias if target_name is a list
139- if isinstance (target , list ):
140- # For list targets, try each one until one works
141- for single_target in target :
142- try :
143- target_hook = resolve_alias (
144- self , alias_name , {alias_name : single_target }
145- )
146- if target_hook is not None :
147- self ._hook_registry [alias_name ] = target_hook
148- break
149- except AttributeError :
150- continue
151- else :
152- target_hook = resolve_alias (self , alias_name , {alias_name : target })
153- if target_hook is not None :
154- self ._hook_registry [alias_name ] = target_hook
155-
156134 self ._hook_registry_initialized = True
157135
158136 def _scan_existing_hooks (self , module : nn .Module , prefix : str = "" ) -> None :
@@ -180,41 +158,39 @@ def scan_module(mod: nn.Module, path: str = "") -> None:
180158 except Exception :
181159 # If get_hooks() fails, fall through to the else block
182160 pass
183- else :
184- # Fall back to scanning attributes for non-GeneralizedComponent modules
185- for attr_name in dir (mod ):
186- if attr_name .startswith ("_" ):
187- continue
188- if attr_name == "original_component" :
189- continue
190161
191- try :
192- attr = getattr (mod , attr_name )
193- except Exception :
194- continue
195-
196- name = f"{ path } .{ attr_name } " if path else attr_name
197-
198- if isinstance (attr , HookPoint ):
199- attr .name = name
200- self ._hook_registry [name ] = attr
201- elif isinstance (attr , HookPointWrapper ):
202- hook_in_name = f"{ name } .hook_in"
203- hook_out_name = f"{ name } .hook_out"
204- attr .hook_in .name = hook_in_name
205- attr .hook_out .name = hook_out_name
206- self ._hook_registry [hook_in_name ] = attr .hook_in
207- self ._hook_registry [hook_out_name ] = attr .hook_out
208- elif isinstance (attr , nn .Module ) and attr is not mod :
209- scan_module (attr , name )
210- elif isinstance (attr , (list , tuple )):
211- for i , item in enumerate (attr ):
212- if isinstance (item , nn .Module ):
213- scan_module (item , f"{ name } [{ i } ]" )
162+ # Always scan attributes for additional hooks and submodules
163+ for attr_name in dir (mod ):
164+ if attr_name .startswith ("_" ):
165+ continue
166+ if attr_name == "original_component" or "original_model" :
167+ continue
168+
169+ try :
170+ attr = getattr (mod , attr_name )
171+ except Exception :
172+ continue
173+
174+ name = f"{ path } .{ attr_name } " if path else attr_name
175+
176+ if isinstance (attr , HookPoint ):
177+ attr .name = name
178+ self ._hook_registry [name ] = attr
179+ elif isinstance (attr , HookPointWrapper ):
180+ hook_in_name = f"{ name } .hook_in"
181+ hook_out_name = f"{ name } .hook_out"
182+ attr .hook_in .name = hook_in_name
183+ attr .hook_out .name = hook_out_name
184+ self ._hook_registry [hook_in_name ] = attr .hook_in
185+ self ._hook_registry [hook_out_name ] = attr .hook_out
214186
215187 # Check named children
216188 for child_name , child_module in mod .named_children ():
217- if child_name == "original_component" or child_name == "_original_component" :
189+ if (
190+ child_name == "original_component"
191+ or child_name == "_original_component"
192+ or child_name == "original_model"
193+ ):
218194 continue
219195 child_path = f"{ path } .{ child_name } " if path else child_name
220196 scan_module (child_module , child_path )
@@ -225,23 +201,7 @@ def scan_module(mod: nn.Module, path: str = "") -> None:
225201 def hook_dict (self ) -> dict [str , HookPoint ]:
226202 """Get all HookPoint objects in the model for compatibility with HookedTransformer."""
227203 # Start with the current registry
228- hooks = self ._hook_registry .copy ()
229-
230- # Add aliases if compatibility mode is enabled
231- if self .compatibility_mode :
232- for alias_name , target in self .hook_aliases .items ():
233- # Handle both string and list target names
234- if isinstance (target , list ):
235- # For list targets, find the first one that exists in hooks
236- for single_target in target :
237- if single_target in hooks :
238- hooks [alias_name ] = hooks [single_target ]
239- break
240- else :
241- if target in hooks :
242- hooks [alias_name ] = hooks [target ]
243-
244- return hooks
204+ return self ._hook_registry .copy ()
245205
246206 def _discover_hooks (self ) -> dict [str , HookPoint ]:
247207 """Get all HookPoint objects from the registry (deprecated, use hook_dict)."""
@@ -262,17 +222,10 @@ def __getattr__(self, name: str) -> Any:
262222 return self .__dict__ [name ]
263223
264224 # Check if this is a hook alias when compatibility mode is enabled
265- if self .compatibility_mode and name in self .hook_aliases :
266- target = self .hook_aliases [name ]
267- # Handle both string and list target names
268- if isinstance (target , list ):
269- # For list targets, find the first one that exists in the registry
270- for single_target in target :
271- if single_target in self ._hook_registry :
272- return self ._hook_registry [single_target ]
273- else :
274- if target in self ._hook_registry :
275- return self ._hook_registry [target ]
225+ if self .compatibility_mode :
226+ resolved_hook = resolve_alias (self , name , self .hook_aliases )
227+ if resolved_hook is not None :
228+ return resolved_hook
276229
277230 return super ().__getattr__ (name )
278231
@@ -982,7 +935,7 @@ def cache_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
982935 return cache_hook
983936
984937 # Use cached hooks instead of re-discovering them
985- hook_dict = self .hook_dict
938+ hook_dict = self ._hook_registry
986939
987940 # Filter hooks based on names_filter
988941 for hook_name , hook in hook_dict .items ():
@@ -1026,7 +979,7 @@ def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
1026979
1027980 # Add hook to the output of the last layer to be processed
1028981 block_hook_name = f"blocks.{ last_layer_to_process } .hook_out"
1029- hook_dict = self .hook_dict
982+ hook_dict = self ._hook_registry
1030983 if block_hook_name in hook_dict :
1031984 hook_dict [block_hook_name ].add_hook (stop_hook )
1032985 hooks .append ((hook_dict [block_hook_name ], block_hook_name ))
@@ -1167,7 +1120,7 @@ def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
11671120
11681121 # Add hook to the output of the last layer to be processed
11691122 block_hook_name = f"blocks.{ last_layer_to_process } .hook_out"
1170- hook_dict = self .hook_dict
1123+ hook_dict = self ._hook_registry
11711124 if block_hook_name in hook_dict :
11721125 add_hook_to_point (hook_dict [block_hook_name ], stop_hook , block_hook_name )
11731126
@@ -1197,7 +1150,7 @@ def wrapped_hook_fn(tensor, hook):
11971150
11981151 if isinstance (hook_name_or_filter , str ):
11991152 # Direct hook name - check for aliases first
1200- hook_dict = self .hook_dict
1153+ hook_dict = self ._hook_registry
12011154 actual_hook_name = hook_name_or_filter
12021155
12031156 # If this is an alias, resolve it to the actual hook name
@@ -1208,7 +1161,7 @@ def wrapped_hook_fn(tensor, hook):
12081161 add_hook_to_point (hook_dict [actual_hook_name ], hook_fn , actual_hook_name )
12091162 else :
12101163 # Filter function
1211- hook_dict = self .hook_dict
1164+ hook_dict = self ._hook_registry
12121165 for name , hook_point in hook_dict .items ():
12131166 if hook_name_or_filter (name ):
12141167 add_hook_to_point (hook_point , hook_fn , name )
0 commit comments