Skip to content

Commit 0a58e98

Browse files
authored
updated get params to fill zeroes when needed (#1049)
* updated get params to fill zeroes when needed * fixed mlp typings * optimized checks a bit * fixed tests * ran garbage collect to improve performance
1 parent 83d8a7a commit 0a58e98

File tree

2 files changed

+570
-42
lines changed

2 files changed

+570
-42
lines changed

tests/integration/model_bridge/test_bridge_integration.py

Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
including model initialization, text generation, hooks, and caching.
55
"""
66

7+
import gc
78
import logging
89

910
import pytest
@@ -343,5 +344,331 @@ def capture_pattern_hook(tensor, hook):
343344
bridge.blocks[0].attn.hook_pattern.remove_hooks()
344345

345346

347+
@pytest.mark.parametrize(
348+
"model_name",
349+
[
350+
"gpt2", # GPT-2 architecture
351+
"distilgpt2", # DistilGPT-2 architecture (smaller GPT-2)
352+
"EleutherAI/pythia-70m", # Pythia architecture (smallest, ~70M params)
353+
"EleutherAI/gpt-neo-125M", # GPT-Neo architecture
354+
"google/gemma-2-2b-it", # Gemma architecture (Grouped Query Attention)
355+
],
356+
)
357+
def test_get_params(model_name):
358+
"""Test that get_params works correctly with different model architectures.
359+
360+
This test verifies that the get_params function can successfully extract
361+
parameters from various model types (GPT-2, DistilGPT-2, Pythia, GPT-Neo, Gemma)
362+
without encountering attribute errors or missing component issues. This includes
363+
models with different attention architectures like Grouped Query Attention (GQA).
364+
Covers a range of model sizes from 70M to 2B parameters.
365+
366+
Args:
367+
model_name: The model name to test (parameterized)
368+
"""
369+
# Clear any existing cache/memory before loading models
370+
gc.collect()
371+
if torch.cuda.is_available():
372+
torch.cuda.empty_cache()
373+
374+
bridge = TransformerBridge.boot_transformers(model_name)
375+
376+
# This should not raise any exceptions
377+
try:
378+
params_dict = bridge.get_params()
379+
except Exception as e:
380+
pytest.fail(f"get_params failed for {model_name}: {e}")
381+
382+
# Verify that we got a dictionary with expected keys
383+
assert isinstance(params_dict, dict), "get_params should return a dictionary"
384+
assert len(params_dict) > 0, "Parameters dictionary should not be empty"
385+
386+
# Check for expected embedding parameters
387+
assert "embed.W_E" in params_dict, "Should contain embedding weights"
388+
assert "pos_embed.W_pos" in params_dict, "Should contain positional embedding weights"
389+
390+
# Check for expected layer parameters (at least layer 0)
391+
assert "blocks.0.attn.W_Q" in params_dict, "Should contain query weights for layer 0"
392+
assert "blocks.0.attn.W_K" in params_dict, "Should contain key weights for layer 0"
393+
assert "blocks.0.attn.W_V" in params_dict, "Should contain value weights for layer 0"
394+
assert "blocks.0.attn.W_O" in params_dict, "Should contain output weights for layer 0"
395+
396+
# Check for attention biases
397+
assert "blocks.0.attn.b_Q" in params_dict, "Should contain query biases for layer 0"
398+
assert "blocks.0.attn.b_K" in params_dict, "Should contain key biases for layer 0"
399+
assert "blocks.0.attn.b_V" in params_dict, "Should contain value biases for layer 0"
400+
assert "blocks.0.attn.b_O" in params_dict, "Should contain output biases for layer 0"
401+
402+
# Check for MLP parameters
403+
assert "blocks.0.mlp.W_in" in params_dict, "Should contain MLP input weights for layer 0"
404+
assert "blocks.0.mlp.W_out" in params_dict, "Should contain MLP output weights for layer 0"
405+
assert "blocks.0.mlp.b_in" in params_dict, "Should contain MLP input biases for layer 0"
406+
assert "blocks.0.mlp.b_out" in params_dict, "Should contain MLP output biases for layer 0"
407+
408+
# Check for unembedding weights
409+
assert "unembed.W_U" in params_dict, "Should contain unembedding weights"
410+
411+
# Verify that all parameter values are tensors
412+
for key, value in params_dict.items():
413+
assert isinstance(
414+
value, torch.Tensor
415+
), f"Parameter {key} should be a tensor, got {type(value)}"
416+
assert value.numel() > 0, f"Parameter {key} should not be empty"
417+
418+
# Verify tensor shapes are reasonable (not zero-dimensional)
419+
for key, value in params_dict.items():
420+
assert (
421+
len(value.shape) > 0
422+
), f"Parameter {key} should have at least 1 dimension, got shape {value.shape}"
423+
424+
# Check that we have parameters for all layers
425+
for layer_idx in range(bridge.cfg.n_layers):
426+
assert (
427+
f"blocks.{layer_idx}.attn.W_Q" in params_dict
428+
), f"Should contain query weights for layer {layer_idx}"
429+
assert (
430+
f"blocks.{layer_idx}.attn.W_K" in params_dict
431+
), f"Should contain key weights for layer {layer_idx}"
432+
assert (
433+
f"blocks.{layer_idx}.attn.W_V" in params_dict
434+
), f"Should contain value weights for layer {layer_idx}"
435+
assert (
436+
f"blocks.{layer_idx}.attn.W_O" in params_dict
437+
), f"Should contain output weights for layer {layer_idx}"
438+
439+
# Explicit cleanup to help CI memory management
440+
del params_dict
441+
del bridge
442+
gc.collect()
443+
if torch.cuda.is_available():
444+
torch.cuda.empty_cache()
445+
446+
447+
def test_get_params_parameter_shapes():
448+
"""Test that get_params returns parameters with expected shapes for GPT-2."""
449+
model_name = "gpt2"
450+
bridge = TransformerBridge.boot_transformers(model_name)
451+
452+
params_dict = bridge.get_params()
453+
454+
# Check embedding shapes
455+
embed_weight = params_dict["embed.W_E"]
456+
assert embed_weight.shape == (
457+
bridge.cfg.d_vocab,
458+
bridge.cfg.d_model,
459+
), f"Embedding weight shape should be ({bridge.cfg.d_vocab}, {bridge.cfg.d_model}), got {embed_weight.shape}"
460+
461+
pos_embed_weight = params_dict["pos_embed.W_pos"]
462+
assert pos_embed_weight.shape == (
463+
bridge.cfg.n_ctx,
464+
bridge.cfg.d_model,
465+
), f"Position embedding weight shape should be ({bridge.cfg.n_ctx}, {bridge.cfg.d_model}), got {pos_embed_weight.shape}"
466+
467+
# Check attention weight shapes for first layer
468+
w_q = params_dict["blocks.0.attn.W_Q"]
469+
w_k = params_dict["blocks.0.attn.W_K"]
470+
w_v = params_dict["blocks.0.attn.W_V"]
471+
w_o = params_dict["blocks.0.attn.W_O"]
472+
473+
expected_qkv_shape = (bridge.cfg.n_heads, bridge.cfg.d_model, bridge.cfg.d_head)
474+
expected_o_shape = (bridge.cfg.n_heads, bridge.cfg.d_head, bridge.cfg.d_model)
475+
476+
assert (
477+
w_q.shape == expected_qkv_shape
478+
), f"W_Q shape should be {expected_qkv_shape}, got {w_q.shape}"
479+
assert (
480+
w_k.shape == expected_qkv_shape
481+
), f"W_K shape should be {expected_qkv_shape}, got {w_k.shape}"
482+
assert (
483+
w_v.shape == expected_qkv_shape
484+
), f"W_V shape should be {expected_qkv_shape}, got {w_v.shape}"
485+
assert w_o.shape == expected_o_shape, f"W_O shape should be {expected_o_shape}, got {w_o.shape}"
486+
487+
# Check attention bias shapes
488+
b_q = params_dict["blocks.0.attn.b_Q"]
489+
b_k = params_dict["blocks.0.attn.b_K"]
490+
b_v = params_dict["blocks.0.attn.b_V"]
491+
b_o = params_dict["blocks.0.attn.b_O"]
492+
493+
expected_qkv_bias_shape = (bridge.cfg.n_heads, bridge.cfg.d_head)
494+
expected_o_bias_shape = (bridge.cfg.d_model,)
495+
496+
assert (
497+
b_q.shape == expected_qkv_bias_shape
498+
), f"b_Q shape should be {expected_qkv_bias_shape}, got {b_q.shape}"
499+
assert (
500+
b_k.shape == expected_qkv_bias_shape
501+
), f"b_K shape should be {expected_qkv_bias_shape}, got {b_k.shape}"
502+
assert (
503+
b_v.shape == expected_qkv_bias_shape
504+
), f"b_V shape should be {expected_qkv_bias_shape}, got {b_v.shape}"
505+
assert (
506+
b_o.shape == expected_o_bias_shape
507+
), f"b_O shape should be {expected_o_bias_shape}, got {b_o.shape}"
508+
509+
510+
def test_get_params_missing_components():
511+
"""Test that get_params gracefully handles missing components with zero tensors."""
512+
model_name = "gpt2"
513+
bridge = TransformerBridge.boot_transformers(model_name)
514+
515+
# Test that the method works normally first
516+
params_dict = bridge.get_params()
517+
assert isinstance(params_dict, dict)
518+
519+
# Test handling of missing components - should return zero tensors instead of exceptions
520+
# Save original components
521+
original_embed = bridge.embed
522+
original_pos_embed = bridge.pos_embed
523+
original_unembed = bridge.unembed
524+
525+
try:
526+
# Test missing embed component - should return zero tensor
527+
del bridge.embed
528+
params_dict = bridge.get_params()
529+
assert isinstance(params_dict, dict)
530+
assert "embed.W_E" in params_dict
531+
embed_weight = params_dict["embed.W_E"]
532+
assert torch.all(embed_weight == 0), "Missing embed should be filled with zeros"
533+
assert embed_weight.shape == (bridge.cfg.d_vocab, bridge.cfg.d_model)
534+
535+
# Restore embed, test missing pos_embed
536+
bridge.embed = original_embed
537+
del bridge.pos_embed
538+
params_dict = bridge.get_params()
539+
assert isinstance(params_dict, dict)
540+
assert "pos_embed.W_pos" in params_dict
541+
pos_embed_weight = params_dict["pos_embed.W_pos"]
542+
assert torch.all(pos_embed_weight == 0), "Missing pos_embed should be filled with zeros"
543+
assert pos_embed_weight.shape == (bridge.cfg.n_ctx, bridge.cfg.d_model)
544+
545+
# Restore pos_embed, test missing unembed
546+
bridge.pos_embed = original_pos_embed
547+
del bridge.unembed
548+
params_dict = bridge.get_params()
549+
assert isinstance(params_dict, dict)
550+
assert "unembed.W_U" in params_dict
551+
unembed_weight = params_dict["unembed.W_U"]
552+
assert torch.all(unembed_weight == 0), "Missing unembed should be filled with zeros"
553+
assert unembed_weight.shape == (bridge.cfg.d_model, bridge.cfg.d_vocab)
554+
555+
finally:
556+
# Always restore components
557+
bridge.embed = original_embed
558+
bridge.pos_embed = original_pos_embed
559+
bridge.unembed = original_unembed
560+
561+
562+
def test_get_params_consistency():
563+
"""Test that get_params returns consistent results across multiple calls."""
564+
model_name = "gpt2"
565+
bridge = TransformerBridge.boot_transformers(model_name)
566+
567+
# Get parameters twice
568+
params1 = bridge.get_params()
569+
params2 = bridge.get_params()
570+
571+
# Should have same keys
572+
assert set(params1.keys()) == set(
573+
params2.keys()
574+
), "Parameter keys should be consistent across calls"
575+
576+
# Should have same tensor shapes and values
577+
for key in params1.keys():
578+
assert params1[key].shape == params2[key].shape, f"Shape mismatch for {key}"
579+
assert torch.equal(params1[key], params2[key]), f"Value mismatch for {key}"
580+
581+
582+
def test_get_params_configuration_mismatch():
583+
"""Test that get_params raises ValueError for configuration mismatches."""
584+
model_name = "gpt2"
585+
bridge = TransformerBridge.boot_transformers(model_name)
586+
587+
# Test that the method works normally first
588+
params_dict = bridge.get_params()
589+
assert isinstance(params_dict, dict)
590+
591+
# Save original configuration
592+
original_n_layers = bridge.cfg.n_layers
593+
594+
try:
595+
# Simulate configuration mismatch - more layers in config than actual blocks
596+
bridge.cfg.n_layers = len(bridge.blocks) + 2
597+
598+
with pytest.raises(ValueError, match="Configuration mismatch.*blocks found"):
599+
bridge.get_params()
600+
601+
finally:
602+
# Always restore original configuration
603+
bridge.cfg.n_layers = original_n_layers
604+
605+
606+
def test_get_params_multi_query_attention_reshaping():
607+
"""Test Multi-Query Attention weight reshaping logic without requiring a large model.
608+
609+
This test verifies that the get_params function can correctly handle different
610+
weight shapes that occur in Multi-Query Attention architectures, where K and V
611+
weights have different shapes than Q weights.
612+
"""
613+
model_name = "gpt2"
614+
bridge = TransformerBridge.boot_transformers(model_name)
615+
616+
# Get the original attention layer to modify
617+
original_attn = bridge.blocks[0].attn
618+
original_k_weight = original_attn.k.weight.clone()
619+
original_v_weight = original_attn.v.weight.clone()
620+
621+
try:
622+
# Test case 1: Simulate MQA where K and V have shape [d_head, d_model]
623+
# instead of [d_model, d_model]
624+
d_head = bridge.cfg.d_head
625+
d_model = bridge.cfg.d_model
626+
627+
# Create MQA-style K and V weights with shape [d_head, d_model]
628+
mqa_k_weight = torch.randn(
629+
d_head, d_model, dtype=original_k_weight.dtype, device=original_k_weight.device
630+
)
631+
mqa_v_weight = torch.randn(
632+
d_head, d_model, dtype=original_v_weight.dtype, device=original_v_weight.device
633+
)
634+
635+
# Temporarily replace the weights
636+
original_attn.k.weight.data = mqa_k_weight
637+
original_attn.v.weight.data = mqa_v_weight
638+
639+
# This should work without raising exceptions
640+
params_dict = bridge.get_params()
641+
642+
# Verify the weights were reshaped correctly
643+
# For MQA: K and V should be expanded from [d_head, d_model] to [n_heads, d_model, d_head] (same as Q)
644+
k_param = params_dict["blocks.0.attn.W_K"]
645+
v_param = params_dict["blocks.0.attn.W_V"]
646+
647+
expected_shape = (bridge.cfg.n_heads, bridge.cfg.d_model, bridge.cfg.d_head)
648+
assert (
649+
k_param.shape == expected_shape
650+
), f"K weight should be reshaped to {expected_shape}, got {k_param.shape}"
651+
assert (
652+
v_param.shape == expected_shape
653+
), f"V weight should be reshaped to {expected_shape}, got {v_param.shape}"
654+
655+
# Verify that all heads contain the transposed MQA weight (due to transpose + expand operation)
656+
expected_k_per_head = mqa_k_weight.transpose(0, 1) # [d_head, d_model] -> [d_model, d_head]
657+
expected_v_per_head = mqa_v_weight.transpose(0, 1) # [d_head, d_model] -> [d_model, d_head]
658+
659+
for head_idx in range(bridge.cfg.n_heads):
660+
assert torch.allclose(
661+
k_param[head_idx], expected_k_per_head
662+
), f"K head {head_idx} should match transposed MQA weight"
663+
assert torch.allclose(
664+
v_param[head_idx], expected_v_per_head
665+
), f"V head {head_idx} should match transposed MQA weight"
666+
667+
finally:
668+
# Always restore original weights
669+
original_attn.k.weight.data = original_k_weight
670+
original_attn.v.weight.data = original_v_weight
671+
672+
346673
if __name__ == "__main__":
347674
pytest.main([__file__])

0 commit comments

Comments
 (0)