Skip to content

Commit ee9b44b

Browse files
Split weights instead of logits for models with joint QKV matrix (#1043)
* Split weights instead of logits for models with joint QKV activation * Adjust tests accordingly * Set split_qkv_matrix function inside init * Remove debugging print statements * added model fixture * attempted memory reduction * fixed test * fixed test * removed extra test * cleaned up test * applied input hook to attention --------- Co-authored-by: Bryce Meyer <[email protected]>
1 parent 58e788f commit ee9b44b

File tree

18 files changed

+873
-615
lines changed

18 files changed

+873
-615
lines changed

tests/conftest.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""Global pytest configuration for memory management and test optimization."""
2+
3+
import gc
4+
5+
import pytest
6+
import torch
7+
8+
9+
@pytest.fixture(autouse=True, scope="function")
10+
def cleanup_memory():
11+
"""Automatically clean up memory after each test."""
12+
yield
13+
# Clear torch cache
14+
if torch.cuda.is_available():
15+
torch.cuda.empty_cache()
16+
# Force garbage collection multiple times for better cleanup
17+
for _ in range(3):
18+
gc.collect()
19+
20+
21+
@pytest.fixture(autouse=True, scope="class")
22+
def cleanup_class_memory():
23+
"""Clean up memory after each test class."""
24+
yield
25+
# More aggressive cleanup after test classes
26+
if torch.cuda.is_available():
27+
torch.cuda.empty_cache()
28+
gc.collect()
29+
30+
31+
# Configure pytest to be more memory-efficient
32+
def pytest_configure(config):
33+
"""Configure pytest for better memory usage."""
34+
# Set torch to use less memory
35+
torch.set_num_threads(1) # Reduce threading overhead
36+
37+
# Configure garbage collection to be more aggressive
38+
gc.set_threshold(700, 10, 10)
39+
40+
41+
def pytest_sessionfinish(session, exitstatus):
42+
"""Clean up at the end of test session."""
43+
if torch.cuda.is_available():
44+
torch.cuda.empty_cache()
45+
gc.collect()

tests/integration/model_bridge/compatibility/test_hooks.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,15 @@
66
MODEL = "gpt2" # Use a model supported by TransformerBridge
77

88
prompt = "Hello World!"
9-
model = TransformerBridge.boot_transformers(MODEL, device="cpu")
109
embed = lambda name: name == "hook_embed"
1110

1211

12+
@pytest.fixture(scope="module")
13+
def model():
14+
"""Load model once per test module to reduce memory usage."""
15+
return TransformerBridge.boot_transformers(MODEL, device="cpu")
16+
17+
1318
class Counter:
1419
def __init__(self):
1520
self.count = 0
@@ -18,7 +23,7 @@ def inc(self, *args, **kwargs):
1823
self.count += 1
1924

2025

21-
def test_hook_attaches_normally():
26+
def test_hook_attaches_normally(model):
2227
"""Test that hooks can be attached and removed normally with TransformerBridge."""
2328
c = Counter()
2429
_ = model.run_with_hooks(prompt, fwd_hooks=[(embed, c.inc)])
@@ -40,7 +45,7 @@ def test_hook_attaches_normally():
4045
pass
4146

4247

43-
def test_perma_hook_attaches_normally():
48+
def test_perma_hook_attaches_normally(model):
4449
"""Test that permanent hooks can be attached with TransformerBridge."""
4550
c = Counter()
4651

@@ -66,7 +71,7 @@ def test_perma_hook_attaches_normally():
6671
pass
6772

6873

69-
def test_hook_context_manager():
74+
def test_hook_context_manager(model):
7075
"""Test that hook context manager works with TransformerBridge."""
7176
c = Counter()
7277

@@ -95,7 +100,7 @@ def test_hook_context_manager():
95100
pass
96101

97102

98-
def test_run_with_cache_functionality():
103+
def test_run_with_cache_functionality(model):
99104
"""Test that run_with_cache works with TransformerBridge."""
100105
try:
101106
output, cache = model.run_with_cache(prompt)
@@ -122,7 +127,7 @@ def test_run_with_cache_functionality():
122127
pytest.skip(f"run_with_cache not working on TransformerBridge: {e}")
123128

124129

125-
def test_hook_dict_access():
130+
def test_hook_dict_access(model):
126131
"""Test that hook_dict property works with TransformerBridge."""
127132
try:
128133
hook_dict = model.hook_dict
@@ -141,7 +146,7 @@ def test_hook_dict_access():
141146
pytest.skip(f"hook_dict not working on TransformerBridge: {e}")
142147

143148

144-
def test_basic_forward_with_hooks():
149+
def test_basic_forward_with_hooks(model):
145150
"""Test basic forward pass with hooks on TransformerBridge."""
146151

147152
def simple_hook(tensor, hook):
@@ -168,7 +173,7 @@ def simple_hook(tensor, hook):
168173
pytest.skip(f"Forward with hooks not working on TransformerBridge: {e}")
169174

170175

171-
def test_hook_names_consistency():
176+
def test_hook_names_consistency(model):
172177
"""Test that hook names are consistent and follow expected patterns."""
173178
try:
174179
hook_dict = model.hook_dict
@@ -194,7 +199,7 @@ def test_hook_names_consistency():
194199
pytest.skip(f"Hook names check failed on TransformerBridge: {e}")
195200

196201

197-
def test_caching_with_names_filter():
202+
def test_caching_with_names_filter(model):
198203
"""Test that caching with names filter works with TransformerBridge."""
199204
try:
200205
hook_dict = model.hook_dict

tests/integration/model_bridge/compatibility/test_match_huggingface.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,20 @@ class TestMatchHuggingFace:
1313
def model_name(self, request):
1414
return request.param
1515

16+
@pytest.fixture(scope="class")
17+
def bridge_model(self, model_name):
18+
"""Load TransformerBridge once per test class."""
19+
return TransformerBridge.boot_transformers(model_name, device="cpu")
20+
21+
@pytest.fixture(scope="class")
22+
def hf_model(self, model_name):
23+
"""Load HuggingFace model once per test class."""
24+
return AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu")
25+
1626
# tests
17-
def test_compare_huggingface_mlp_match_local_implementation(self, model_name):
27+
def test_compare_huggingface_mlp_match_local_implementation(self, bridge_model, hf_model):
1828
"""Test that TransformerBridge MLP outputs match HuggingFace MLP outputs."""
1929
try:
20-
bridge_model = TransformerBridge.boot_transformers(model_name, device="cpu")
21-
hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu")
22-
2330
tensor_shape = (3, 5, bridge_model.cfg.d_model)
2431
test_tensor = torch.randn(tensor_shape)
2532

@@ -48,12 +55,9 @@ def test_compare_huggingface_mlp_match_local_implementation(self, model_name):
4855
except Exception as e:
4956
pytest.fail(f"Unexpected error in MLP comparison: {e}")
5057

51-
def test_compare_huggingface_attention_match_local_implementation(self, model_name):
58+
def test_compare_huggingface_attention_match_local_implementation(self, bridge_model, hf_model):
5259
"""Test that TransformerBridge attention outputs match HuggingFace attention outputs."""
5360
try:
54-
bridge_model = TransformerBridge.boot_transformers(model_name, device="cpu")
55-
hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu")
56-
5761
batch, pos, d_model = 3, 5, bridge_model.cfg.d_model
5862
input_tensor = torch.randn(batch, pos, d_model)
5963

@@ -113,12 +117,9 @@ def test_compare_huggingface_attention_match_local_implementation(self, model_na
113117
except Exception as e:
114118
pytest.fail(f"Unexpected error in attention comparison: {e}")
115119

116-
def test_full_model_output_match(self, model_name):
120+
def test_full_model_output_match(self, bridge_model, hf_model):
117121
"""Test that full TransformerBridge model output matches HuggingFace model output."""
118122
try:
119-
bridge_model = TransformerBridge.boot_transformers(model_name, device="cpu")
120-
hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu")
121-
122123
# Test with a simple prompt
123124
prompt = "The capital of France is"
124125

@@ -144,11 +145,9 @@ def test_full_model_output_match(self, model_name):
144145
except Exception as e:
145146
pytest.fail(f"Unexpected error in full model comparison: {e}")
146147

147-
def test_tokenizer_consistency(self, model_name):
148+
def test_tokenizer_consistency(self, bridge_model):
148149
"""Test that TransformerBridge tokenizer matches HuggingFace tokenizer."""
149150
try:
150-
bridge_model = TransformerBridge.boot_transformers(model_name, device="cpu")
151-
152151
# Test tokenization
153152
prompt = "Hello, world! This is a test."
154153
bridge_tokens = bridge_model.to_tokens(prompt)
@@ -171,12 +170,9 @@ def test_tokenizer_consistency(self, model_name):
171170
except Exception as e:
172171
pytest.fail(f"Unexpected error in tokenizer consistency: {e}")
173172

174-
def test_config_consistency(self, model_name):
173+
def test_config_consistency(self, bridge_model, hf_model):
175174
"""Test that TransformerBridge config matches HuggingFace config."""
176175
try:
177-
bridge_model = TransformerBridge.boot_transformers(model_name, device="cpu")
178-
hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu")
179-
180176
bridge_cfg = bridge_model.cfg
181177
hf_cfg = hf_model.config
182178

@@ -199,11 +195,9 @@ def test_config_consistency(self, model_name):
199195
except Exception as e:
200196
pytest.fail(f"Unexpected error in config consistency: {e}")
201197

202-
def test_weight_access_consistency(self, model_name):
198+
def test_weight_access_consistency(self, bridge_model):
203199
"""Test that TransformerBridge weight access provides expected values."""
204200
try:
205-
bridge_model = TransformerBridge.boot_transformers(model_name, device="cpu")
206-
207201
# Test basic weight access patterns
208202
weight_checks = []
209203

tests/integration/model_bridge/generalized_components/test_qkv_bridge_integration.py renamed to tests/integration/model_bridge/generalized_components/test_joint_qkv_attention_bridge_integration.py

Lines changed: 19 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Lightweight integration tests for QKVBridge.
1+
"""Lightweight integration tests for JointQKVAttentionBridge.
22
33
Tests the core functionality without loading large models to keep CI fast.
44
"""
@@ -7,11 +7,10 @@
77
import torch
88

99
import transformer_lens.utils as utils
10-
from transformer_lens.model_bridge.generalized_components.qkv_bridge import QKVBridge
1110

1211

13-
class TestQKVBridgeIntegration:
14-
"""Minimal integration tests for QKVBridge."""
12+
class TestJointQKVAttentionBridgeIntegration:
13+
"""Minimal integration tests for JointQKVAttentionBridge."""
1514

1615
def test_hook_alias_resolution(self):
1716
"""Test that hook aliases are properly resolved."""
@@ -25,38 +24,6 @@ def test_hook_alias_resolution(self):
2524
assert utils.get_act_name("q", 1) == "blocks.1.attn.hook_q"
2625
assert utils.get_act_name("k", 2) == "blocks.2.attn.hook_k"
2726

28-
def test_joint_qkv_attention_bridge_properties(self):
29-
"""Test that JointQKVAttentionBridge properties are properly resolved."""
30-
from transformer_lens.model_bridge.generalized_components.joint_qkv_attention import (
31-
JointQKVAttentionBridge,
32-
)
33-
34-
class TestConfig:
35-
n_heads = 12
36-
37-
qkv_bridge = QKVBridge(name="qkv", config=TestConfig())
38-
39-
qkv_attention_bridge = JointQKVAttentionBridge(
40-
name="blocks.0.attn",
41-
config=TestConfig(),
42-
submodules={"qkv": qkv_bridge},
43-
)
44-
45-
assert qkv_attention_bridge.q.hook_in == qkv_bridge.q_hook_in
46-
assert qkv_attention_bridge.q.hook_out == qkv_bridge.q_hook_out
47-
assert qkv_attention_bridge.k.hook_in == qkv_bridge.k_hook_in
48-
assert qkv_attention_bridge.k.hook_out == qkv_bridge.k_hook_out
49-
assert qkv_attention_bridge.v.hook_in == qkv_bridge.v_hook_in
50-
assert qkv_attention_bridge.v.hook_out == qkv_bridge.v_hook_out
51-
52-
def test_component_class_exists(self):
53-
"""Test that QKVBridge class can be imported."""
54-
55-
# Verify the class exists and has expected methods
56-
assert hasattr(QKVBridge, "forward")
57-
assert hasattr(QKVBridge, "_create_qkv_conversion_rule")
58-
assert hasattr(QKVBridge, "_create_qkv_separation_rule")
59-
6027
def test_hook_point_has_hooks_method(self):
6128
"""Test that HookPoint.has_hooks method works correctly."""
6229
from transformer_lens.hook_points import HookPoint
@@ -82,9 +49,9 @@ def dummy_hook(x, hook):
8249
assert not hook_point.has_hooks()
8350

8451
def test_architecture_imports(self):
85-
"""Test that architecture files can be imported and reference QKVBridge."""
52+
"""Test that architecture files can be imported and reference JointQKVAttentionBridge."""
8653
# Test that we can import the architecture files without errors
87-
# Test that QKVBridge is referenced in the source files
54+
# Test that JointQKVAttentionBridge is referenced in the source files
8855
import inspect
8956

9057
from transformer_lens.model_bridge.supported_architectures import (
@@ -94,13 +61,19 @@ def test_architecture_imports(self):
9461
)
9562

9663
gpt2_source = inspect.getsource(gpt2)
97-
assert "QKVBridge" in gpt2_source, "GPT-2 architecture should reference QKVBridge"
64+
assert (
65+
"JointQKVAttentionBridge" in gpt2_source
66+
), "GPT-2 architecture should reference JointQKVAttentionBridge"
9867

9968
bloom_source = inspect.getsource(bloom)
100-
assert "QKVBridge" in bloom_source, "BLOOM architecture should reference QKVBridge"
69+
assert (
70+
"JointQKVAttentionBridge" in bloom_source
71+
), "BLOOM architecture should reference JointQKVAttentionBridge"
10172

10273
neox_source = inspect.getsource(neox)
103-
assert "QKVBridge" in neox_source, "NeoX architecture should reference QKVBridge"
74+
assert (
75+
"JointQKVAttentionBridge" in neox_source
76+
), "NeoX architecture should reference JointQKVAttentionBridge"
10477

10578
@pytest.mark.skip(reason="Requires model loading - too slow for CI")
10679
def test_distilgpt2_integration(self):
@@ -112,15 +85,15 @@ def test_distilgpt2_integration(self):
11285
torch.set_grad_enabled(False)
11386
model = TransformerBridge.boot_transformers("distilgpt2", device="cpu")
11487

115-
# Verify QKVBridge usage
116-
qkv_bridge_modules = [
88+
# Verify JointQKVAttentionBridge usage
89+
joint_qkv_attention_bridge_modules = [
11790
name
11891
for name, module in model.named_modules()
119-
if "QKVBridge" in getattr(module, "__class__", {}).get("__name__", "")
92+
if "JointQKVAttentionBridge" in getattr(module, "__class__", {}).get("__name__", "")
12093
]
12194
assert (
122-
len(qkv_bridge_modules) == 6
123-
), f"Expected 6 QKVBridge modules, got {len(qkv_bridge_modules)}"
95+
len(joint_qkv_attention_bridge_modules) == 6
96+
), f"Expected 6 JointQKVAttentionBridge modules, got {len(joint_qkv_attention_bridge_modules)}"
12497

12598
# Test basic functionality
12699
tokens = model.to_tokens("Test")

0 commit comments

Comments
 (0)