Skip to content

Commit 975bea7

Browse files
committed
Merge remote-tracking branch 'origin/dev-3.x' into add_support_for_gpt_oss
2 parents 63d7b45 + 1ad8162 commit 975bea7

File tree

6 files changed

+613
-35
lines changed

6 files changed

+613
-35
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
"""Lightweight integration tests for JointQKVAttentionBridge.
2+
3+
Tests the core functionality without loading large models to keep CI fast.
4+
"""
5+
6+
import pytest
7+
import torch
8+
9+
import transformer_lens.utils as utils
10+
11+
12+
class TestJointQKVAttentionBridgeIntegration:
13+
"""Minimal integration tests for JointQKVAttentionBridge."""
14+
15+
def test_hook_alias_resolution(self):
16+
"""Test that hook aliases are properly resolved."""
17+
# Test the hook alias resolution that caused the original issue
18+
hook_name = utils.get_act_name("v", 0)
19+
assert (
20+
hook_name == "blocks.0.attn.hook_v"
21+
), f"Expected 'blocks.0.attn.hook_v', got '{hook_name}'"
22+
23+
# Test other hook names
24+
assert utils.get_act_name("q", 1) == "blocks.1.attn.hook_q"
25+
assert utils.get_act_name("k", 2) == "blocks.2.attn.hook_k"
26+
27+
def test_component_class_exists(self):
28+
"""Test that JointQKVAttentionBridge class can be imported."""
29+
from transformer_lens.model_bridge.generalized_components.joint_qkv_attention import (
30+
JointQKVAttentionBridge,
31+
)
32+
33+
# Verify the class exists and has expected methods
34+
assert hasattr(JointQKVAttentionBridge, "forward")
35+
assert hasattr(JointQKVAttentionBridge, "_reconstruct_attention")
36+
assert hasattr(JointQKVAttentionBridge, "_manual_attention_computation")
37+
38+
def test_hook_point_has_hooks_method(self):
39+
"""Test that HookPoint.has_hooks method works correctly."""
40+
from transformer_lens.hook_points import HookPoint
41+
42+
hook_point = HookPoint()
43+
44+
# Test initial state
45+
assert not hook_point.has_hooks()
46+
assert not hook_point.has_hooks(dir="fwd")
47+
assert not hook_point.has_hooks(dir="bwd")
48+
49+
# Add a hook and test detection
50+
def dummy_hook(x, hook):
51+
return x
52+
53+
hook_point.add_hook(dummy_hook)
54+
assert hook_point.has_hooks()
55+
assert hook_point.has_hooks(dir="fwd")
56+
assert not hook_point.has_hooks(dir="bwd")
57+
58+
# Clean up
59+
hook_point.remove_hooks()
60+
assert not hook_point.has_hooks()
61+
62+
def test_architecture_imports(self):
63+
"""Test that architecture files can be imported and reference JointQKVAttentionBridge."""
64+
# Test that we can import the architecture files without errors
65+
# Test that JointQKVAttentionBridge is referenced in the source files
66+
import inspect
67+
68+
from transformer_lens.model_bridge.supported_architectures import (
69+
bloom,
70+
gpt2,
71+
neox,
72+
)
73+
74+
gpt2_source = inspect.getsource(gpt2)
75+
assert (
76+
"JointQKVAttentionBridge" in gpt2_source
77+
), "GPT-2 architecture should reference JointQKVAttentionBridge"
78+
79+
bloom_source = inspect.getsource(bloom)
80+
assert (
81+
"JointQKVAttentionBridge" in bloom_source
82+
), "BLOOM architecture should reference JointQKVAttentionBridge"
83+
84+
neox_source = inspect.getsource(neox)
85+
assert (
86+
"JointQKVAttentionBridge" in neox_source
87+
), "NeoX architecture should reference JointQKVAttentionBridge"
88+
89+
@pytest.mark.skip(reason="Requires model loading - too slow for CI")
90+
def test_distilgpt2_integration(self):
91+
"""Full integration test with DistilGPT-2 (skipped in CI)."""
92+
# This test would load DistilGPT-2 and test full functionality
93+
# but is skipped by default to keep CI fast
94+
from transformer_lens.model_bridge import TransformerBridge
95+
96+
torch.set_grad_enabled(False)
97+
model = TransformerBridge.boot_transformers("distilgpt2", device="cpu")
98+
99+
# Verify JointQKVAttentionBridge usage
100+
joint_qkv_modules = [
101+
name
102+
for name, module in model.named_modules()
103+
if "JointQKVAttentionBridge" in getattr(module, "__class__", {}).get("__name__", "")
104+
]
105+
assert (
106+
len(joint_qkv_modules) == 6
107+
), f"Expected 6 JointQKVAttentionBridge modules, got {len(joint_qkv_modules)}"
108+
109+
# Test basic functionality
110+
tokens = model.to_tokens("Test")
111+
with torch.no_grad():
112+
loss = model(tokens, return_type="loss")
113+
assert torch.isfinite(loss) and loss > 0
114+
115+
# Test hook integration
116+
def v_ablation_hook(value, hook):
117+
value[:, :, 0, :] = 0.0 # Ablate first head
118+
return value
119+
120+
original_loss = model(tokens, return_type="loss")
121+
hooked_loss = model.run_with_hooks(
122+
tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("v", 0), v_ablation_hook)]
123+
)
124+
assert not torch.isclose(original_loss, hooked_loss, atol=1e-6)

tests/integration/model_bridge/test_bridge_root_module_cache_compatibility.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
},
1212
)
1313

14-
# Attention output enabled via hf_config_overrides
15-
1614
act_names_in_cache = [
1715
# "hook_embed",
1816
# "hook_pos_embed",

tests/unit/test_hook_points.py

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,264 @@ def test_hook(activation, hook):
208208
# Since hook returns None, the original input should be returned
209209
# (HookPoint's forward method returns the input when no valid hook result)
210210
assert torch.equal(result, test_input)
211+
212+
213+
class TestHookPointHasHooks:
214+
"""Comprehensive test suite for HookPoint.has_hooks method."""
215+
216+
def setup_method(self):
217+
"""Set up fresh HookPoint and sample hook for each test."""
218+
self.hook_point = HookPoint()
219+
220+
def sample_hook(activation, hook):
221+
return activation
222+
223+
self.sample_hook = sample_hook
224+
225+
def test_no_hooks_returns_false(self):
226+
"""Test that has_hooks returns False when no hooks are present."""
227+
assert not self.hook_point.has_hooks()
228+
assert not self.hook_point.has_hooks(dir="fwd")
229+
assert not self.hook_point.has_hooks(dir="bwd")
230+
assert not self.hook_point.has_hooks(dir="both")
231+
232+
def test_forward_hook_detection(self):
233+
"""Test detection of forward hooks."""
234+
# Add a forward hook
235+
self.hook_point.add_hook(self.sample_hook, dir="fwd")
236+
237+
# Should detect forward hooks
238+
assert self.hook_point.has_hooks()
239+
assert self.hook_point.has_hooks(dir="fwd")
240+
assert self.hook_point.has_hooks(dir="both")
241+
242+
# Should not detect backward hooks
243+
assert not self.hook_point.has_hooks(dir="bwd")
244+
245+
def test_backward_hook_detection(self):
246+
"""Test detection of backward hooks."""
247+
# Add a backward hook
248+
self.hook_point.add_hook(self.sample_hook, dir="bwd")
249+
250+
# Should detect backward hooks
251+
assert self.hook_point.has_hooks()
252+
assert self.hook_point.has_hooks(dir="bwd")
253+
assert self.hook_point.has_hooks(dir="both")
254+
255+
# Should not detect forward hooks
256+
assert not self.hook_point.has_hooks(dir="fwd")
257+
258+
def test_both_direction_hooks(self):
259+
"""Test detection when both forward and backward hooks are present."""
260+
# Add both forward and backward hooks
261+
self.hook_point.add_hook(self.sample_hook, dir="fwd")
262+
self.hook_point.add_hook(self.sample_hook, dir="bwd")
263+
264+
# All directions should detect hooks
265+
assert self.hook_point.has_hooks()
266+
assert self.hook_point.has_hooks(dir="fwd")
267+
assert self.hook_point.has_hooks(dir="bwd")
268+
assert self.hook_point.has_hooks(dir="both")
269+
270+
def test_permanent_hook_detection(self):
271+
"""Test detection of permanent hooks."""
272+
# Add a permanent forward hook
273+
self.hook_point.add_hook(self.sample_hook, dir="fwd", is_permanent=True)
274+
275+
# Should detect permanent hooks by default
276+
assert self.hook_point.has_hooks()
277+
assert self.hook_point.has_hooks(including_permanent=True)
278+
279+
# Should not detect when excluding permanent hooks
280+
assert not self.hook_point.has_hooks(including_permanent=False)
281+
282+
def test_non_permanent_hook_detection(self):
283+
"""Test detection of non-permanent hooks."""
284+
# Add a non-permanent forward hook
285+
self.hook_point.add_hook(self.sample_hook, dir="fwd", is_permanent=False)
286+
287+
# Should detect non-permanent hooks regardless of including_permanent setting
288+
assert self.hook_point.has_hooks()
289+
assert self.hook_point.has_hooks(including_permanent=True)
290+
assert self.hook_point.has_hooks(including_permanent=False)
291+
292+
def test_mixed_permanent_hooks(self):
293+
"""Test detection with mix of permanent and non-permanent hooks."""
294+
# Add both permanent and non-permanent hooks
295+
self.hook_point.add_hook(self.sample_hook, dir="fwd", is_permanent=True)
296+
self.hook_point.add_hook(self.sample_hook, dir="fwd", is_permanent=False)
297+
298+
# Should detect hooks in both cases
299+
assert self.hook_point.has_hooks(including_permanent=True)
300+
assert self.hook_point.has_hooks(including_permanent=False)
301+
302+
def test_only_permanent_hooks(self):
303+
"""Test detection when only permanent hooks are present."""
304+
# Add only permanent hooks
305+
self.hook_point.add_hook(self.sample_hook, dir="fwd", is_permanent=True)
306+
self.hook_point.add_hook(self.sample_hook, dir="bwd", is_permanent=True)
307+
308+
# Should detect when including permanent
309+
assert self.hook_point.has_hooks(including_permanent=True)
310+
assert self.hook_point.has_hooks(dir="fwd", including_permanent=True)
311+
assert self.hook_point.has_hooks(dir="bwd", including_permanent=True)
312+
313+
# Should not detect when excluding permanent
314+
assert not self.hook_point.has_hooks(including_permanent=False)
315+
assert not self.hook_point.has_hooks(dir="fwd", including_permanent=False)
316+
assert not self.hook_point.has_hooks(dir="bwd", including_permanent=False)
317+
318+
def test_context_level_filtering(self):
319+
"""Test context level filtering functionality."""
320+
# Add hooks at different context levels
321+
self.hook_point.add_hook(self.sample_hook, dir="fwd", level=0)
322+
self.hook_point.add_hook(self.sample_hook, dir="fwd", level=1)
323+
self.hook_point.add_hook(self.sample_hook, dir="bwd", level=2)
324+
325+
# Should detect hooks at specific levels
326+
assert self.hook_point.has_hooks(level=0)
327+
assert self.hook_point.has_hooks(level=1)
328+
assert self.hook_point.has_hooks(level=2)
329+
330+
# Should not detect hooks at non-existent levels
331+
assert not self.hook_point.has_hooks(level=3)
332+
assert not self.hook_point.has_hooks(level=-1)
333+
334+
# Should detect all hooks when level is None
335+
assert self.hook_point.has_hooks(level=None)
336+
337+
def test_context_level_with_direction(self):
338+
"""Test context level filtering combined with direction filtering."""
339+
# Add hooks at different levels and directions
340+
self.hook_point.add_hook(self.sample_hook, dir="fwd", level=0)
341+
self.hook_point.add_hook(self.sample_hook, dir="bwd", level=1)
342+
343+
# Test specific combinations
344+
assert self.hook_point.has_hooks(dir="fwd", level=0)
345+
assert self.hook_point.has_hooks(dir="bwd", level=1)
346+
347+
# Test non-matching combinations
348+
assert not self.hook_point.has_hooks(dir="fwd", level=1)
349+
assert not self.hook_point.has_hooks(dir="bwd", level=0)
350+
351+
def test_context_level_with_permanent_flags(self):
352+
"""Test context level filtering combined with permanent hook filtering."""
353+
# Add permanent and non-permanent hooks at different levels
354+
self.hook_point.add_hook(self.sample_hook, dir="fwd", level=0, is_permanent=True)
355+
self.hook_point.add_hook(self.sample_hook, dir="fwd", level=1, is_permanent=False)
356+
357+
# Test combinations
358+
assert self.hook_point.has_hooks(level=0, including_permanent=True)
359+
assert not self.hook_point.has_hooks(level=0, including_permanent=False)
360+
assert self.hook_point.has_hooks(level=1, including_permanent=True)
361+
assert self.hook_point.has_hooks(level=1, including_permanent=False)
362+
363+
def test_all_parameters_combined(self):
364+
"""Test all parameters combined in various ways."""
365+
# Create a complex setup with multiple hooks
366+
self.hook_point.add_hook(self.sample_hook, dir="fwd", level=0, is_permanent=True)
367+
self.hook_point.add_hook(self.sample_hook, dir="fwd", level=1, is_permanent=False)
368+
self.hook_point.add_hook(self.sample_hook, dir="bwd", level=0, is_permanent=False)
369+
self.hook_point.add_hook(self.sample_hook, dir="bwd", level=2, is_permanent=True)
370+
371+
# Test specific combinations
372+
assert self.hook_point.has_hooks(dir="fwd", level=0, including_permanent=True)
373+
assert not self.hook_point.has_hooks(dir="fwd", level=0, including_permanent=False)
374+
assert self.hook_point.has_hooks(dir="fwd", level=1, including_permanent=False)
375+
assert self.hook_point.has_hooks(dir="bwd", level=0, including_permanent=False)
376+
assert not self.hook_point.has_hooks(dir="bwd", level=1, including_permanent=False)
377+
assert self.hook_point.has_hooks(dir="bwd", level=2, including_permanent=True)
378+
379+
def test_invalid_direction_raises_error(self):
380+
"""Test that invalid direction parameter raises error (caught by type checking)."""
381+
# Note: beartype catches this at the parameter level before reaching the ValueError
382+
import pytest
383+
from beartype.roar import BeartypeCallHintParamViolation
384+
385+
with pytest.raises(BeartypeCallHintParamViolation):
386+
self.hook_point.has_hooks(dir="invalid") # type: ignore
387+
388+
def test_multiple_hooks_same_criteria(self):
389+
"""Test detection when multiple hooks match the same criteria."""
390+
# Add multiple hooks with same criteria
391+
self.hook_point.add_hook(self.sample_hook, dir="fwd", level=0, is_permanent=False)
392+
self.hook_point.add_hook(self.sample_hook, dir="fwd", level=0, is_permanent=False)
393+
self.hook_point.add_hook(self.sample_hook, dir="fwd", level=0, is_permanent=False)
394+
395+
# Should still detect hooks (method returns True on first match)
396+
assert self.hook_point.has_hooks(dir="fwd", level=0, including_permanent=False)
397+
398+
def test_hook_removal_affects_detection(self):
399+
"""Test that removing hooks affects detection."""
400+
# Add a hook
401+
self.hook_point.add_hook(self.sample_hook, dir="fwd")
402+
assert self.hook_point.has_hooks()
403+
404+
# Remove all hooks
405+
self.hook_point.remove_hooks(dir="both")
406+
assert not self.hook_point.has_hooks()
407+
408+
def test_default_parameter_values(self):
409+
"""Test that default parameter values work correctly."""
410+
# Add hooks to test defaults
411+
self.hook_point.add_hook(self.sample_hook, dir="fwd", is_permanent=True, level=0)
412+
self.hook_point.add_hook(self.sample_hook, dir="bwd", is_permanent=False, level=1)
413+
414+
# Test default behavior (dir="both", including_permanent=True, level=None)
415+
assert self.hook_point.has_hooks()
416+
417+
# This should be equivalent to:
418+
assert self.hook_point.has_hooks(dir="both", including_permanent=True, level=None)
419+
420+
def test_edge_case_empty_after_filtering(self):
421+
"""Test edge case where hooks exist but are filtered out."""
422+
# Add hooks that will be filtered out
423+
self.hook_point.add_hook(self.sample_hook, dir="fwd", is_permanent=True, level=5)
424+
425+
# These should not detect the hook due to filtering
426+
assert not self.hook_point.has_hooks(including_permanent=False)
427+
assert not self.hook_point.has_hooks(dir="bwd")
428+
assert not self.hook_point.has_hooks(level=0)
429+
assert not self.hook_point.has_hooks(dir="bwd", level=5, including_permanent=True)
430+
431+
def test_functional_hook_execution_still_works(self):
432+
"""Test that has_hooks doesn't interfere with actual hook functionality."""
433+
import torch
434+
435+
results = []
436+
437+
def test_hook(activation, hook):
438+
results.append("hook_called")
439+
return activation
440+
441+
# Add hook and verify detection
442+
self.hook_point.add_hook(test_hook, dir="fwd")
443+
assert self.hook_point.has_hooks()
444+
445+
# Execute hook and verify it still works
446+
test_input = torch.tensor([1.0, 2.0, 3.0])
447+
output = self.hook_point(test_input)
448+
449+
assert torch.equal(output, test_input)
450+
assert "hook_called" in results
451+
452+
def test_hook_point_with_conversions(self):
453+
"""Test has_hooks with hook conversions if they exist."""
454+
import torch
455+
456+
# This test ensures has_hooks works even when hook conversions are involved
457+
def simple_hook(activation, hook):
458+
return activation * 2
459+
460+
# Add hook
461+
self.hook_point.add_hook(simple_hook, dir="fwd")
462+
463+
# Should detect hook regardless of any internal conversions
464+
assert self.hook_point.has_hooks()
465+
assert self.hook_point.has_hooks(dir="fwd")
466+
467+
# Test actual functionality still works
468+
test_input = torch.tensor([1.0, 2.0])
469+
output = self.hook_point(test_input)
470+
expected = torch.tensor([2.0, 4.0])
471+
assert torch.allclose(output, expected)

0 commit comments

Comments
 (0)