Skip to content

Commit 5babde2

Browse files
authored
cleaned up a lot of things (#1113)
* cleaned up a lot of things * removed extra function * fixed typing * fixed index bug * removed extra stuff * fixed main demo * removed bad chunk * removed attention check * fixed cache * fixed type check * fixed demo issue * fixed test * fixed typing * updated type * restored patched function * fixed gemma 3 * fixed test * fixed typing * continued working through gemma compat * fixed more issues * got closer * ran format * fixed extra config * grouped results by phase * cleaned up adapter * fixed weight processing issue * reevised hooks * fixed phase 2 * set flags correctly * revised benchmarks for granularity * improved gemma compatibility * claned up memory * revised architecture adapters * cleaned up memory * improved some models * used correct component * verified more architectures * fixed more models * ran format * fixed typing * fixed typing * fixed test * fixed t5 * fixed format * fixed test
1 parent 69db6a3 commit 5babde2

35 files changed

+2039
-502
lines changed

tests/integration/model_bridge/generalized_components/test_joint_qkv_attention_bridge_integration.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ def dummy_hook(x, hook):
4949
assert not hook_point.has_hooks()
5050

5151
def test_architecture_imports(self):
52-
"""Test that architecture files can be imported and reference JointQKVAttentionBridge."""
52+
"""Test that architecture files can be imported and use appropriate attention bridges."""
5353
# Test that we can import the architecture files without errors
54-
# Test that JointQKVAttentionBridge is referenced in the source files
54+
# Test that appropriate attention bridges are referenced in the source files
5555
import inspect
5656

5757
from transformer_lens.model_bridge.supported_architectures import (
@@ -65,15 +65,18 @@ def test_architecture_imports(self):
6565
"JointQKVAttentionBridge" in gpt2_source
6666
), "GPT-2 architecture should reference JointQKVAttentionBridge"
6767

68+
# BLOOM uses BloomAttentionBridge instead of JointQKVAttentionBridge
69+
# because it requires alibi bias and residual connections
6870
bloom_source = inspect.getsource(bloom)
6971
assert (
70-
"JointQKVAttentionBridge" in bloom_source
71-
), "BLOOM architecture should reference JointQKVAttentionBridge"
72+
"BloomAttentionBridge" in bloom_source
73+
), "BLOOM architecture should reference BloomAttentionBridge"
7274

75+
# NeoX uses JointQKVPositionEmbeddingsAttentionBridge for rotary embeddings
7376
neox_source = inspect.getsource(neox)
7477
assert (
75-
"JointQKVAttentionBridge" in neox_source
76-
), "NeoX architecture should reference JointQKVAttentionBridge"
78+
"JointQKVPositionEmbeddingsAttentionBridge" in neox_source
79+
), "NeoX architecture should reference JointQKVPositionEmbeddingsAttentionBridge"
7780

7881
@pytest.mark.slow
7982
def test_distilgpt2_integration(self):

tests/integration/model_bridge/test_weight_processing_integration.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,7 @@ def head_ablation_hook(value: Float[torch.Tensor, "batch pos head_index d_head"]
8888
# ===========================================
8989
print("\n3. Loading TransformerBridge without processing...")
9090
try:
91-
bridge_unprocessed = TransformerBridge.boot_transformers(
92-
model_name, device=device, apply_weight_processing=False
93-
)
91+
bridge_unprocessed = TransformerBridge.boot_transformers(model_name, device=device)
9492

9593
print("\n Testing baseline performance...")
9694
bridge_unprocessed_baseline = bridge_unprocessed(tokens, return_type="loss")
@@ -122,9 +120,9 @@ def head_ablation_hook(value: Float[torch.Tensor, "batch pos head_index d_head"]
122120
# ===========================================
123121
print("\n4. Loading TransformerBridge with processing...")
124122
try:
125-
bridge_processed = TransformerBridge.boot_transformers(
126-
model_name, device=device, apply_weight_processing=True
127-
)
123+
bridge_processed = TransformerBridge.boot_transformers(model_name, device=device)
124+
125+
bridge_processed.process_weights()
128126

129127
print("\n Testing baseline performance...")
130128
bridge_processed_baseline = bridge_processed(tokens, return_type="loss")
@@ -288,10 +286,9 @@ def head_ablation_hook(value: Float[torch.Tensor, "batch pos head_index d_head"]
288286
if overall_success:
289287
print("\n🎉🎉🎉 FULL INTEGRATION COMPATIBILITY ACHIEVED! 🎉🎉🎉")
290288
print("TransformerBridge is fully compatible with HookedTransformer!")
291-
return True
292289
else:
293290
print("\n⚠️ Integration compatibility issues detected")
294-
return False
291+
pytest.fail("Integration compatibility issues detected")
295292

296293

297294
@pytest.mark.skip(

tests/integration/model_bridge/test_weight_processing_perfect_match.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,9 @@ def head_ablation_hook(value, hook):
127127
print("\n🎉🎉🎉 PERFECT MATCH ACHIEVED! 🎉🎉🎉")
128128
print("The corrected processing matches HookedTransformer exactly!")
129129
print("This solution can be applied to TransformerBridge for perfect ablation matching.")
130-
return True
131130
else:
132131
print("\n⚠️ Not quite perfect yet, but very close!")
133-
return False
132+
pytest.fail("Not quite perfect yet, but very close!")
134133

135134

136135
if __name__ == "__main__":

transformer_lens/benchmarks/component_outputs.py

Lines changed: 106 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -265,31 +265,44 @@ def benchmark_all_components(
265265

266266
results: List[ComponentTestResult] = []
267267

268+
# Block-type components that need to be tested recursively by layer
269+
# (they are ModuleLists that don't have direct forward methods)
270+
block_components = {"blocks", "encoder_blocks", "decoder_blocks"}
271+
268272
# Test top-level components (embed, pos_embed, ln_final, unembed)
269273
for comp_name, component in component_mapping.items():
270274
if comp_name in skip_components:
271275
continue
272276

273-
if comp_name == "blocks":
274-
# Handle blocks separately
277+
if comp_name in block_components:
278+
# Handle blocks separately - test their subcomponents by layer
275279
continue
276280

277281
result = self._test_component(comp_name, component, test_inputs)
278282
if result is not None:
279283
results.append(result)
280284

281285
# Test block components recursively
282-
if "blocks" in component_mapping and "blocks" not in skip_components:
283-
blocks_component = component_mapping["blocks"]
284-
n_layers = self.cfg.n_layers
285-
286-
for layer_idx in range(n_layers):
287-
# Recursively test each subcomponent and its nested subcomponents
288-
for subcomp_name, subcomponent in blocks_component.submodules.items():
289-
comp_path = f"blocks.{layer_idx}.{subcomp_name}"
290-
self._test_component_recursive(
291-
comp_path, subcomponent, test_inputs, results, skip_components
292-
)
286+
for block_type in block_components:
287+
if block_type in component_mapping and block_type not in skip_components:
288+
blocks_component = component_mapping[block_type]
289+
n_layers = self.cfg.n_layers
290+
291+
for layer_idx in range(n_layers):
292+
# Recursively test each subcomponent and its nested subcomponents
293+
for subcomp_name, subcomponent in blocks_component.submodules.items():
294+
comp_path = f"{block_type}.{layer_idx}.{subcomp_name}"
295+
self._test_component_recursive(
296+
comp_path, subcomponent, test_inputs, results, skip_components
297+
)
298+
299+
# Clean up test inputs to free memory
300+
if test_inputs is not None:
301+
for key in list(test_inputs.keys()):
302+
tensor = test_inputs[key]
303+
if tensor is not None and isinstance(tensor, torch.Tensor):
304+
del tensor
305+
test_inputs.clear()
293306

294307
# Create report
295308
passed = sum(1 for r in results if r.passed)
@@ -333,6 +346,58 @@ def _test_component_recursive(
333346
if component_path in skip_components:
334347
return
335348

349+
# Skip MLP components that don't exist as separate modules in HF (name=None)
350+
# These are virtual components where fc1/fc2 are directly on the layer
351+
# Component testing doesn't work for these because get_component returns the parent layer
352+
if "mlp" in component_path and hasattr(component, "name") and component.name is None:
353+
return
354+
355+
# Skip MLP components with custom forward signatures (e.g., BLOOM requires residual)
356+
# These can't be tested in isolation without full model context
357+
if "mlp" in component_path and hasattr(component, "hf_component"):
358+
import inspect
359+
360+
try:
361+
sig = inspect.signature(component.hf_component.forward)
362+
params = list(sig.parameters.keys())
363+
# Standard MLP only needs hidden_states (or self + hidden_states)
364+
# If there are more required params, skip testing
365+
if len(params) > 2: # self + hidden_states + other required params
366+
return
367+
except Exception:
368+
# If we can't inspect, proceed with testing
369+
pass
370+
371+
# Skip attention components that require position embeddings in Phase 3
372+
# These can't be tested in isolation without full model context for position embeddings
373+
if (
374+
"attn" in component_path
375+
and hasattr(component, "requires_position_embeddings")
376+
and component.requires_position_embeddings
377+
):
378+
return
379+
380+
# Skip attention components that use native HF attention (maintain_native_attention=True)
381+
# These have custom forward signatures (e.g., BLOOM requires residual, alibi, attention_mask)
382+
# and can't be tested in isolation without full model context
383+
if (
384+
"attn" in component_path
385+
and hasattr(component, "maintain_native_attention")
386+
and component.maintain_native_attention
387+
):
388+
return
389+
390+
# Skip BLOOM and T5 attention and MLP components - they have custom signatures that require
391+
# residual connections, alibi bias, or cache_position from the full model context
392+
if "attn" in component_path or "mlp" in component_path:
393+
# Check if this is a BLOOM or T5 model by looking at the HF model config
394+
hf_model_config = getattr(self.hf_model, "config", None)
395+
if hf_model_config and hasattr(hf_model_config, "model_type"):
396+
# BLOOM requires residual and alibi bias
397+
# T5 requires cache_position for relative position embeddings
398+
if hf_model_config.model_type in ["bloom", "t5"]:
399+
return
400+
336401
# Skip components that require specific shaped inputs from their parent modules
337402
# These components expect intermediate outputs from their parent attention/MLP
338403
# modules and can't be tested with generic hidden state inputs
@@ -402,7 +467,10 @@ def _test_component(
402467
"""
403468
try:
404469
# Get bridge component
405-
bridge_component = self.adapter.get_component(self.bridge_model, component_path)
470+
# The adapter returns nn.Module, but for bridge models it's actually GeneralizedComponent
471+
bridge_component = cast(
472+
GeneralizedComponent, self.adapter.get_component(self.bridge_model, component_path)
473+
)
406474

407475
# Get HuggingFace component
408476
hf_component = self.adapter.get_component(self.hf_model, component_path)
@@ -412,7 +480,14 @@ def _test_component(
412480
if test_input is None:
413481
return None
414482

415-
# For embedding components, generate token indices once to use for both
483+
# Get input args/kwargs from the Bridge component
484+
# All bridge components inherit from GeneralizedComponent and have get_dummy_inputs()
485+
batch, seq_len, _ = test_input.shape
486+
pos_indices = (
487+
torch.arange(seq_len, device=test_input.device).unsqueeze(0).expand(batch, -1)
488+
)
489+
490+
# For embedding components, generate token indices once
416491
shared_token_indices = None
417492
if component_path == "embed":
418493
batch, seq_len, _ = test_input.shape
@@ -490,13 +565,28 @@ def _test_component(
490565
bridge_tensor, hf_tensor
491566
)
492567

568+
# Get output shape before deleting tensors
569+
output_shape = tuple(bridge_tensor.shape)
570+
571+
# Clean up output tensors immediately to free memory
572+
del bridge_output, hf_output, bridge_tensor, hf_tensor
573+
if shared_inputs is not None:
574+
# Clean up shared inputs
575+
for key in list(shared_inputs.keys()):
576+
val = shared_inputs[key]
577+
if val is not None and isinstance(val, torch.Tensor):
578+
del val
579+
shared_inputs[key] = None
580+
if shared_token_indices is not None:
581+
del shared_token_indices
582+
493583
return ComponentTestResult(
494584
component_path=component_path,
495585
component_type=type(component).__name__,
496586
passed=passed,
497587
max_diff=max_diff,
498588
mean_diff=mean_diff,
499-
output_shape=tuple(bridge_tensor.shape),
589+
output_shape=output_shape,
500590
percentile_diffs=percentile_diffs,
501591
)
502592

transformer_lens/benchmarks/forward_pass.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,29 @@
1414
from transformer_lens.model_bridge import TransformerBridge
1515

1616

17+
def _is_encoder_decoder(model: torch.nn.Module) -> bool:
18+
"""Check if a model is an encoder-decoder architecture."""
19+
config = getattr(model, "config", None)
20+
if config is None:
21+
return False
22+
return getattr(config, "is_encoder_decoder", False)
23+
24+
25+
def _get_decoder_input_ids(model: torch.nn.Module, batch_size: int = 1) -> torch.Tensor:
26+
"""Get decoder_input_ids for encoder-decoder models.
27+
28+
Args:
29+
model: The model to get decoder_start_token_id from
30+
batch_size: Batch size for the decoder_input_ids
31+
32+
Returns:
33+
Tensor of shape [batch_size, 1] with decoder_start_token_id
34+
"""
35+
config = getattr(model, "config", None)
36+
decoder_start_token_id = getattr(config, "decoder_start_token_id", 0) if config else 0
37+
return torch.tensor([[decoder_start_token_id]] * batch_size)
38+
39+
1740
def benchmark_forward_pass(
1841
bridge: TransformerBridge,
1942
test_text: str,
@@ -34,8 +57,20 @@ def benchmark_forward_pass(
3457
BenchmarkResult with comparison details
3558
"""
3659
try:
60+
# Check if this is an encoder-decoder model
61+
is_enc_dec = _is_encoder_decoder(bridge.original_model)
62+
63+
# Prepare extra kwargs for encoder-decoder models
64+
extra_kwargs = {}
65+
if is_enc_dec:
66+
tokens = bridge.to_tokens(test_text)
67+
batch_size = tokens.shape[0]
68+
decoder_input_ids = _get_decoder_input_ids(bridge.original_model, batch_size)
69+
decoder_input_ids = decoder_input_ids.to(tokens.device)
70+
extra_kwargs["decoder_input_ids"] = decoder_input_ids
71+
3772
# Run bridge forward pass
38-
bridge_output = bridge(test_text, return_type="logits")
73+
bridge_output = bridge(test_text, return_type="logits", **extra_kwargs)
3974

4075
if reference_model is None:
4176
# No reference model - just verify output shape and validity
@@ -69,7 +104,14 @@ def benchmark_forward_pass(
69104
# HuggingFace model
70105
tokens = bridge.to_tokens(test_text)
71106
with torch.no_grad():
72-
hf_output = reference_model(tokens)
107+
if is_enc_dec:
108+
# Encoder-decoder models need decoder_input_ids
109+
batch_size = tokens.shape[0]
110+
decoder_input_ids = _get_decoder_input_ids(reference_model, batch_size)
111+
decoder_input_ids = decoder_input_ids.to(tokens.device)
112+
hf_output = reference_model(tokens, decoder_input_ids=decoder_input_ids)
113+
else:
114+
hf_output = reference_model(tokens)
73115
reference_output = hf_output.logits
74116

75117
return compare_tensors(

0 commit comments

Comments
 (0)