3535)
3636from transformer_lens .model_bridge .hook_point_wrapper import HookPointWrapper
3737from transformer_lens .model_bridge .types import ComponentMapping
38- from transformer_lens .utilities .aliases import collect_aliases_recursive
38+ from transformer_lens .utilities .aliases import collect_aliases_recursive , resolve_alias
3939
4040if TYPE_CHECKING :
4141 from transformer_lens .ActivationCache import ActivationCache
@@ -54,6 +54,7 @@ class TransformerBridge(nn.Module):
5454 hook_aliases = {
5555 "hook_embed" : "embed.hook_out" ,
5656 "hook_pos_embed" : "pos_embed.hook_out" ,
57+ "hook_unembed" : "unembed.hook_out" ,
5758 }
5859
5960 def __init__ (self , model : nn .Module , adapter : ArchitectureAdapter , tokenizer : Any ):
@@ -70,6 +71,11 @@ def __init__(self, model: nn.Module, adapter: ArchitectureAdapter, tokenizer: An
7071 self .cfg = adapter .cfg
7172 self .tokenizer = tokenizer
7273 self .compatibility_mode = False
74+ self ._hook_cache = None # Cache for hook discovery results
75+ self ._hook_registry : Dict [
76+ str , HookPoint
77+ ] = {} # Dynamic registry of hook names to HookPoints
78+ self ._hook_registry_initialized = False # Track if registry has been initialized
7379
7480 # Add device information to config from the loaded model
7581 if not hasattr (self .cfg , "device" ):
@@ -84,68 +90,159 @@ def __init__(self, model: nn.Module, adapter: ArchitectureAdapter, tokenizer: An
8490 # Set original components on the pre-created bridge components
8591 set_original_components (self , self .adapter , self .original_model )
8692
87- @property
88- def hook_dict (self ) -> dict [str , HookPoint ]:
89- """Get all HookPoint objects in the model for compatibility with HookedTransformer."""
90- hooks = {}
91- visited = set () # Move visited set outside the recursive function
92-
93- def collect_hookpoints (module : nn .Module , prefix : str = "" ) -> None :
94- """Recursively collect all HookPoint objects."""
95- obj_id = id (module )
93+ # Initialize hook registry after components are set up
94+ self ._initialize_hook_registry ()
95+
96+ def __setattr__ (self , name : str , value : Any ) -> None :
97+ """Override setattr to track HookPoint objects dynamically."""
98+ # Call parent setattr first
99+ super ().__setattr__ (name , value )
100+
101+ # Check if this is a HookPoint being set
102+ if isinstance (value , HookPoint ):
103+ # Set the name on the HookPoint
104+ value .name = name
105+ # Add to registry
106+ self ._hook_registry [name ] = value
107+ elif isinstance (value , HookPointWrapper ):
108+ # Handle HookPointWrapper objects
109+ hook_in_name = f"{ name } .hook_in"
110+ hook_out_name = f"{ name } .hook_out"
111+ value .hook_in .name = hook_in_name
112+ value .hook_out .name = hook_out_name
113+ self ._hook_registry [hook_in_name ] = value .hook_in
114+ self ._hook_registry [hook_out_name ] = value .hook_out
115+ elif hasattr (value , "get_hooks" ) and callable (getattr (value , "get_hooks" )):
116+ # This is a GeneralizedComponent being set
117+ # We need to register its hooks with the appropriate prefix
118+ component_hooks = value .get_hooks ()
119+ for hook_name , hook in component_hooks .items ():
120+ full_name = f"{ name } .{ hook_name } "
121+ hook .name = full_name
122+ self ._hook_registry [full_name ] = hook
123+
124+ def _initialize_hook_registry (self ) -> None :
125+ """Initialize the hook registry by scanning existing components."""
126+ if self ._hook_registry_initialized :
127+ return
128+
129+ # Scan existing components for hooks
130+ self ._scan_existing_hooks (self , "" )
131+
132+ # Add bridge aliases if compatibility mode is enabled
133+ if self .compatibility_mode and self .hook_aliases :
134+ for alias_name , target_name in self .hook_aliases .items ():
135+ # Use the existing alias system to resolve the target hook
136+ target_hook = resolve_alias (self , alias_name , self .hook_aliases )
137+ if target_hook is not None :
138+ self ._hook_registry [alias_name ] = target_hook
139+
140+ self ._hook_registry_initialized = True
141+
142+ def _scan_existing_hooks (self , module : nn .Module , prefix : str = "" ) -> None :
143+ """Scan existing modules for hooks and add them to registry."""
144+ visited = set ()
145+
146+ def scan_module (mod : nn .Module , path : str = "" ) -> None :
147+ obj_id = id (mod )
96148 if obj_id in visited :
97149 return
98150 visited .add (obj_id )
99151
100- for attr_name in dir (module ):
101- if attr_name .startswith ("_" ):
102- continue
103- # Skip original_component to avoid deep traversal
104- if attr_name == "original_component" :
105- continue
152+ # Check if this is a GeneralizedComponent with its own hook registry
153+ if hasattr (mod , "get_hooks" ) and callable (getattr (mod , "get_hooks" )):
154+ # Use the component's own hook registry
106155 try :
107- attr = getattr (module , attr_name )
156+ component_hooks = mod .get_hooks () # type: ignore
157+ if isinstance (component_hooks , dict ):
158+ # Type cast to help mypy understand this is a dict of hooks
159+ hooks_dict = cast (Dict [str , HookPoint ], component_hooks ) # type: ignore
160+ for hook_name , hook in hooks_dict .items (): # type: ignore
161+ full_name = f"{ path } .{ hook_name } " if path else hook_name
162+ hook .name = full_name
163+ self ._hook_registry [full_name ] = hook
108164 except Exception :
109- continue
165+ # If get_hooks() fails, fall through to the else block
166+ pass
167+ else :
168+ # Fall back to scanning attributes for non-GeneralizedComponent modules
169+ for attr_name in dir (mod ):
170+ if attr_name .startswith ("_" ):
171+ continue
172+ if attr_name == "original_component" :
173+ continue
110174
111- name = f"{ prefix } .{ attr_name } " if prefix else attr_name
112- if isinstance (attr , HookPoint ):
113- # Set the name on the HookPoint so it can be used in caching
114- attr .name = name
115- hooks [name ] = attr
116- elif hasattr (attr , "hook_in" ) and hasattr (attr , "hook_out" ):
117- # Handle HookPointWrapper objects
118- if isinstance (attr , HookPointWrapper ):
119- # Add hook_in and hook_out from the wrapper
175+ try :
176+ attr = getattr (mod , attr_name )
177+ except Exception :
178+ continue
179+
180+ name = f"{ path } .{ attr_name } " if path else attr_name
181+
182+ if isinstance (attr , HookPoint ):
183+ attr .name = name
184+ self ._hook_registry [name ] = attr
185+ elif isinstance (attr , HookPointWrapper ):
120186 hook_in_name = f"{ name } .hook_in"
121187 hook_out_name = f"{ name } .hook_out"
122188 attr .hook_in .name = hook_in_name
123189 attr .hook_out .name = hook_out_name
124- hooks [hook_in_name ] = attr .hook_in
125- hooks [hook_out_name ] = attr .hook_out
126- elif isinstance (attr , nn .Module ) and attr is not module :
127- collect_hookpoints (attr , name )
128- elif isinstance (attr , (list , tuple )):
129- for i , item in enumerate (attr ):
130- if isinstance (item , nn .Module ):
131- collect_hookpoints (item , f"{ name } [{ i } ]" )
132-
133- # Also traverse named_children() to catch ModuleList and other containers
134- for child_name , child_module in module .named_children ():
135- # Skip original_component and _original_component to avoid deep traversal
190+ self ._hook_registry [hook_in_name ] = attr .hook_in
191+ self ._hook_registry [hook_out_name ] = attr .hook_out
192+ elif isinstance (attr , nn .Module ) and attr is not mod :
193+ scan_module (attr , name )
194+ elif isinstance (attr , (list , tuple )):
195+ for i , item in enumerate (attr ):
196+ if isinstance (item , nn .Module ):
197+ scan_module (item , f"{ name } [{ i } ]" )
198+
199+ # Check named children
200+ for child_name , child_module in mod .named_children ():
136201 if child_name == "original_component" or child_name == "_original_component" :
137202 continue
138- child_path = f"{ prefix } .{ child_name } " if prefix else child_name
139- collect_hookpoints (child_module , child_path )
203+ child_path = f"{ path } .{ child_name } " if path else child_name
204+ scan_module (child_module , child_path )
205+
206+ scan_module (module , prefix )
207+
208+ @property
209+ def hook_dict (self ) -> dict [str , HookPoint ]:
210+ """Get all HookPoint objects in the model for compatibility with HookedTransformer."""
211+ # Start with the current registry
212+ hooks = self ._hook_registry .copy ()
213+
214+ # Add aliases if compatibility mode is enabled
215+ if self .compatibility_mode :
216+ for alias_name , target_name in self .hook_aliases .items ():
217+ if target_name in hooks :
218+ hooks [alias_name ] = hooks [target_name ]
140219
141- collect_hookpoints (self , "" )
142220 return hooks
143221
222+ def _discover_hooks (self ) -> dict [str , HookPoint ]:
223+ """Get all HookPoint objects from the registry (deprecated, use hook_dict)."""
224+ return self ._hook_registry .copy ()
225+
226+ def clear_hook_cache (self ) -> None :
227+ """Clear the cached hook discovery results (deprecated, kept for compatibility)."""
228+ pass # No longer needed since we don't use caching
229+
230+ def clear_hook_registry (self ) -> None :
231+ """Clear the hook registry and force re-initialization."""
232+ self ._hook_registry .clear ()
233+ self ._hook_registry_initialized = False
234+
144235 def __getattr__ (self , name : str ) -> Any :
145236 """Provide a clear error message for missing attributes."""
146237 if name in self .__dict__ :
147238 return self .__dict__ [name ]
148239
240+ # Check if this is a hook alias when compatibility mode is enabled
241+ if self .compatibility_mode and name in self .hook_aliases :
242+ target_name = self .hook_aliases [name ]
243+ if target_name in self ._hook_registry :
244+ return self ._hook_registry [target_name ]
245+
149246 return super ().__getattr__ (name )
150247
151248 def _get_nested_attr (self , path : str ) -> Any :
@@ -247,6 +344,10 @@ def set_compatibility_mode(component: Any) -> None:
247344
248345 apply_fn_to_all_components (self , set_compatibility_mode )
249346
347+ # Re-initialize the hook registry to include aliases from components
348+ self .clear_hook_registry ()
349+ self ._initialize_hook_registry ()
350+
250351 # ==================== TOKENIZATION METHODS ====================
251352
252353 def to_tokens (
@@ -738,56 +839,14 @@ def cache_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
738839
739840 return cache_hook
740841
741- # Recursively collect all HookPoint objects
742- def collect_hookpoints (module : nn .Module , prefix : str = "" ) -> None :
743- obj_id = id (module )
744- if obj_id in visited :
745- return
746- visited .add (obj_id )
747-
748- for attr_name in dir (module ):
749- if attr_name .startswith ("_" ):
750- continue
751- # Skip the original_model to avoid collecting hooks from HuggingFace model
752- if attr_name == "original_model" or attr_name == "original_component" :
753- continue
754- try :
755- attr = getattr (module , attr_name )
756- except Exception :
757- continue
758-
759- def add_hook_to_list (hook : HookPoint , name : str ):
760- # Set the name on the HookPoint so it can be used in caching
761- hook .name = name
762-
763- # Only add hook if it passes the names filter
764- if names_filter_fn (name ):
765- hooks .append ((hook , name ))
766-
767- name = f"{ prefix } .{ attr_name } " if prefix else attr_name
768- if isinstance (attr , HookPoint ):
769- add_hook_to_list (attr , name )
770- elif isinstance (attr , HookPointWrapper ):
771- # Add hooks for the wrapped hook points (hook_in and hook_out)
772- add_hook_to_list (attr .hook_in , f"{ name } .hook_in" )
773- add_hook_to_list (attr .hook_out , f"{ name } .hook_out" )
774- elif isinstance (attr , nn .Module ):
775- collect_hookpoints (attr , name )
776- elif isinstance (attr , (list , tuple )):
777- for i , item in enumerate (attr ):
778- if isinstance (item , nn .Module ):
779- collect_hookpoints (item , f"{ name } [{ i } ]" )
780-
781- # Also traverse named_children() to catch ModuleList and other containers
782- for child_name , child_module in module .named_children ():
783- child_path = f"{ prefix } .{ child_name } " if prefix else child_name
784- # Skip the original_model module
785- if child_name == "original_model" or child_name == "original_component" :
786- continue
787- collect_hookpoints (child_module , child_path )
842+ # Use cached hooks instead of re-discovering them
843+ hook_dict = self .hook_dict
788844
789- # Collect hooks from bridge components (these have the clean TransformerLens paths)
790- collect_hookpoints (self , "" )
845+ # Filter hooks based on names_filter
846+ for hook_name , hook in hook_dict .items ():
847+ # Only add hook if it passes the names filter
848+ if names_filter_fn (hook_name ):
849+ hooks .append ((hook , hook_name ))
791850
792851 # Register hooks
793852 for hp , name in hooks :
0 commit comments