Skip to content

Commit 3355912

Browse files
made sure to check for nested hooks (#1035)
* made sure to check for nested hooks * removed extra check * Skip original_model in hook scanning * Remove extra traversal through general components submodules * Remove adding aliases to hook registry * Fix typing error * Fix typing error * Remove constant copying of hook dictionary to save memory --------- Co-authored-by: degenfabian <[email protected]>
1 parent 7c8d9bb commit 3355912

File tree

3 files changed

+48
-96
lines changed

3 files changed

+48
-96
lines changed

transformer_lens/model_bridge/bridge.py

Lines changed: 39 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -131,28 +131,6 @@ def _initialize_hook_registry(self) -> None:
131131
# Scan existing components for hooks
132132
self._scan_existing_hooks(self, "")
133133

134-
# Add bridge aliases if compatibility mode is enabled
135-
if self.compatibility_mode and self.hook_aliases:
136-
for alias_name, target in self.hook_aliases.items():
137-
# Use the existing alias system to resolve the target hook
138-
# Convert to Dict[str, str] for resolve_alias if target_name is a list
139-
if isinstance(target, list):
140-
# For list targets, try each one until one works
141-
for single_target in target:
142-
try:
143-
target_hook = resolve_alias(
144-
self, alias_name, {alias_name: single_target}
145-
)
146-
if target_hook is not None:
147-
self._hook_registry[alias_name] = target_hook
148-
break
149-
except AttributeError:
150-
continue
151-
else:
152-
target_hook = resolve_alias(self, alias_name, {alias_name: target})
153-
if target_hook is not None:
154-
self._hook_registry[alias_name] = target_hook
155-
156134
self._hook_registry_initialized = True
157135

158136
def _scan_existing_hooks(self, module: nn.Module, prefix: str = "") -> None:
@@ -180,41 +158,39 @@ def scan_module(mod: nn.Module, path: str = "") -> None:
180158
except Exception:
181159
# If get_hooks() fails, fall through to the else block
182160
pass
183-
else:
184-
# Fall back to scanning attributes for non-GeneralizedComponent modules
185-
for attr_name in dir(mod):
186-
if attr_name.startswith("_"):
187-
continue
188-
if attr_name == "original_component":
189-
continue
190161

191-
try:
192-
attr = getattr(mod, attr_name)
193-
except Exception:
194-
continue
195-
196-
name = f"{path}.{attr_name}" if path else attr_name
197-
198-
if isinstance(attr, HookPoint):
199-
attr.name = name
200-
self._hook_registry[name] = attr
201-
elif isinstance(attr, HookPointWrapper):
202-
hook_in_name = f"{name}.hook_in"
203-
hook_out_name = f"{name}.hook_out"
204-
attr.hook_in.name = hook_in_name
205-
attr.hook_out.name = hook_out_name
206-
self._hook_registry[hook_in_name] = attr.hook_in
207-
self._hook_registry[hook_out_name] = attr.hook_out
208-
elif isinstance(attr, nn.Module) and attr is not mod:
209-
scan_module(attr, name)
210-
elif isinstance(attr, (list, tuple)):
211-
for i, item in enumerate(attr):
212-
if isinstance(item, nn.Module):
213-
scan_module(item, f"{name}[{i}]")
162+
# Always scan attributes for additional hooks and submodules
163+
for attr_name in dir(mod):
164+
if attr_name.startswith("_"):
165+
continue
166+
if attr_name == "original_component" or "original_model":
167+
continue
168+
169+
try:
170+
attr = getattr(mod, attr_name)
171+
except Exception:
172+
continue
173+
174+
name = f"{path}.{attr_name}" if path else attr_name
175+
176+
if isinstance(attr, HookPoint):
177+
attr.name = name
178+
self._hook_registry[name] = attr
179+
elif isinstance(attr, HookPointWrapper):
180+
hook_in_name = f"{name}.hook_in"
181+
hook_out_name = f"{name}.hook_out"
182+
attr.hook_in.name = hook_in_name
183+
attr.hook_out.name = hook_out_name
184+
self._hook_registry[hook_in_name] = attr.hook_in
185+
self._hook_registry[hook_out_name] = attr.hook_out
214186

215187
# Check named children
216188
for child_name, child_module in mod.named_children():
217-
if child_name == "original_component" or child_name == "_original_component":
189+
if (
190+
child_name == "original_component"
191+
or child_name == "_original_component"
192+
or child_name == "original_model"
193+
):
218194
continue
219195
child_path = f"{path}.{child_name}" if path else child_name
220196
scan_module(child_module, child_path)
@@ -225,23 +201,7 @@ def scan_module(mod: nn.Module, path: str = "") -> None:
225201
def hook_dict(self) -> dict[str, HookPoint]:
226202
"""Get all HookPoint objects in the model for compatibility with HookedTransformer."""
227203
# Start with the current registry
228-
hooks = self._hook_registry.copy()
229-
230-
# Add aliases if compatibility mode is enabled
231-
if self.compatibility_mode:
232-
for alias_name, target in self.hook_aliases.items():
233-
# Handle both string and list target names
234-
if isinstance(target, list):
235-
# For list targets, find the first one that exists in hooks
236-
for single_target in target:
237-
if single_target in hooks:
238-
hooks[alias_name] = hooks[single_target]
239-
break
240-
else:
241-
if target in hooks:
242-
hooks[alias_name] = hooks[target]
243-
244-
return hooks
204+
return self._hook_registry.copy()
245205

246206
def _discover_hooks(self) -> dict[str, HookPoint]:
247207
"""Get all HookPoint objects from the registry (deprecated, use hook_dict)."""
@@ -262,17 +222,10 @@ def __getattr__(self, name: str) -> Any:
262222
return self.__dict__[name]
263223

264224
# Check if this is a hook alias when compatibility mode is enabled
265-
if self.compatibility_mode and name in self.hook_aliases:
266-
target = self.hook_aliases[name]
267-
# Handle both string and list target names
268-
if isinstance(target, list):
269-
# For list targets, find the first one that exists in the registry
270-
for single_target in target:
271-
if single_target in self._hook_registry:
272-
return self._hook_registry[single_target]
273-
else:
274-
if target in self._hook_registry:
275-
return self._hook_registry[target]
225+
if self.compatibility_mode:
226+
resolved_hook = resolve_alias(self, name, self.hook_aliases)
227+
if resolved_hook is not None:
228+
return resolved_hook
276229

277230
return super().__getattr__(name)
278231

@@ -982,7 +935,7 @@ def cache_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
982935
return cache_hook
983936

984937
# Use cached hooks instead of re-discovering them
985-
hook_dict = self.hook_dict
938+
hook_dict = self._hook_registry
986939

987940
# Filter hooks based on names_filter
988941
for hook_name, hook in hook_dict.items():
@@ -1026,7 +979,7 @@ def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
1026979

1027980
# Add hook to the output of the last layer to be processed
1028981
block_hook_name = f"blocks.{last_layer_to_process}.hook_out"
1029-
hook_dict = self.hook_dict
982+
hook_dict = self._hook_registry
1030983
if block_hook_name in hook_dict:
1031984
hook_dict[block_hook_name].add_hook(stop_hook)
1032985
hooks.append((hook_dict[block_hook_name], block_hook_name))
@@ -1167,7 +1120,7 @@ def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
11671120

11681121
# Add hook to the output of the last layer to be processed
11691122
block_hook_name = f"blocks.{last_layer_to_process}.hook_out"
1170-
hook_dict = self.hook_dict
1123+
hook_dict = self._hook_registry
11711124
if block_hook_name in hook_dict:
11721125
add_hook_to_point(hook_dict[block_hook_name], stop_hook, block_hook_name)
11731126

@@ -1197,7 +1150,7 @@ def wrapped_hook_fn(tensor, hook):
11971150

11981151
if isinstance(hook_name_or_filter, str):
11991152
# Direct hook name - check for aliases first
1200-
hook_dict = self.hook_dict
1153+
hook_dict = self._hook_registry
12011154
actual_hook_name = hook_name_or_filter
12021155

12031156
# If this is an alias, resolve it to the actual hook name
@@ -1208,7 +1161,7 @@ def wrapped_hook_fn(tensor, hook):
12081161
add_hook_to_point(hook_dict[actual_hook_name], hook_fn, actual_hook_name)
12091162
else:
12101163
# Filter function
1211-
hook_dict = self.hook_dict
1164+
hook_dict = self._hook_registry
12121165
for name, hook_point in hook_dict.items():
12131166
if hook_name_or_filter(name):
12141167
add_hook_to_point(hook_point, hook_fn, name)

transformer_lens/model_bridge/generalized_components/base.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,6 @@ def __init__(
6969
self.hook_in.hook_conversion = self.conversion_rule
7070
self.hook_out.hook_conversion = self.conversion_rule
7171

72-
# Register the standard hooks
73-
self._register_hook("hook_in", self.hook_in)
74-
self._register_hook("hook_out", self.hook_out)
75-
7672
def _register_hook(self, name: str, hook: HookPoint) -> None:
7773
"""Register a hook in the component's hook registry."""
7874
# Set the name on the HookPoint
@@ -82,17 +78,20 @@ def _register_hook(self, name: str, hook: HookPoint) -> None:
8278

8379
def get_hooks(self) -> Dict[str, HookPoint]:
8480
"""Get all hooks registered in this component."""
85-
hooks = self._hook_registry.copy()
8681

8782
# Add aliases if compatibility mode is enabled
8883
if self.compatibility_mode and self.hook_aliases:
84+
# Only copy hook registry if compatibility mode is enabled to save memory
85+
hooks = self._hook_registry.copy()
86+
8987
for alias_name, target_name in self.hook_aliases.items():
9088
# Use the existing alias system to resolve the target hook
9189
target_hook = resolve_alias(self, alias_name, self.hook_aliases)
9290
if target_hook is not None:
9391
hooks[alias_name] = target_hook
94-
95-
return hooks
92+
return hooks
93+
else:
94+
return self._hook_registry
9695

9796
def _is_getattr_called_internally(self) -> bool:
9897
"""This function checks if the __getattr__ method was being called internally

transformer_lens/utilities/aliases.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"""Utilities for handling hook aliases in the bridge system."""
22

33
import warnings
4-
from typing import Any, Dict, Optional, Set
4+
from typing import Any, Dict, List, Optional, Set, Union
55

66

77
def resolve_alias(
88
target_object: Any,
99
requested_name: str,
10-
aliases: Dict[str, str],
10+
aliases: Dict[str, str] | Dict[str, Union[str, List[str]]],
1111
) -> Optional[Any]:
1212
"""Resolve a hook alias to the actual hook object.
1313
@@ -76,7 +76,7 @@ def _resolve_single_target(target_name: str) -> Any:
7676

7777

7878
def _collect_aliases_from_module(
79-
module: Any, path: str, aliases: Dict[str, str], visited: Set[int]
79+
module: Any, path: str, aliases: Dict[str, str], visited: Set[int] = set()
8080
) -> None:
8181
"""Helper function to collect all aliases from a single module.
8282
Args:

0 commit comments

Comments
 (0)