Skip to content

Commit 98d649a

Browse files
authored
Attention shape normalization (#997)
* created input output reshape hook functionality * refactor weight conversion to hook conversion * ran format * moved conversion utils * added auto attention attempt * added revert function * ran format * fixed type checks * fixed tests * ran format * ran format * fixed test * replaced config creation with more robust object * restored init file * removed type checking * updated hook to match what it should have been * made pattern rules a param * updated jqv component to allow conversion to be passed through * made c_attn sub module * fixed type issues * fixed test * simplified config flow * fixed test * set attention output * removed output attentions config var
1 parent f244eff commit 98d649a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+664
-127
lines changed

tests/integration/model_bridge/test_bridge_integration.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,5 +122,164 @@ def test_component_access():
122122
assert hasattr(block, "ln2"), "Block should have second layer norm"
123123

124124

125+
def test_joint_qkv_custom_conversion_rule():
126+
"""Test that custom QKV conversion rules can be passed to JointQKVAttentionBridge."""
127+
from transformer_lens.conversion_utils.conversion_steps.rearrange_hook_conversion import (
128+
RearrangeHookConversion,
129+
)
130+
from transformer_lens.model_bridge.generalized_components.joint_qkv_attention import (
131+
JointQKVAttentionBridge,
132+
)
133+
from transformer_lens.model_bridge.generalized_components.linear import LinearBridge
134+
135+
model_name = "gpt2" # Use a smaller model for testing
136+
bridge = TransformerBridge.boot_transformers(model_name)
137+
138+
# Create a custom QKV conversion rule
139+
custom_qkv_conversion = RearrangeHookConversion(
140+
"batch seq (num_attention_heads d_head) -> batch seq num_attention_heads d_head",
141+
num_attention_heads=12, # GPT-2 small has 12 heads
142+
)
143+
144+
# Create QKV config
145+
qkv_config = {
146+
"split_qkv_matrix": lambda x: (x, x, x), # Dummy function for test
147+
}
148+
149+
# Create submodules
150+
submodules = {
151+
"qkv": LinearBridge(name="c_attn"),
152+
"o": LinearBridge(name="c_proj"),
153+
}
154+
155+
# This should not raise an error
156+
test_bridge = JointQKVAttentionBridge(
157+
name="test_joint_qkv",
158+
model_config=bridge.cfg,
159+
submodules=submodules,
160+
qkv_config=qkv_config,
161+
qkv_conversion_rule=custom_qkv_conversion,
162+
)
163+
164+
# Verify the custom conversion rule was set on Q, K, V components
165+
assert (
166+
test_bridge.q.hook_out.hook_conversion is custom_qkv_conversion
167+
), "Custom QKV conversion rule should be set on Q"
168+
assert (
169+
test_bridge.k.hook_out.hook_conversion is custom_qkv_conversion
170+
), "Custom QKV conversion rule should be set on K"
171+
assert (
172+
test_bridge.v.hook_out.hook_conversion is custom_qkv_conversion
173+
), "Custom QKV conversion rule should be set on V"
174+
175+
176+
def test_attention_pattern_hook_shape_custom_conversion():
177+
"""Test that custom pattern conversion rules can be passed to attention components."""
178+
from transformer_lens.conversion_utils.conversion_steps.rearrange_hook_conversion import (
179+
RearrangeHookConversion,
180+
)
181+
182+
model_name = "gpt2" # Use a smaller model for testing
183+
bridge = TransformerBridge.boot_transformers(model_name)
184+
185+
if bridge.tokenizer.pad_token is None:
186+
bridge.tokenizer.pad_token = bridge.tokenizer.eos_token
187+
188+
# Create a custom conversion rule (this is just for testing the parameter passing)
189+
custom_conversion = RearrangeHookConversion(
190+
"batch n_heads pos_q pos_k -> batch n_heads pos_q pos_k" # Same as default but explicitly set
191+
)
192+
193+
# Verify that the attention bridge accepts the custom conversion parameter
194+
# We can't easily test this with the existing bridge without recreating it,
195+
# but we can at least verify the parameter is accepted without error
196+
from transformer_lens.model_bridge.generalized_components.attention import (
197+
AttentionBridge,
198+
)
199+
200+
# This should not raise an error
201+
test_bridge = AttentionBridge(
202+
name="test_attn", config=bridge.cfg, pattern_conversion_rule=custom_conversion
203+
)
204+
205+
# Verify the conversion rule was set
206+
assert (
207+
test_bridge.hook_pattern.hook_conversion is custom_conversion
208+
), "Custom conversion rule should be set"
209+
210+
211+
def test_attention_pattern_hook_shape():
212+
"""Test that the attention pattern hook produces the correct shape (batch, n_heads, pos, pos)."""
213+
model_name = "gpt2" # Use a smaller model for testing
214+
bridge = TransformerBridge.boot_transformers(
215+
model_name,
216+
hf_config_overrides={
217+
"attn_implementation": "eager",
218+
},
219+
)
220+
221+
if bridge.tokenizer.pad_token is None:
222+
bridge.tokenizer.pad_token = bridge.tokenizer.eos_token
223+
224+
# Attention output enabled via hf_config_overrides
225+
226+
# Variable to store captured attention patterns
227+
captured_patterns = {}
228+
229+
def capture_pattern_hook(tensor, hook):
230+
"""Hook to capture attention patterns."""
231+
captured_patterns[hook.name] = tensor.clone()
232+
return tensor
233+
234+
# Add hook to capture attention patterns
235+
bridge.blocks[0].attn.hook_pattern.add_hook(capture_pattern_hook)
236+
237+
try:
238+
# Run model with a prompt
239+
prompt = "The quick brown fox"
240+
tokens = bridge.to_tokens(prompt)
241+
batch_size, seq_len = tokens.shape
242+
243+
# Run forward pass
244+
output = bridge(tokens)
245+
246+
# Verify we captured attention patterns
247+
assert len(captured_patterns) > 0, "Should have captured attention patterns"
248+
249+
# Get the captured pattern tensor
250+
pattern_tensor = list(captured_patterns.values())[0]
251+
252+
# Verify the shape is (batch, n_heads, pos, pos)
253+
assert (
254+
len(pattern_tensor.shape) == 4
255+
), f"Pattern tensor should be 4D, got {len(pattern_tensor.shape)}D"
256+
257+
batch_dim, n_heads_dim, pos_q_dim, pos_k_dim = pattern_tensor.shape
258+
259+
# Verify dimensions make sense
260+
assert batch_dim == batch_size, f"Batch dimension should be {batch_size}, got {batch_dim}"
261+
assert (
262+
n_heads_dim == bridge.cfg.n_heads
263+
), f"Heads dimension should be {bridge.cfg.n_heads}, got {n_heads_dim}"
264+
assert (
265+
pos_q_dim == seq_len
266+
), f"Query position dimension should be {seq_len}, got {pos_q_dim}"
267+
assert pos_k_dim == seq_len, f"Key position dimension should be {seq_len}, got {pos_k_dim}"
268+
269+
# Verify it's actually attention weights (should be non-negative and roughly sum to 1 along last dim)
270+
assert torch.all(pattern_tensor >= 0), "Attention patterns should be non-negative"
271+
272+
# Check that attention weights roughly sum to 1 along the last dimension (with some tolerance for numerical precision)
273+
attention_sums = pattern_tensor.sum(dim=-1)
274+
expected_sums = torch.ones_like(attention_sums)
275+
assert torch.allclose(
276+
attention_sums, expected_sums, atol=1e-5
277+
), "Attention patterns should sum to ~1 along key dimension"
278+
279+
finally:
280+
# Clean up hooks
281+
bridge.blocks[0].attn.hook_pattern.remove_hooks()
282+
283+
125284
if __name__ == "__main__":
126285
pytest.main([__file__])

tests/integration/model_bridge/test_bridge_root_module_cache_compatibility.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,15 @@
33
MODEL = "gpt2"
44

55
prompt = "Hello World!"
6-
bridge = TransformerBridge.boot_transformers(MODEL, device="cpu")
6+
bridge = TransformerBridge.boot_transformers(
7+
MODEL,
8+
device="cpu",
9+
hf_config_overrides={
10+
"attn_implementation": "eager",
11+
},
12+
)
13+
14+
# Attention output enabled via hf_config_overrides
715

816
act_names_in_cache = [
917
# "hook_embed",

tests/mocks/architecture_adapter.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
"""Mock architecture adapter for testing."""
2+
from types import SimpleNamespace
3+
24
import pytest
35
import torch.nn as nn
46

@@ -18,13 +20,11 @@ class MockArchitectureAdapter(ArchitectureAdapter):
1820
def __init__(self, cfg=None):
1921
if cfg is None:
2022
# Create a minimal config for testing
21-
cfg = type(
22-
"MockConfig",
23-
(),
24-
{"d_mlp": 512, "intermediate_size": 512, "default_prepend_bos": True},
25-
)()
23+
cfg = SimpleNamespace(d_mlp=512, intermediate_size=512, default_prepend_bos=True)
2624
super().__init__(cfg)
2725
# Use actual bridge instances instead of tuples
26+
# Provide minimal config to components that require it
27+
attn_cfg = SimpleNamespace(n_heads=1)
2828
self.component_mapping = {
2929
"embed": EmbeddingBridge(name="embed"),
3030
"unembed": EmbeddingBridge(name="unembed"),
@@ -34,7 +34,7 @@ def __init__(self, cfg=None):
3434
submodules={
3535
"ln1": NormalizationBridge(name="ln1"),
3636
"ln2": NormalizationBridge(name="ln2"),
37-
"attn": AttentionBridge(name="attn"),
37+
"attn": AttentionBridge(name="attn", config=attn_cfg),
3838
"mlp": MLPBridge(name="mlp"),
3939
},
4040
),

tests/unit/model_bridge/test_bridge.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
and other bridge operations.
55
"""
66

7+
from types import SimpleNamespace
78
from unittest.mock import MagicMock
89

910
import pytest
@@ -42,7 +43,9 @@ def mock_get_component(model, path):
4243
comp.set_original_component(model.unembed)
4344
return comp
4445
elif "blocks" in path and "attn" in path:
45-
comp = AttentionBridge(name="attn")
46+
# Minimal config with n_heads for AttentionBridge
47+
attn_cfg = SimpleNamespace(n_heads=1)
48+
comp = AttentionBridge(name="attn", config=attn_cfg)
4649
comp.set_original_component(model.blocks[0].attn)
4750
return comp
4851
elif "blocks" in path and "mlp" in path:
@@ -99,7 +102,7 @@ def test_format_block_mapping_tuple(self):
99102
submodules={
100103
"ln1": NormalizationBridge(name="ln1"),
101104
"ln2": NormalizationBridge(name="ln2"),
102-
"attn": AttentionBridge(name="attn"),
105+
"attn": AttentionBridge(name="attn", config=SimpleNamespace(n_heads=1)),
103106
"mlp": MLPBridge(name="mlp"),
104107
},
105108
)
@@ -124,7 +127,7 @@ def test_format_mixed_mapping(self):
124127
name="blocks",
125128
submodules={
126129
"ln1": NormalizationBridge(name="ln1"),
127-
"attn": AttentionBridge(name="attn"),
130+
"attn": AttentionBridge(name="attn", config=SimpleNamespace(n_heads=1)),
128131
},
129132
),
130133
"ln_final": NormalizationBridge(name="ln_final"),
@@ -145,7 +148,7 @@ def test_format_with_prepend_path(self):
145148
"""Test formatting with prepend path parameter."""
146149
mapping = {
147150
"ln1": NormalizationBridge(name="ln1"),
148-
"attn": AttentionBridge(name="attn"),
151+
"attn": AttentionBridge(name="attn", config=SimpleNamespace(n_heads=1)),
149152
}
150153
# To test prepending, we need a parent structure in the component mapping
151154
self.bridge.adapter.component_mapping = {

tests/unit/model_bridge/test_component_setup.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Tests for component setup utilities."""
22

33

4+
from types import SimpleNamespace
5+
46
import pytest
57
import torch.nn as nn
68

@@ -55,6 +57,7 @@ def test_setup_submodules_basic(self, mock_model_adapter):
5557
# Create a component with submodules
5658
component = AttentionBridge(
5759
name="self_attn",
60+
config=SimpleNamespace(n_heads=1),
5861
submodules={
5962
"q_proj": EmbeddingBridge(name="q_proj"),
6063
"k_proj": EmbeddingBridge(name="k_proj"),
@@ -82,10 +85,13 @@ def test_setup_submodules_nested(self):
8285

8386
# Create a component with nested submodules
8487
inner_component = AttentionBridge(
85-
name="q_proj", submodules={} # This should match a real path
88+
name="q_proj",
89+
config=SimpleNamespace(n_heads=1),
90+
submodules={}, # This should match a real path
8691
)
8792
component = AttentionBridge(
8893
name="attn",
94+
config=SimpleNamespace(n_heads=1),
8995
submodules={
9096
"q_proj": inner_component,
9197
},
@@ -169,7 +175,7 @@ def test_setup_blocks_bridge(self):
169175
submodules={
170176
"ln1": NormalizationBridge(name="ln1"),
171177
"ln2": NormalizationBridge(name="ln2"),
172-
"attn": AttentionBridge(name="attn"),
178+
"attn": AttentionBridge(name="attn", config=SimpleNamespace(n_heads=1)),
173179
"mlp": MLPBridge(name="mlp"),
174180
},
175181
)
@@ -240,7 +246,7 @@ def __init__(self):
240246
submodules={
241247
"ln1": NormalizationBridge(name="ln1"),
242248
"ln2": NormalizationBridge(name="ln2"),
243-
"attn": AttentionBridge(name="attn"),
249+
"attn": AttentionBridge(name="attn", config=SimpleNamespace(n_heads=1)),
244250
"mlp": MLPBridge(name="mlp"),
245251
},
246252
),

tests/unit/model_bridge/test_end_to_end_bridge.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""End-to-end tests for the TransformerBridge."""
2+
from types import SimpleNamespace
23
from unittest.mock import MagicMock
34

45
import torch.nn as nn
@@ -35,7 +36,7 @@ def test_bridge_creation_and_component_access(self):
3536
name="encoder.layers",
3637
submodules={
3738
"ln1": NormalizationBridge(name="norm1"),
38-
"attn": AttentionBridge(name="self_attn"),
39+
"attn": AttentionBridge(name="self_attn", config=SimpleNamespace(n_heads=1)),
3940
},
4041
),
4142
}

tests/unit/test_hook_points.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@ def test_add_hook_with_level(mock_handle):
4444
assert hook_point.fwd_hooks[0].context_level == 5
4545

4646

47+
@mock.patch("transformer_lens.hook_points.LensHandle")
4748
@mock.patch("torch.utils.hooks.RemovableHandle")
48-
def test_add_hook_prepend(mock_handle):
49+
def test_add_hook_prepend(mock_handle, mock_lens_handle):
4950
mock_handle.id = 0
5051
mock_handle.next_id = 1
5152

@@ -57,6 +58,26 @@ def hook1(activation, hook):
5758
def hook2(activation, hook):
5859
return activation
5960

61+
# Make LensHandle constructor return a simple container capturing the pt_handle ('hook')
62+
class _LensHandleBox:
63+
def __init__(self, handle, is_permanent, context_level):
64+
self.hook = handle
65+
self.is_permanent = is_permanent
66+
self.context_level = context_level
67+
68+
mock_lens_handle.side_effect = _LensHandleBox
69+
70+
# Override register_forward_hook to return mocked handles with incremental ids
71+
next_id = {"val": 1}
72+
73+
def fake_register_forward_hook(fn, prepend=False):
74+
handle = mock.MagicMock()
75+
handle.id = next_id["val"]
76+
next_id["val"] += 1
77+
return handle
78+
79+
hook_point.register_forward_hook = fake_register_forward_hook # type: ignore[assignment]
80+
6081
hook_point.add_hook(hook1, dir="fwd")
6182
hook_point.add_hook(hook2, dir="fwd", prepend=True)
6283

transformer_lens/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from . import utilities
33
from . import hook_points
44
from . import evals
5+
from . import conversion_utils
56
from .past_key_value_caching import (
67
HookedTransformerKeyValueCache,
78
HookedTransformerKeyValueCacheEntry,

0 commit comments

Comments
 (0)