@@ -105,6 +105,9 @@ def __init__(
105105 # Initialize hook registry after components are set up
106106 self ._initialize_hook_registry ()
107107
108+ # Intiialize dictionary containing hooks that will be cached
109+ self ._initialize_hooks_to_cache ()
110+
108111 def __setattr__ (self , name : str , value : Any ) -> None :
109112 """Override setattr to track HookPoint objects dynamically."""
110113 # Call parent setattr first
@@ -141,8 +144,37 @@ def _initialize_hook_registry(self) -> None:
141144 # Scan existing components for hooks
142145 self ._scan_existing_hooks (self , "" )
143146
147+ # Add bridge aliases if compatibility mode is enabled
148+ if self .compatibility_mode :
149+ self ._add_aliases_to_hooks (self ._hook_registry )
150+
144151 self ._hook_registry_initialized = True
145152
153+ def _add_aliases_to_hooks (self , hooks : Dict [str , HookPoint ]) -> None :
154+ """Add aliases to hooks in place."""
155+
156+ # If no aliases, do nothing
157+ if not self .hook_aliases :
158+ return
159+
160+ for alias_name , target in self .hook_aliases .items ():
161+ # Use the existing alias system to resolve the target hook
162+ # Convert to Dict[str, str] for resolve_alias if target_name is a list
163+ if isinstance (target , list ):
164+ # For list targets, try each one until one works
165+ for single_target in target :
166+ try :
167+ target_hook = resolve_alias (self , alias_name , {alias_name : single_target })
168+ if target_hook is not None :
169+ hooks [alias_name ] = target_hook
170+ break
171+ except AttributeError :
172+ continue
173+ else :
174+ target_hook = resolve_alias (self , alias_name , {alias_name : target })
175+ if target_hook is not None :
176+ hooks [alias_name ] = target_hook
177+
146178 def _scan_existing_hooks (self , module : nn .Module , prefix : str = "" ) -> None :
147179 """Scan existing modules for hooks and add them to registry."""
148180 visited = set ()
@@ -210,8 +242,13 @@ def scan_module(mod: nn.Module, path: str = "") -> None:
210242 @property
211243 def hook_dict (self ) -> dict [str , HookPoint ]:
212244 """Get all HookPoint objects in the model for compatibility with HookedTransformer."""
213- # Start with the current registry
214- return self ._hook_registry .copy ()
245+ hooks = self ._hook_registry .copy ()
246+
247+ # Add aliases if compatibility mode is enabled
248+ if self .compatibility_mode :
249+ self ._add_aliases_to_hooks (hooks )
250+
251+ return hooks
215252
216253 def _discover_hooks (self ) -> dict [str , HookPoint ]:
217254 """Get all HookPoint objects from the registry (deprecated, use hook_dict)."""
@@ -226,6 +263,108 @@ def clear_hook_registry(self) -> None:
226263 self ._hook_registry .clear ()
227264 self ._hook_registry_initialized = False
228265
266+ def _initialize_hooks_to_cache (self ) -> None :
267+ """Initialize the hooks to cache when running the model with cache."""
268+ self .hooks_to_cache = {}
269+
270+ default_cached_hooks_names = [
271+ "embed.hook_in" ,
272+ "embed.hook_out" ,
273+ "pos_embed.hook_in" ,
274+ "pos_embed.hook_out" ,
275+ "rotary_embed.hook_in" ,
276+ "rotary_embed.hook_out" ,
277+ "ln_final.hook_in" ,
278+ "ln_final.hook_scale" ,
279+ "ln_final.hook_normalized" ,
280+ "ln_final.hook_out" ,
281+ "unembed.hook_in" ,
282+ "unembed.hook_out" ,
283+ ]
284+
285+ for block_idx in range (self .cfg .n_layers ):
286+ default_cached_hooks_names .append (f"blocks.{ block_idx } .hook_in" )
287+ default_cached_hooks_names .append (f"blocks.{ block_idx } .ln1.hook_in" )
288+ default_cached_hooks_names .append (f"blocks.{ block_idx } .ln1.hook_scale" )
289+ default_cached_hooks_names .append (f"blocks.{ block_idx } .ln1.hook_normalized" )
290+ default_cached_hooks_names .append (f"blocks.{ block_idx } .ln1.hook_out" )
291+ default_cached_hooks_names .append (f"blocks.{ block_idx } .ln1_post.hook_in" )
292+ default_cached_hooks_names .append (f"blocks.{ block_idx } .ln1_post.hook_scale" )
293+ default_cached_hooks_names .append (f"blocks.{ block_idx } .ln1_post.hook_normalized" )
294+ default_cached_hooks_names .append (f"blocks.{ block_idx } .ln1_post.hook_out" )
295+ default_cached_hooks_names .append (f"blocks.{ block_idx } .attn.hook_in" )
296+ default_cached_hooks_names .append (f"blocks.{ block_idx } .attn.q.hook_in" )
297+ default_cached_hooks_names .append (f"blocks.{ block_idx } .attn.q.hook_out" )
298+ default_cached_hooks_names .append (f"blocks.{ block_idx } .attn.q_norm.hook_in" )
299+ default_cached_hooks_names .append (f"blocks.{ block_idx } .attn.q_norm.hook_out" )
300+ default_cached_hooks_names .append (f"blocks.{ block_idx } .attn.k.hook_in" )
301+ default_cached_hooks_names .append (f"blocks.{ block_idx } .attn.k.hook_out" )
302+ default_cached_hooks_names .append (f"blocks.{ block_idx } .attn.k_norm.hook_in" )
303+ default_cached_hooks_names .append (f"blocks.{ block_idx } .attn.k_norm.hook_out" )
304+ default_cached_hooks_names .append (f"blocks.{ block_idx } .attn.v.hook_in" )
305+ default_cached_hooks_names .append (f"blocks.{ block_idx } .attn.v.hook_out" )
306+ default_cached_hooks_names .append (f"blocks.{ block_idx } .attn.o.hook_in" )
307+ default_cached_hooks_names .append (f"blocks.{ block_idx } .attn.o.hook_out" )
308+ default_cached_hooks_names .append (f"blocks.{ block_idx } .attn.hook_attn_scores" )
309+ default_cached_hooks_names .append (f"blocks.{ block_idx } .attn.hook_pattern" )
310+ default_cached_hooks_names .append (f"blocks.{ block_idx } .attn.hook_hidden_states" )
311+ default_cached_hooks_names .append (f"blocks.{ block_idx } .attn.hook_out" )
312+ default_cached_hooks_names .append (f"blocks.{ block_idx } .ln2.hook_in" )
313+ default_cached_hooks_names .append (f"blocks.{ block_idx } .ln2.hook_scale" )
314+ default_cached_hooks_names .append (f"blocks.{ block_idx } .ln2.hook_normalized" )
315+ default_cached_hooks_names .append (f"blocks.{ block_idx } .ln2.hook_out" )
316+ default_cached_hooks_names .append (f"blocks.{ block_idx } .ln2_post.hook_in" )
317+ default_cached_hooks_names .append (f"blocks.{ block_idx } .ln2_post.hook_scale" )
318+ default_cached_hooks_names .append (f"blocks.{ block_idx } .ln2_post.hook_normalized" )
319+ default_cached_hooks_names .append (f"blocks.{ block_idx } .ln2_post.hook_out" )
320+ default_cached_hooks_names .append (f"blocks.{ block_idx } .mlp.hook_in" )
321+ default_cached_hooks_names .append (f"blocks.{ block_idx } .mlp.in.hook_in" )
322+ default_cached_hooks_names .append (f"blocks.{ block_idx } .mlp.in.hook_out" )
323+ default_cached_hooks_names .append (f"blocks.{ block_idx } .mlp.out.hook_in" )
324+ default_cached_hooks_names .append (f"blocks.{ block_idx } .mlp.out.hook_out" )
325+ default_cached_hooks_names .append (f"blocks.{ block_idx } .mlp.gate.hook_in" )
326+ default_cached_hooks_names .append (f"blocks.{ block_idx } .mlp.gate.hook_out" )
327+ default_cached_hooks_names .append (f"blocks.{ block_idx } .mlp.hook_out" )
328+ default_cached_hooks_names .append (f"blocks.{ block_idx } .hook_out" )
329+
330+ for hook_name in default_cached_hooks_names :
331+ if hook_name in self ._hook_registry :
332+ self .hooks_to_cache [hook_name ] = self ._hook_registry [hook_name ]
333+
334+ def set_hooks_to_cache (
335+ self , hook_names : Optional [List [str ]] = None , include_all : bool = False
336+ ) -> None :
337+ """Set the hooks to cache when running the model with cache.
338+
339+ You can specify hook names that were only available in the old HookedTransformer,
340+ but in this case you need to make sure to enable compatibility mode.
341+
342+ Args:
343+ hook_names (Optional[List[str]]): List of hook names to cache
344+ include_all (bool): Whether to cache all hooks
345+ """
346+ hooks_to_cache = {}
347+
348+ if self .compatibility_mode :
349+ aliases = collect_aliases_recursive (self )
350+
351+ if include_all :
352+ self .hooks_to_cache = self .hook_dict
353+ return
354+
355+ if hook_names is not None :
356+ for hook_name in hook_names :
357+ if hook_name in self ._hook_registry :
358+ hooks_to_cache [hook_name ] = self ._hook_registry [hook_name ]
359+ else :
360+ raise ValueError (
361+ f"Hook { hook_name } does not exist. If you are using a hook name used with the old HookedTransformer, make sure to enable compatibility mode."
362+ )
363+ else :
364+ raise ValueError ("hook_names must be provided if include_all is False" )
365+
366+ self .hooks_to_cache = hooks_to_cache
367+
229368 def __getattr__ (self , name : str ) -> Any :
230369 """Provide a clear error message for missing attributes."""
231370 if name in self .__dict__ :
@@ -1543,7 +1682,7 @@ def cache_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
15431682 return cache_hook
15441683
15451684 # Use cached hooks instead of re-discovering them
1546- hook_dict = self .hook_dict
1685+ hook_dict = self .hooks_to_cache
15471686
15481687 # Filter hooks based on names_filter
15491688 for hook_name , hook in hook_dict .items ():
0 commit comments