Skip to content

Commit 999309a

Browse files
Move QKV separation into bridge that wraps QKV matrix (#1027)
* Move QKV separation to bridge that directly wraps QKV matrix * Fix typing issues * Fix hook collection issues * Ensuring standardized hook shape * Fix syntax error * Run CI again * adjust test to reflect new hook names in qkv bridge * simplify getattr in base component * Add parameter for conversion rule of hook_in and hook_out in qkvbridge * moved hook point wrapper, and added more test coverage --------- Co-authored-by: Bryce Meyer <[email protected]>
1 parent 787d5c8 commit 999309a

File tree

18 files changed

+835
-582
lines changed

18 files changed

+835
-582
lines changed

tests/integration/model_bridge/generalized_components/test_joint_qkv_attention_integration.py renamed to tests/integration/model_bridge/generalized_components/test_qkv_bridge_integration.py

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Lightweight integration tests for JointQKVAttentionBridge.
1+
"""Lightweight integration tests for QKVBridge.
22
33
Tests the core functionality without loading large models to keep CI fast.
44
"""
@@ -7,10 +7,11 @@
77
import torch
88

99
import transformer_lens.utils as utils
10+
from transformer_lens.model_bridge.generalized_components.qkv_bridge import QKVBridge
1011

1112

12-
class TestJointQKVAttentionBridgeIntegration:
13-
"""Minimal integration tests for JointQKVAttentionBridge."""
13+
class TestQKVBridgeIntegration:
14+
"""Minimal integration tests for QKVBridge."""
1415

1516
def test_hook_alias_resolution(self):
1617
"""Test that hook aliases are properly resolved."""
@@ -24,16 +25,37 @@ def test_hook_alias_resolution(self):
2425
assert utils.get_act_name("q", 1) == "blocks.1.attn.hook_q"
2526
assert utils.get_act_name("k", 2) == "blocks.2.attn.hook_k"
2627

27-
def test_component_class_exists(self):
28-
"""Test that JointQKVAttentionBridge class can be imported."""
28+
def test_joint_qkv_attention_bridge_properties(self):
29+
"""Test that JointQKVAttentionBridge properties are properly resolved."""
2930
from transformer_lens.model_bridge.generalized_components.joint_qkv_attention import (
3031
JointQKVAttentionBridge,
3132
)
3233

34+
class TestConfig:
35+
n_heads = 12
36+
37+
qkv_bridge = QKVBridge(name="qkv", config=TestConfig())
38+
39+
qkv_attention_bridge = JointQKVAttentionBridge(
40+
name="blocks.0.attn",
41+
config=TestConfig(),
42+
submodules={"qkv": qkv_bridge},
43+
)
44+
45+
assert qkv_attention_bridge.q.hook_in == qkv_bridge.q_hook_in
46+
assert qkv_attention_bridge.q.hook_out == qkv_bridge.q_hook_out
47+
assert qkv_attention_bridge.k.hook_in == qkv_bridge.k_hook_in
48+
assert qkv_attention_bridge.k.hook_out == qkv_bridge.k_hook_out
49+
assert qkv_attention_bridge.v.hook_in == qkv_bridge.v_hook_in
50+
assert qkv_attention_bridge.v.hook_out == qkv_bridge.v_hook_out
51+
52+
def test_component_class_exists(self):
53+
"""Test that QKVBridge class can be imported."""
54+
3355
# Verify the class exists and has expected methods
34-
assert hasattr(JointQKVAttentionBridge, "forward")
35-
assert hasattr(JointQKVAttentionBridge, "_reconstruct_attention")
36-
assert hasattr(JointQKVAttentionBridge, "_manual_attention_computation")
56+
assert hasattr(QKVBridge, "forward")
57+
assert hasattr(QKVBridge, "_create_qkv_conversion_rule")
58+
assert hasattr(QKVBridge, "_create_qkv_separation_rule")
3759

3860
def test_hook_point_has_hooks_method(self):
3961
"""Test that HookPoint.has_hooks method works correctly."""
@@ -60,9 +82,9 @@ def dummy_hook(x, hook):
6082
assert not hook_point.has_hooks()
6183

6284
def test_architecture_imports(self):
63-
"""Test that architecture files can be imported and reference JointQKVAttentionBridge."""
85+
"""Test that architecture files can be imported and reference QKVBridge."""
6486
# Test that we can import the architecture files without errors
65-
# Test that JointQKVAttentionBridge is referenced in the source files
87+
# Test that QKVBridge is referenced in the source files
6688
import inspect
6789

6890
from transformer_lens.model_bridge.supported_architectures import (
@@ -72,19 +94,13 @@ def test_architecture_imports(self):
7294
)
7395

7496
gpt2_source = inspect.getsource(gpt2)
75-
assert (
76-
"JointQKVAttentionBridge" in gpt2_source
77-
), "GPT-2 architecture should reference JointQKVAttentionBridge"
97+
assert "QKVBridge" in gpt2_source, "GPT-2 architecture should reference QKVBridge"
7898

7999
bloom_source = inspect.getsource(bloom)
80-
assert (
81-
"JointQKVAttentionBridge" in bloom_source
82-
), "BLOOM architecture should reference JointQKVAttentionBridge"
100+
assert "QKVBridge" in bloom_source, "BLOOM architecture should reference QKVBridge"
83101

84102
neox_source = inspect.getsource(neox)
85-
assert (
86-
"JointQKVAttentionBridge" in neox_source
87-
), "NeoX architecture should reference JointQKVAttentionBridge"
103+
assert "QKVBridge" in neox_source, "NeoX architecture should reference QKVBridge"
88104

89105
@pytest.mark.skip(reason="Requires model loading - too slow for CI")
90106
def test_distilgpt2_integration(self):
@@ -96,15 +112,15 @@ def test_distilgpt2_integration(self):
96112
torch.set_grad_enabled(False)
97113
model = TransformerBridge.boot_transformers("distilgpt2", device="cpu")
98114

99-
# Verify JointQKVAttentionBridge usage
100-
joint_qkv_modules = [
115+
# Verify QKVBridge usage
116+
qkv_bridge_modules = [
101117
name
102118
for name, module in model.named_modules()
103-
if "JointQKVAttentionBridge" in getattr(module, "__class__", {}).get("__name__", "")
119+
if "QKVBridge" in getattr(module, "__class__", {}).get("__name__", "")
104120
]
105121
assert (
106-
len(joint_qkv_modules) == 6
107-
), f"Expected 6 JointQKVAttentionBridge modules, got {len(joint_qkv_modules)}"
122+
len(qkv_bridge_modules) == 6
123+
), f"Expected 6 QKVBridge modules, got {len(qkv_bridge_modules)}"
108124

109125
# Test basic functionality
110126
tokens = model.to_tokens("Test")

tests/integration/model_bridge/test_bridge_integration.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -154,14 +154,13 @@ def test_component_access():
154154

155155

156156
def test_joint_qkv_custom_conversion_rule():
157-
"""Test that custom QKV conversion rules can be passed to JointQKVAttentionBridge."""
157+
"""Test that custom QKV conversion rules can be passed to QKVBridge."""
158158
from transformer_lens.conversion_utils.conversion_steps.rearrange_hook_conversion import (
159159
RearrangeHookConversion,
160160
)
161-
from transformer_lens.model_bridge.generalized_components.joint_qkv_attention import (
162-
JointQKVAttentionBridge,
161+
from transformer_lens.model_bridge.generalized_components.qkv_bridge import (
162+
QKVBridge,
163163
)
164-
from transformer_lens.model_bridge.generalized_components.linear import LinearBridge
165164

166165
model_name = "gpt2" # Use a smaller model for testing
167166
bridge = TransformerBridge.boot_transformers(model_name)
@@ -172,36 +171,46 @@ def test_joint_qkv_custom_conversion_rule():
172171
num_attention_heads=12, # GPT-2 small has 12 heads
173172
)
174173

175-
# Create QKV config
176-
qkv_config = {
177-
"split_qkv_matrix": lambda x: (x, x, x), # Dummy function for test
178-
}
179-
180-
# Create submodules
181-
submodules = {
182-
"qkv": LinearBridge(name="c_attn"),
183-
"o": LinearBridge(name="c_proj"),
184-
}
174+
custom_qkv_separation = RearrangeHookConversion(
175+
"batch seq (three d_model) -> three batch seq d_model",
176+
three=3,
177+
)
185178

186179
# This should not raise an error
187-
test_bridge = JointQKVAttentionBridge(
188-
name="test_joint_qkv",
189-
model_config=bridge.cfg,
190-
submodules=submodules,
191-
qkv_config=qkv_config,
180+
test_bridge = QKVBridge(
181+
name="test_qkv_bridge",
182+
config=bridge.cfg,
183+
submodules={},
192184
qkv_conversion_rule=custom_qkv_conversion,
185+
qkv_separation_rule=custom_qkv_separation,
193186
)
194187

195188
# Verify the custom conversion rule was set on Q, K, V components
196189
assert (
197-
test_bridge.q.hook_out.hook_conversion is custom_qkv_conversion
198-
), "Custom QKV conversion rule should be set on Q"
190+
test_bridge.q_hook_in.hook_conversion is custom_qkv_conversion
191+
), "Custom QKV conversion rule should be set on hook_in of Q"
192+
assert (
193+
test_bridge.k_hook_in.hook_conversion is custom_qkv_conversion
194+
), "Custom QKV conversion rule should be set on hook_in of K"
195+
assert (
196+
test_bridge.v_hook_in.hook_conversion is custom_qkv_conversion
197+
), "Custom QKV conversion rule should be set on hook_in of V"
198+
assert (
199+
test_bridge.q_hook_out.hook_conversion is custom_qkv_conversion
200+
), "Custom QKV conversion rule should be set on hook_out of Q"
201+
assert (
202+
test_bridge.k_hook_out.hook_conversion is custom_qkv_conversion
203+
), "Custom QKV conversion rule should be set on hook_out of K"
204+
assert (
205+
test_bridge.v_hook_out.hook_conversion is custom_qkv_conversion
206+
), "Custom QKV conversion rule should be set on hook_out of V"
207+
199208
assert (
200-
test_bridge.k.hook_out.hook_conversion is custom_qkv_conversion
201-
), "Custom QKV conversion rule should be set on K"
209+
test_bridge.qkv_conversion_rule is custom_qkv_conversion
210+
), "Custom QKV conversion rule should be set"
202211
assert (
203-
test_bridge.v.hook_out.hook_conversion is custom_qkv_conversion
204-
), "Custom QKV conversion rule should be set on V"
212+
test_bridge.qkv_separation_rule is custom_qkv_separation
213+
), "Custom QKV separation rule should be set"
205214

206215

207216
def test_attention_pattern_hook_shape_custom_conversion():
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
"""Integration tests for QKV hook compatibility in TransformerBridge."""
2+
3+
import torch
4+
5+
from transformer_lens.model_bridge import TransformerBridge
6+
7+
8+
class TestQKVHookCompatibility:
9+
"""Test that QKV bridge hooks are compatible with overall model hook access."""
10+
11+
def test_v_hook_out_equals_blocks_attn_hook_v(self):
12+
"""Test that v_hook_out in QKV bridge equals blocks.0.attn.hook_v on the overall model."""
13+
# Load GPT-2 in TransformerBridge
14+
bridge = TransformerBridge.boot_transformers("gpt2", device="cpu")
15+
16+
# Turn on compatibility mode
17+
bridge.enable_compatibility_mode(disable_warnings=True)
18+
19+
# Create test input
20+
test_input = torch.tensor([[1, 2, 3, 4, 5]]) # Simple test sequence
21+
22+
# Get the QKV bridge from the first attention layer
23+
qkv_bridge = bridge.blocks[0].attn.qkv
24+
25+
# Verify that qkv_bridge is indeed a QKVBridge
26+
from transformer_lens.model_bridge.generalized_components.qkv_bridge import (
27+
QKVBridge,
28+
)
29+
30+
assert isinstance(qkv_bridge, QKVBridge), "First attention layer should have a QKVBridge"
31+
32+
# Run a forward pass to populate the hooks
33+
with torch.no_grad():
34+
_ = bridge(test_input)
35+
36+
# Assert that v_hook_out in the QKV bridge is the same object as
37+
# blocks.0.attn.hook_v on the overall model
38+
assert (
39+
qkv_bridge.v_hook_out is bridge.blocks[0].attn.hook_v
40+
), "v_hook_out in QKV bridge should be the same object as blocks.0.attn.hook_v"
41+
42+
# Also test that the hook points have the same properties
43+
assert (
44+
qkv_bridge.v_hook_out.has_hooks() == bridge.blocks[0].attn.hook_v.has_hooks()
45+
), "Hook points should have the same hook status"
46+
47+
def test_q_hook_out_equals_blocks_attn_hook_q(self):
48+
"""Test that q_hook_out in QKV bridge equals blocks.0.attn.hook_q on the overall model."""
49+
# Load GPT-2 in TransformerBridge
50+
bridge = TransformerBridge.boot_transformers("gpt2", device="cpu")
51+
52+
# Turn on compatibility mode
53+
bridge.enable_compatibility_mode(disable_warnings=True)
54+
55+
# Create test input
56+
test_input = torch.tensor([[1, 2, 3, 4, 5]]) # Simple test sequence
57+
58+
# Get the QKV bridge from the first attention layer
59+
qkv_bridge = bridge.blocks[0].attn.qkv
60+
61+
# Run a forward pass to populate the hooks
62+
with torch.no_grad():
63+
_ = bridge(test_input)
64+
65+
# Assert that q_hook_out in the QKV bridge is the same object as
66+
# blocks.0.attn.hook_q on the overall model
67+
assert (
68+
qkv_bridge.q_hook_out is bridge.blocks[0].attn.hook_q
69+
), "q_hook_out in QKV bridge should be the same object as blocks.0.attn.hook_q"
70+
71+
def test_k_hook_out_equals_blocks_attn_hook_k(self):
72+
"""Test that k_hook_out in QKV bridge equals blocks.0.attn.hook_k on the overall model."""
73+
# Load GPT-2 in TransformerBridge
74+
bridge = TransformerBridge.boot_transformers("gpt2", device="cpu")
75+
76+
# Turn on compatibility mode
77+
bridge.enable_compatibility_mode(disable_warnings=True)
78+
79+
# Create test input
80+
test_input = torch.tensor([[1, 2, 3, 4, 5]]) # Simple test sequence
81+
82+
# Get the QKV bridge from the first attention layer
83+
qkv_bridge = bridge.blocks[0].attn.qkv
84+
85+
# Run a forward pass to populate the hooks
86+
with torch.no_grad():
87+
_ = bridge(test_input)
88+
89+
# Assert that k_hook_out in the QKV bridge is the same object as
90+
# blocks.0.attn.hook_k on the overall model
91+
assert (
92+
qkv_bridge.k_hook_out is bridge.blocks[0].attn.hook_k
93+
), "k_hook_out in QKV bridge should be the same object as blocks.0.attn.hook_k"
94+
95+
def test_hook_aliases_work_correctly(self):
96+
"""Test that hook aliases work correctly in compatibility mode."""
97+
# Load GPT-2 in TransformerBridge
98+
bridge = TransformerBridge.boot_transformers("gpt2", device="cpu")
99+
100+
# Turn on compatibility mode
101+
bridge.enable_compatibility_mode(disable_warnings=True)
102+
103+
# Create test input
104+
test_input = torch.tensor([[1, 2, 3, 4, 5]]) # Simple test sequence
105+
106+
# Get the QKV bridge from the first attention layer
107+
qkv_bridge = bridge.blocks[0].attn.qkv
108+
109+
# Run a forward pass to populate the hooks
110+
with torch.no_grad():
111+
_ = bridge(test_input)
112+
113+
# Test that hook aliases work correctly
114+
# These should all reference the same hook points
115+
assert qkv_bridge.q_hook_out is bridge.blocks[0].attn.hook_q, "Q hook alias should work"
116+
assert qkv_bridge.k_hook_out is bridge.blocks[0].attn.hook_k, "K hook alias should work"
117+
assert qkv_bridge.v_hook_out is bridge.blocks[0].attn.hook_v, "V hook alias should work"
118+
119+
# Test that the hook points are accessible through the attention bridge properties
120+
assert qkv_bridge.q_hook_out is bridge.blocks[0].attn.q.hook_out, "Q property should work"
121+
assert qkv_bridge.k_hook_out is bridge.blocks[0].attn.k.hook_out, "K property should work"
122+
assert qkv_bridge.v_hook_out is bridge.blocks[0].attn.v.hook_out, "V property should work"
123+
124+
def test_head_ablation_hook_works_correctly(self):
125+
"""Test that head ablation hook works correctly with TransformerBridge."""
126+
# Load GPT-2 in TransformerBridge
127+
bridge = TransformerBridge.boot_transformers("gpt2", device="cpu")
128+
129+
# Turn on compatibility mode
130+
bridge.enable_compatibility_mode(disable_warnings=True)
131+
132+
# Create test tokens (same as in the demo)
133+
gpt2_tokens = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
134+
135+
layer_to_ablate = 0
136+
head_index_to_ablate = 8
137+
138+
# Test both hook names
139+
hook_names_to_test = [
140+
"blocks.0.attn.hook_v", # Compatibility mode alias
141+
"blocks.0.attn.v.hook_out", # Direct property access
142+
]
143+
144+
for hook_name in hook_names_to_test:
145+
print(f"\nTesting hook name: {hook_name}")
146+
147+
# Track if the hook was called
148+
hook_called = False
149+
mutation_applied = False
150+
151+
# We define a head ablation hook
152+
def head_ablation_hook(value, hook):
153+
nonlocal hook_called, mutation_applied
154+
hook_called = True
155+
print(f"Shape of the value tensor: {value.shape}")
156+
157+
# Apply the ablation (out-of-place to avoid view modification error)
158+
result = value.clone()
159+
result[:, :, head_index_to_ablate, :] = 0.0
160+
161+
# Check if the mutation was applied (the result should be zero for the ablated head)
162+
if torch.all(result[:, :, head_index_to_ablate, :] == 0.0):
163+
mutation_applied = True
164+
165+
return result
166+
167+
# Get original loss
168+
original_loss = bridge(gpt2_tokens, return_type="loss")
169+
170+
# Run with head ablation hook
171+
ablated_loss = bridge.run_with_hooks(
172+
gpt2_tokens, return_type="loss", fwd_hooks=[(hook_name, head_ablation_hook)]
173+
)
174+
175+
print(f"Original Loss: {original_loss.item():.3f}")
176+
print(f"Ablated Loss: {ablated_loss.item():.3f}")
177+
178+
# Assert that the hook was called
179+
assert hook_called, f"Head ablation hook should have been called for {hook_name}"
180+
181+
# Assert that the mutation was applied
182+
assert (
183+
mutation_applied
184+
), f"Mutation should have been applied to the value tensor for {hook_name}"
185+
186+
# Assert that ablated loss is higher than original loss (ablation should hurt performance)
187+
assert (
188+
ablated_loss.item() > original_loss.item()
189+
), f"Ablated loss should be higher than original loss for {hook_name}"
190+
191+
print(f"✅ Hook {hook_name} works correctly!")

0 commit comments

Comments
 (0)