Skip to content

Commit 5464167

Browse files
authored
added test and made sure backwards hooks are working (#1058)
* added test and made sure backwards hooks are working * fixed type issue
1 parent 05ad6a7 commit 5464167

File tree

2 files changed

+74
-7
lines changed

2 files changed

+74
-7
lines changed

tests/integration/model_bridge/test_bridge_integration.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytest
1111
import torch
1212

13+
from transformer_lens import HookedTransformer
1314
from transformer_lens.ActivationCache import ActivationCache
1415
from transformer_lens.conversion_utils.conversion_steps.rearrange_hook_conversion import (
1516
RearrangeHookConversion,
@@ -673,5 +674,58 @@ def test_get_params_multi_query_attention_reshaping():
673674
original_attn.v.weight.data = original_v_weight
674675

675676

677+
def test_TransformerBridge_hooks_backward_hooks():
678+
"""Test that TransformerBridge.hooks() correctly registers backward hooks.
679+
680+
This test verifies that TransformerBridge.hooks() properly handles bwd_hooks
681+
and registers them correctly, matching the behavior of HookedTransformer.hooks().
682+
"""
683+
# Create both models with the same configuration
684+
hooked_model = HookedTransformer.from_pretrained_no_processing("gpt2", device_map="cpu")
685+
bridge_model: TransformerBridge = TransformerBridge.boot_transformers("gpt2", device="cpu") # type: ignore
686+
bridge_model.enable_compatibility_mode(no_processing=True)
687+
688+
# Create a simple backward hook that tracks if it was called
689+
hook_called = {"hooked": False, "bridge": False}
690+
691+
def make_test_hook(model_type):
692+
def hook_fn(grad, hook=None):
693+
hook_called[model_type] = True
694+
# For HookedTransformer, the hook doesn't modify the gradient
695+
return None
696+
697+
return hook_fn
698+
699+
# Test input
700+
test_input = torch.tensor([[1, 2, 3]])
701+
702+
# Test HookedTransformer - backward hooks should work
703+
with hooked_model.hooks(bwd_hooks=[("blocks.0.hook_mlp_out", make_test_hook("hooked"))]):
704+
output = hooked_model(test_input)
705+
# Check that the backward hook was registered
706+
assert (
707+
len(hooked_model.blocks[0].hook_mlp_out.bwd_hooks) > 0
708+
), "HookedTransformer should register backward hooks"
709+
710+
# Trigger backward pass
711+
output.sum().backward()
712+
713+
# Test TransformerBridge - backward hooks should now work correctly
714+
# With compatibility mode, TransformerBridge should have the same hook names as HookedTransformer
715+
with bridge_model.hooks(bwd_hooks=[("blocks.0.hook_mlp_out", make_test_hook("bridge"))]):
716+
output = bridge_model(test_input)
717+
# This assertion verifies that backward hooks are now properly registered
718+
assert (
719+
len(bridge_model.blocks[0].hook_mlp_out.bwd_hooks) > 0
720+
), "TransformerBridge should now register backward hooks correctly"
721+
722+
# Backward pass should trigger the hook
723+
output.sum().backward()
724+
725+
# Verify the hooks were called appropriately
726+
assert hook_called["hooked"], "HookedTransformer backward hook should have been called"
727+
assert hook_called["bridge"], "TransformerBridge backward hook should now be called correctly"
728+
729+
676730
if __name__ == "__main__":
677731
pytest.main([__file__])

transformer_lens/model_bridge/bridge.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1846,8 +1846,10 @@ def run_with_hooks(
18461846
# Store hooks that we add so we can remove them later
18471847
added_hooks: List[Tuple[HookPoint, str]] = []
18481848

1849-
def add_hook_to_point(hook_point: HookPoint, hook_fn: Callable, name: str):
1850-
hook_point.add_hook(hook_fn)
1849+
def add_hook_to_point(
1850+
hook_point: HookPoint, hook_fn: Callable, name: str, dir: Literal["fwd", "bwd"] = "fwd"
1851+
):
1852+
hook_point.add_hook(hook_fn, dir=dir)
18511853
added_hooks.append((hook_point, name))
18521854

18531855
# Add stop_at_layer hook if specified
@@ -1868,10 +1870,11 @@ def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
18681870
block_hook_name = f"blocks.{last_layer_to_process}.hook_out"
18691871
hook_dict = self.hook_dict
18701872
if block_hook_name in hook_dict:
1871-
add_hook_to_point(hook_dict[block_hook_name], stop_hook, block_hook_name)
1873+
add_hook_to_point(hook_dict[block_hook_name], stop_hook, block_hook_name, "fwd")
18721874

18731875
# Helper function to apply hooks based on name or filter function
18741876
def apply_hooks(hooks: List[Tuple[Union[str, Callable], Callable]], is_fwd: bool):
1877+
direction: Literal["fwd", "bwd"] = "fwd" if is_fwd else "bwd"
18751878
# Collect aliases for resolving legacy hook names
18761879
aliases = collect_aliases_recursive(self)
18771880

@@ -1904,13 +1907,15 @@ def wrapped_hook_fn(tensor, hook):
19041907
actual_hook_name = aliases[hook_name_or_filter]
19051908

19061909
if actual_hook_name in hook_dict:
1907-
add_hook_to_point(hook_dict[actual_hook_name], hook_fn, actual_hook_name)
1910+
add_hook_to_point(
1911+
hook_dict[actual_hook_name], hook_fn, actual_hook_name, direction
1912+
)
19081913
else:
19091914
# Filter function
19101915
hook_dict = self.hook_dict
19111916
for name, hook_point in hook_dict.items():
19121917
if hook_name_or_filter(name):
1913-
add_hook_to_point(hook_point, hook_fn, name)
1918+
add_hook_to_point(hook_point, hook_fn, name, direction)
19141919

19151920
try:
19161921
# Apply forward hooks
@@ -2330,10 +2335,18 @@ def _hooks_context():
23302335
# Add forward hooks
23312336
for hook_name, hook_fn in fwd_hooks:
23322337
try:
2333-
self.add_hook(hook_name, hook_fn)
2338+
self.add_hook(hook_name, hook_fn, dir="fwd")
2339+
added_hooks.append((hook_name, hook_fn))
2340+
except Exception as e:
2341+
print(f"Warning: Failed to add forward hook {hook_name}: {e}")
2342+
2343+
# Add backward hooks
2344+
for hook_name, hook_fn in bwd_hooks:
2345+
try:
2346+
self.add_hook(hook_name, hook_fn, dir="bwd")
23342347
added_hooks.append((hook_name, hook_fn))
23352348
except Exception as e:
2336-
print(f"Warning: Failed to add hook {hook_name}: {e}")
2349+
print(f"Warning: Failed to add backward hook {hook_name}: {e}")
23372350

23382351
yield
23392352

0 commit comments

Comments
 (0)