Skip to content

Commit 840dc44

Browse files
Enable setting cached hooks (#1048)
* Enable setting which hooks to cache * Scan aliases when compatibility mode is enabled * Add bridge aliases to hook_registry again * Simplifications * Add new hooks to hooks cached by default * Restore proper GPT-2 configuration * ran format * Proper access of mlp.in with getattr --------- Co-authored-by: Bryce Meyer <[email protected]>
1 parent f34f5d9 commit 840dc44

File tree

1 file changed

+142
-3
lines changed

1 file changed

+142
-3
lines changed

transformer_lens/model_bridge/bridge.py

Lines changed: 142 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ def __init__(
105105
# Initialize hook registry after components are set up
106106
self._initialize_hook_registry()
107107

108+
# Intiialize dictionary containing hooks that will be cached
109+
self._initialize_hooks_to_cache()
110+
108111
def __setattr__(self, name: str, value: Any) -> None:
109112
"""Override setattr to track HookPoint objects dynamically."""
110113
# Call parent setattr first
@@ -141,8 +144,37 @@ def _initialize_hook_registry(self) -> None:
141144
# Scan existing components for hooks
142145
self._scan_existing_hooks(self, "")
143146

147+
# Add bridge aliases if compatibility mode is enabled
148+
if self.compatibility_mode:
149+
self._add_aliases_to_hooks(self._hook_registry)
150+
144151
self._hook_registry_initialized = True
145152

153+
def _add_aliases_to_hooks(self, hooks: Dict[str, HookPoint]) -> None:
154+
"""Add aliases to hooks in place."""
155+
156+
# If no aliases, do nothing
157+
if not self.hook_aliases:
158+
return
159+
160+
for alias_name, target in self.hook_aliases.items():
161+
# Use the existing alias system to resolve the target hook
162+
# Convert to Dict[str, str] for resolve_alias if target_name is a list
163+
if isinstance(target, list):
164+
# For list targets, try each one until one works
165+
for single_target in target:
166+
try:
167+
target_hook = resolve_alias(self, alias_name, {alias_name: single_target})
168+
if target_hook is not None:
169+
hooks[alias_name] = target_hook
170+
break
171+
except AttributeError:
172+
continue
173+
else:
174+
target_hook = resolve_alias(self, alias_name, {alias_name: target})
175+
if target_hook is not None:
176+
hooks[alias_name] = target_hook
177+
146178
def _scan_existing_hooks(self, module: nn.Module, prefix: str = "") -> None:
147179
"""Scan existing modules for hooks and add them to registry."""
148180
visited = set()
@@ -210,8 +242,13 @@ def scan_module(mod: nn.Module, path: str = "") -> None:
210242
@property
211243
def hook_dict(self) -> dict[str, HookPoint]:
212244
"""Get all HookPoint objects in the model for compatibility with HookedTransformer."""
213-
# Start with the current registry
214-
return self._hook_registry.copy()
245+
hooks = self._hook_registry.copy()
246+
247+
# Add aliases if compatibility mode is enabled
248+
if self.compatibility_mode:
249+
self._add_aliases_to_hooks(hooks)
250+
251+
return hooks
215252

216253
def _discover_hooks(self) -> dict[str, HookPoint]:
217254
"""Get all HookPoint objects from the registry (deprecated, use hook_dict)."""
@@ -226,6 +263,108 @@ def clear_hook_registry(self) -> None:
226263
self._hook_registry.clear()
227264
self._hook_registry_initialized = False
228265

266+
def _initialize_hooks_to_cache(self) -> None:
267+
"""Initialize the hooks to cache when running the model with cache."""
268+
self.hooks_to_cache = {}
269+
270+
default_cached_hooks_names = [
271+
"embed.hook_in",
272+
"embed.hook_out",
273+
"pos_embed.hook_in",
274+
"pos_embed.hook_out",
275+
"rotary_embed.hook_in",
276+
"rotary_embed.hook_out",
277+
"ln_final.hook_in",
278+
"ln_final.hook_scale",
279+
"ln_final.hook_normalized",
280+
"ln_final.hook_out",
281+
"unembed.hook_in",
282+
"unembed.hook_out",
283+
]
284+
285+
for block_idx in range(self.cfg.n_layers):
286+
default_cached_hooks_names.append(f"blocks.{block_idx}.hook_in")
287+
default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_in")
288+
default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_scale")
289+
default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_normalized")
290+
default_cached_hooks_names.append(f"blocks.{block_idx}.ln1.hook_out")
291+
default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_in")
292+
default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_scale")
293+
default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_normalized")
294+
default_cached_hooks_names.append(f"blocks.{block_idx}.ln1_post.hook_out")
295+
default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_in")
296+
default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q.hook_in")
297+
default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q.hook_out")
298+
default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q_norm.hook_in")
299+
default_cached_hooks_names.append(f"blocks.{block_idx}.attn.q_norm.hook_out")
300+
default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k.hook_in")
301+
default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k.hook_out")
302+
default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k_norm.hook_in")
303+
default_cached_hooks_names.append(f"blocks.{block_idx}.attn.k_norm.hook_out")
304+
default_cached_hooks_names.append(f"blocks.{block_idx}.attn.v.hook_in")
305+
default_cached_hooks_names.append(f"blocks.{block_idx}.attn.v.hook_out")
306+
default_cached_hooks_names.append(f"blocks.{block_idx}.attn.o.hook_in")
307+
default_cached_hooks_names.append(f"blocks.{block_idx}.attn.o.hook_out")
308+
default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_attn_scores")
309+
default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_pattern")
310+
default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_hidden_states")
311+
default_cached_hooks_names.append(f"blocks.{block_idx}.attn.hook_out")
312+
default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_in")
313+
default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_scale")
314+
default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_normalized")
315+
default_cached_hooks_names.append(f"blocks.{block_idx}.ln2.hook_out")
316+
default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_in")
317+
default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_scale")
318+
default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_normalized")
319+
default_cached_hooks_names.append(f"blocks.{block_idx}.ln2_post.hook_out")
320+
default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.hook_in")
321+
default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.in.hook_in")
322+
default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.in.hook_out")
323+
default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.out.hook_in")
324+
default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.out.hook_out")
325+
default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.gate.hook_in")
326+
default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.gate.hook_out")
327+
default_cached_hooks_names.append(f"blocks.{block_idx}.mlp.hook_out")
328+
default_cached_hooks_names.append(f"blocks.{block_idx}.hook_out")
329+
330+
for hook_name in default_cached_hooks_names:
331+
if hook_name in self._hook_registry:
332+
self.hooks_to_cache[hook_name] = self._hook_registry[hook_name]
333+
334+
def set_hooks_to_cache(
335+
self, hook_names: Optional[List[str]] = None, include_all: bool = False
336+
) -> None:
337+
"""Set the hooks to cache when running the model with cache.
338+
339+
You can specify hook names that were only available in the old HookedTransformer,
340+
but in this case you need to make sure to enable compatibility mode.
341+
342+
Args:
343+
hook_names (Optional[List[str]]): List of hook names to cache
344+
include_all (bool): Whether to cache all hooks
345+
"""
346+
hooks_to_cache = {}
347+
348+
if self.compatibility_mode:
349+
aliases = collect_aliases_recursive(self)
350+
351+
if include_all:
352+
self.hooks_to_cache = self.hook_dict
353+
return
354+
355+
if hook_names is not None:
356+
for hook_name in hook_names:
357+
if hook_name in self._hook_registry:
358+
hooks_to_cache[hook_name] = self._hook_registry[hook_name]
359+
else:
360+
raise ValueError(
361+
f"Hook {hook_name} does not exist. If you are using a hook name used with the old HookedTransformer, make sure to enable compatibility mode."
362+
)
363+
else:
364+
raise ValueError("hook_names must be provided if include_all is False")
365+
366+
self.hooks_to_cache = hooks_to_cache
367+
229368
def __getattr__(self, name: str) -> Any:
230369
"""Provide a clear error message for missing attributes."""
231370
if name in self.__dict__:
@@ -1543,7 +1682,7 @@ def cache_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
15431682
return cache_hook
15441683

15451684
# Use cached hooks instead of re-discovering them
1546-
hook_dict = self.hook_dict
1685+
hook_dict = self.hooks_to_cache
15471686

15481687
# Filter hooks based on names_filter
15491688
for hook_name, hook in hook_dict.items():

0 commit comments

Comments
 (0)