Skip to content

Commit 5dd54d9

Browse files
authored
added cache layer for hook collection (#1032)
* added cache layer for hook collection * added hook registry * merged setattr * fixed type issue * made sure aliases were used during hook registration from generalized components * resolved aliased hooks proplery * resolved remaining hook alias issues
1 parent 4bed46d commit 5dd54d9

File tree

5 files changed

+189
-163
lines changed

5 files changed

+189
-163
lines changed

tests/integration/model_bridge/test_bridge_integration.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ def test_cache():
110110
model_name = "gpt2" # Use a smaller model for testing
111111
bridge = TransformerBridge.boot_transformers(model_name)
112112

113+
# Enable compatibility mode to include hook aliases
114+
bridge.enable_compatibility_mode(disable_warnings=True)
115+
113116
if bridge.tokenizer.pad_token is None:
114117
bridge.tokenizer.pad_token = bridge.tokenizer.eos_token
115118

tests/integration/model_bridge/test_qkv_hook_compatibility.py

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -120,72 +120,3 @@ def test_hook_aliases_work_correctly(self):
120120
assert qkv_bridge.q_hook_out is bridge.blocks[0].attn.q.hook_out, "Q property should work"
121121
assert qkv_bridge.k_hook_out is bridge.blocks[0].attn.k.hook_out, "K property should work"
122122
assert qkv_bridge.v_hook_out is bridge.blocks[0].attn.v.hook_out, "V property should work"
123-
124-
def test_head_ablation_hook_works_correctly(self):
125-
"""Test that head ablation hook works correctly with TransformerBridge."""
126-
# Load GPT-2 in TransformerBridge
127-
bridge = TransformerBridge.boot_transformers("gpt2", device="cpu")
128-
129-
# Turn on compatibility mode
130-
bridge.enable_compatibility_mode(disable_warnings=True)
131-
132-
# Create test tokens (same as in the demo)
133-
gpt2_tokens = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
134-
135-
layer_to_ablate = 0
136-
head_index_to_ablate = 8
137-
138-
# Test both hook names
139-
hook_names_to_test = [
140-
"blocks.0.attn.hook_v", # Compatibility mode alias
141-
"blocks.0.attn.v.hook_out", # Direct property access
142-
]
143-
144-
for hook_name in hook_names_to_test:
145-
print(f"\nTesting hook name: {hook_name}")
146-
147-
# Track if the hook was called
148-
hook_called = False
149-
mutation_applied = False
150-
151-
# We define a head ablation hook
152-
def head_ablation_hook(value, hook):
153-
nonlocal hook_called, mutation_applied
154-
hook_called = True
155-
print(f"Shape of the value tensor: {value.shape}")
156-
157-
# Apply the ablation (out-of-place to avoid view modification error)
158-
result = value.clone()
159-
result[:, :, head_index_to_ablate, :] = 0.0
160-
161-
# Check if the mutation was applied (the result should be zero for the ablated head)
162-
if torch.all(result[:, :, head_index_to_ablate, :] == 0.0):
163-
mutation_applied = True
164-
165-
return result
166-
167-
# Get original loss
168-
original_loss = bridge(gpt2_tokens, return_type="loss")
169-
170-
# Run with head ablation hook
171-
ablated_loss = bridge.run_with_hooks(
172-
gpt2_tokens, return_type="loss", fwd_hooks=[(hook_name, head_ablation_hook)]
173-
)
174-
175-
print(f"Original Loss: {original_loss.item():.3f}")
176-
print(f"Ablated Loss: {ablated_loss.item():.3f}")
177-
178-
# Assert that the hook was called
179-
assert hook_called, f"Head ablation hook should have been called for {hook_name}"
180-
181-
# Assert that the mutation was applied
182-
assert (
183-
mutation_applied
184-
), f"Mutation should have been applied to the value tensor for {hook_name}"
185-
186-
# Assert that ablated loss is higher than original loss (ablation should hurt performance)
187-
assert (
188-
ablated_loss.item() > original_loss.item()
189-
), f"Ablated loss should be higher than original loss for {hook_name}"
190-
191-
print(f"✅ Hook {hook_name} works correctly!")

transformer_lens/model_bridge/bridge.py

Lines changed: 150 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
)
3636
from transformer_lens.model_bridge.hook_point_wrapper import HookPointWrapper
3737
from transformer_lens.model_bridge.types import ComponentMapping
38-
from transformer_lens.utilities.aliases import collect_aliases_recursive
38+
from transformer_lens.utilities.aliases import collect_aliases_recursive, resolve_alias
3939

4040
if TYPE_CHECKING:
4141
from transformer_lens.ActivationCache import ActivationCache
@@ -54,6 +54,7 @@ class TransformerBridge(nn.Module):
5454
hook_aliases = {
5555
"hook_embed": "embed.hook_out",
5656
"hook_pos_embed": "pos_embed.hook_out",
57+
"hook_unembed": "unembed.hook_out",
5758
}
5859

5960
def __init__(self, model: nn.Module, adapter: ArchitectureAdapter, tokenizer: Any):
@@ -70,6 +71,11 @@ def __init__(self, model: nn.Module, adapter: ArchitectureAdapter, tokenizer: An
7071
self.cfg = adapter.cfg
7172
self.tokenizer = tokenizer
7273
self.compatibility_mode = False
74+
self._hook_cache = None # Cache for hook discovery results
75+
self._hook_registry: Dict[
76+
str, HookPoint
77+
] = {} # Dynamic registry of hook names to HookPoints
78+
self._hook_registry_initialized = False # Track if registry has been initialized
7379

7480
# Add device information to config from the loaded model
7581
if not hasattr(self.cfg, "device"):
@@ -84,68 +90,159 @@ def __init__(self, model: nn.Module, adapter: ArchitectureAdapter, tokenizer: An
8490
# Set original components on the pre-created bridge components
8591
set_original_components(self, self.adapter, self.original_model)
8692

87-
@property
88-
def hook_dict(self) -> dict[str, HookPoint]:
89-
"""Get all HookPoint objects in the model for compatibility with HookedTransformer."""
90-
hooks = {}
91-
visited = set() # Move visited set outside the recursive function
92-
93-
def collect_hookpoints(module: nn.Module, prefix: str = "") -> None:
94-
"""Recursively collect all HookPoint objects."""
95-
obj_id = id(module)
93+
# Initialize hook registry after components are set up
94+
self._initialize_hook_registry()
95+
96+
def __setattr__(self, name: str, value: Any) -> None:
97+
"""Override setattr to track HookPoint objects dynamically."""
98+
# Call parent setattr first
99+
super().__setattr__(name, value)
100+
101+
# Check if this is a HookPoint being set
102+
if isinstance(value, HookPoint):
103+
# Set the name on the HookPoint
104+
value.name = name
105+
# Add to registry
106+
self._hook_registry[name] = value
107+
elif isinstance(value, HookPointWrapper):
108+
# Handle HookPointWrapper objects
109+
hook_in_name = f"{name}.hook_in"
110+
hook_out_name = f"{name}.hook_out"
111+
value.hook_in.name = hook_in_name
112+
value.hook_out.name = hook_out_name
113+
self._hook_registry[hook_in_name] = value.hook_in
114+
self._hook_registry[hook_out_name] = value.hook_out
115+
elif hasattr(value, "get_hooks") and callable(getattr(value, "get_hooks")):
116+
# This is a GeneralizedComponent being set
117+
# We need to register its hooks with the appropriate prefix
118+
component_hooks = value.get_hooks()
119+
for hook_name, hook in component_hooks.items():
120+
full_name = f"{name}.{hook_name}"
121+
hook.name = full_name
122+
self._hook_registry[full_name] = hook
123+
124+
def _initialize_hook_registry(self) -> None:
125+
"""Initialize the hook registry by scanning existing components."""
126+
if self._hook_registry_initialized:
127+
return
128+
129+
# Scan existing components for hooks
130+
self._scan_existing_hooks(self, "")
131+
132+
# Add bridge aliases if compatibility mode is enabled
133+
if self.compatibility_mode and self.hook_aliases:
134+
for alias_name, target_name in self.hook_aliases.items():
135+
# Use the existing alias system to resolve the target hook
136+
target_hook = resolve_alias(self, alias_name, self.hook_aliases)
137+
if target_hook is not None:
138+
self._hook_registry[alias_name] = target_hook
139+
140+
self._hook_registry_initialized = True
141+
142+
def _scan_existing_hooks(self, module: nn.Module, prefix: str = "") -> None:
143+
"""Scan existing modules for hooks and add them to registry."""
144+
visited = set()
145+
146+
def scan_module(mod: nn.Module, path: str = "") -> None:
147+
obj_id = id(mod)
96148
if obj_id in visited:
97149
return
98150
visited.add(obj_id)
99151

100-
for attr_name in dir(module):
101-
if attr_name.startswith("_"):
102-
continue
103-
# Skip original_component to avoid deep traversal
104-
if attr_name == "original_component":
105-
continue
152+
# Check if this is a GeneralizedComponent with its own hook registry
153+
if hasattr(mod, "get_hooks") and callable(getattr(mod, "get_hooks")):
154+
# Use the component's own hook registry
106155
try:
107-
attr = getattr(module, attr_name)
156+
component_hooks = mod.get_hooks() # type: ignore
157+
if isinstance(component_hooks, dict):
158+
# Type cast to help mypy understand this is a dict of hooks
159+
hooks_dict = cast(Dict[str, HookPoint], component_hooks) # type: ignore
160+
for hook_name, hook in hooks_dict.items(): # type: ignore
161+
full_name = f"{path}.{hook_name}" if path else hook_name
162+
hook.name = full_name
163+
self._hook_registry[full_name] = hook
108164
except Exception:
109-
continue
165+
# If get_hooks() fails, fall through to the else block
166+
pass
167+
else:
168+
# Fall back to scanning attributes for non-GeneralizedComponent modules
169+
for attr_name in dir(mod):
170+
if attr_name.startswith("_"):
171+
continue
172+
if attr_name == "original_component":
173+
continue
110174

111-
name = f"{prefix}.{attr_name}" if prefix else attr_name
112-
if isinstance(attr, HookPoint):
113-
# Set the name on the HookPoint so it can be used in caching
114-
attr.name = name
115-
hooks[name] = attr
116-
elif hasattr(attr, "hook_in") and hasattr(attr, "hook_out"):
117-
# Handle HookPointWrapper objects
118-
if isinstance(attr, HookPointWrapper):
119-
# Add hook_in and hook_out from the wrapper
175+
try:
176+
attr = getattr(mod, attr_name)
177+
except Exception:
178+
continue
179+
180+
name = f"{path}.{attr_name}" if path else attr_name
181+
182+
if isinstance(attr, HookPoint):
183+
attr.name = name
184+
self._hook_registry[name] = attr
185+
elif isinstance(attr, HookPointWrapper):
120186
hook_in_name = f"{name}.hook_in"
121187
hook_out_name = f"{name}.hook_out"
122188
attr.hook_in.name = hook_in_name
123189
attr.hook_out.name = hook_out_name
124-
hooks[hook_in_name] = attr.hook_in
125-
hooks[hook_out_name] = attr.hook_out
126-
elif isinstance(attr, nn.Module) and attr is not module:
127-
collect_hookpoints(attr, name)
128-
elif isinstance(attr, (list, tuple)):
129-
for i, item in enumerate(attr):
130-
if isinstance(item, nn.Module):
131-
collect_hookpoints(item, f"{name}[{i}]")
132-
133-
# Also traverse named_children() to catch ModuleList and other containers
134-
for child_name, child_module in module.named_children():
135-
# Skip original_component and _original_component to avoid deep traversal
190+
self._hook_registry[hook_in_name] = attr.hook_in
191+
self._hook_registry[hook_out_name] = attr.hook_out
192+
elif isinstance(attr, nn.Module) and attr is not mod:
193+
scan_module(attr, name)
194+
elif isinstance(attr, (list, tuple)):
195+
for i, item in enumerate(attr):
196+
if isinstance(item, nn.Module):
197+
scan_module(item, f"{name}[{i}]")
198+
199+
# Check named children
200+
for child_name, child_module in mod.named_children():
136201
if child_name == "original_component" or child_name == "_original_component":
137202
continue
138-
child_path = f"{prefix}.{child_name}" if prefix else child_name
139-
collect_hookpoints(child_module, child_path)
203+
child_path = f"{path}.{child_name}" if path else child_name
204+
scan_module(child_module, child_path)
205+
206+
scan_module(module, prefix)
207+
208+
@property
209+
def hook_dict(self) -> dict[str, HookPoint]:
210+
"""Get all HookPoint objects in the model for compatibility with HookedTransformer."""
211+
# Start with the current registry
212+
hooks = self._hook_registry.copy()
213+
214+
# Add aliases if compatibility mode is enabled
215+
if self.compatibility_mode:
216+
for alias_name, target_name in self.hook_aliases.items():
217+
if target_name in hooks:
218+
hooks[alias_name] = hooks[target_name]
140219

141-
collect_hookpoints(self, "")
142220
return hooks
143221

222+
def _discover_hooks(self) -> dict[str, HookPoint]:
223+
"""Get all HookPoint objects from the registry (deprecated, use hook_dict)."""
224+
return self._hook_registry.copy()
225+
226+
def clear_hook_cache(self) -> None:
227+
"""Clear the cached hook discovery results (deprecated, kept for compatibility)."""
228+
pass # No longer needed since we don't use caching
229+
230+
def clear_hook_registry(self) -> None:
231+
"""Clear the hook registry and force re-initialization."""
232+
self._hook_registry.clear()
233+
self._hook_registry_initialized = False
234+
144235
def __getattr__(self, name: str) -> Any:
145236
"""Provide a clear error message for missing attributes."""
146237
if name in self.__dict__:
147238
return self.__dict__[name]
148239

240+
# Check if this is a hook alias when compatibility mode is enabled
241+
if self.compatibility_mode and name in self.hook_aliases:
242+
target_name = self.hook_aliases[name]
243+
if target_name in self._hook_registry:
244+
return self._hook_registry[target_name]
245+
149246
return super().__getattr__(name)
150247

151248
def _get_nested_attr(self, path: str) -> Any:
@@ -247,6 +344,10 @@ def set_compatibility_mode(component: Any) -> None:
247344

248345
apply_fn_to_all_components(self, set_compatibility_mode)
249346

347+
# Re-initialize the hook registry to include aliases from components
348+
self.clear_hook_registry()
349+
self._initialize_hook_registry()
350+
250351
# ==================== TOKENIZATION METHODS ====================
251352

252353
def to_tokens(
@@ -738,56 +839,14 @@ def cache_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
738839

739840
return cache_hook
740841

741-
# Recursively collect all HookPoint objects
742-
def collect_hookpoints(module: nn.Module, prefix: str = "") -> None:
743-
obj_id = id(module)
744-
if obj_id in visited:
745-
return
746-
visited.add(obj_id)
747-
748-
for attr_name in dir(module):
749-
if attr_name.startswith("_"):
750-
continue
751-
# Skip the original_model to avoid collecting hooks from HuggingFace model
752-
if attr_name == "original_model" or attr_name == "original_component":
753-
continue
754-
try:
755-
attr = getattr(module, attr_name)
756-
except Exception:
757-
continue
758-
759-
def add_hook_to_list(hook: HookPoint, name: str):
760-
# Set the name on the HookPoint so it can be used in caching
761-
hook.name = name
762-
763-
# Only add hook if it passes the names filter
764-
if names_filter_fn(name):
765-
hooks.append((hook, name))
766-
767-
name = f"{prefix}.{attr_name}" if prefix else attr_name
768-
if isinstance(attr, HookPoint):
769-
add_hook_to_list(attr, name)
770-
elif isinstance(attr, HookPointWrapper):
771-
# Add hooks for the wrapped hook points (hook_in and hook_out)
772-
add_hook_to_list(attr.hook_in, f"{name}.hook_in")
773-
add_hook_to_list(attr.hook_out, f"{name}.hook_out")
774-
elif isinstance(attr, nn.Module):
775-
collect_hookpoints(attr, name)
776-
elif isinstance(attr, (list, tuple)):
777-
for i, item in enumerate(attr):
778-
if isinstance(item, nn.Module):
779-
collect_hookpoints(item, f"{name}[{i}]")
780-
781-
# Also traverse named_children() to catch ModuleList and other containers
782-
for child_name, child_module in module.named_children():
783-
child_path = f"{prefix}.{child_name}" if prefix else child_name
784-
# Skip the original_model module
785-
if child_name == "original_model" or child_name == "original_component":
786-
continue
787-
collect_hookpoints(child_module, child_path)
842+
# Use cached hooks instead of re-discovering them
843+
hook_dict = self.hook_dict
788844

789-
# Collect hooks from bridge components (these have the clean TransformerLens paths)
790-
collect_hookpoints(self, "")
845+
# Filter hooks based on names_filter
846+
for hook_name, hook in hook_dict.items():
847+
# Only add hook if it passes the names filter
848+
if names_filter_fn(hook_name):
849+
hooks.append((hook, hook_name))
791850

792851
# Register hooks
793852
for hp, name in hooks:

0 commit comments

Comments
 (0)