|
4 | 4 | including model initialization, text generation, hooks, and caching. |
5 | 5 | """ |
6 | 6 |
|
| 7 | +import gc |
7 | 8 | import logging |
8 | 9 |
|
9 | 10 | import pytest |
@@ -343,5 +344,331 @@ def capture_pattern_hook(tensor, hook): |
343 | 344 | bridge.blocks[0].attn.hook_pattern.remove_hooks() |
344 | 345 |
|
345 | 346 |
|
| 347 | +@pytest.mark.parametrize( |
| 348 | + "model_name", |
| 349 | + [ |
| 350 | + "gpt2", # GPT-2 architecture |
| 351 | + "distilgpt2", # DistilGPT-2 architecture (smaller GPT-2) |
| 352 | + "EleutherAI/pythia-70m", # Pythia architecture (smallest, ~70M params) |
| 353 | + "EleutherAI/gpt-neo-125M", # GPT-Neo architecture |
| 354 | + "google/gemma-2-2b-it", # Gemma architecture (Grouped Query Attention) |
| 355 | + ], |
| 356 | +) |
| 357 | +def test_get_params(model_name): |
| 358 | + """Test that get_params works correctly with different model architectures. |
| 359 | +
|
| 360 | + This test verifies that the get_params function can successfully extract |
| 361 | + parameters from various model types (GPT-2, DistilGPT-2, Pythia, GPT-Neo, Gemma) |
| 362 | + without encountering attribute errors or missing component issues. This includes |
| 363 | + models with different attention architectures like Grouped Query Attention (GQA). |
| 364 | + Covers a range of model sizes from 70M to 2B parameters. |
| 365 | +
|
| 366 | + Args: |
| 367 | + model_name: The model name to test (parameterized) |
| 368 | + """ |
| 369 | + # Clear any existing cache/memory before loading models |
| 370 | + gc.collect() |
| 371 | + if torch.cuda.is_available(): |
| 372 | + torch.cuda.empty_cache() |
| 373 | + |
| 374 | + bridge = TransformerBridge.boot_transformers(model_name) |
| 375 | + |
| 376 | + # This should not raise any exceptions |
| 377 | + try: |
| 378 | + params_dict = bridge.get_params() |
| 379 | + except Exception as e: |
| 380 | + pytest.fail(f"get_params failed for {model_name}: {e}") |
| 381 | + |
| 382 | + # Verify that we got a dictionary with expected keys |
| 383 | + assert isinstance(params_dict, dict), "get_params should return a dictionary" |
| 384 | + assert len(params_dict) > 0, "Parameters dictionary should not be empty" |
| 385 | + |
| 386 | + # Check for expected embedding parameters |
| 387 | + assert "embed.W_E" in params_dict, "Should contain embedding weights" |
| 388 | + assert "pos_embed.W_pos" in params_dict, "Should contain positional embedding weights" |
| 389 | + |
| 390 | + # Check for expected layer parameters (at least layer 0) |
| 391 | + assert "blocks.0.attn.W_Q" in params_dict, "Should contain query weights for layer 0" |
| 392 | + assert "blocks.0.attn.W_K" in params_dict, "Should contain key weights for layer 0" |
| 393 | + assert "blocks.0.attn.W_V" in params_dict, "Should contain value weights for layer 0" |
| 394 | + assert "blocks.0.attn.W_O" in params_dict, "Should contain output weights for layer 0" |
| 395 | + |
| 396 | + # Check for attention biases |
| 397 | + assert "blocks.0.attn.b_Q" in params_dict, "Should contain query biases for layer 0" |
| 398 | + assert "blocks.0.attn.b_K" in params_dict, "Should contain key biases for layer 0" |
| 399 | + assert "blocks.0.attn.b_V" in params_dict, "Should contain value biases for layer 0" |
| 400 | + assert "blocks.0.attn.b_O" in params_dict, "Should contain output biases for layer 0" |
| 401 | + |
| 402 | + # Check for MLP parameters |
| 403 | + assert "blocks.0.mlp.W_in" in params_dict, "Should contain MLP input weights for layer 0" |
| 404 | + assert "blocks.0.mlp.W_out" in params_dict, "Should contain MLP output weights for layer 0" |
| 405 | + assert "blocks.0.mlp.b_in" in params_dict, "Should contain MLP input biases for layer 0" |
| 406 | + assert "blocks.0.mlp.b_out" in params_dict, "Should contain MLP output biases for layer 0" |
| 407 | + |
| 408 | + # Check for unembedding weights |
| 409 | + assert "unembed.W_U" in params_dict, "Should contain unembedding weights" |
| 410 | + |
| 411 | + # Verify that all parameter values are tensors |
| 412 | + for key, value in params_dict.items(): |
| 413 | + assert isinstance( |
| 414 | + value, torch.Tensor |
| 415 | + ), f"Parameter {key} should be a tensor, got {type(value)}" |
| 416 | + assert value.numel() > 0, f"Parameter {key} should not be empty" |
| 417 | + |
| 418 | + # Verify tensor shapes are reasonable (not zero-dimensional) |
| 419 | + for key, value in params_dict.items(): |
| 420 | + assert ( |
| 421 | + len(value.shape) > 0 |
| 422 | + ), f"Parameter {key} should have at least 1 dimension, got shape {value.shape}" |
| 423 | + |
| 424 | + # Check that we have parameters for all layers |
| 425 | + for layer_idx in range(bridge.cfg.n_layers): |
| 426 | + assert ( |
| 427 | + f"blocks.{layer_idx}.attn.W_Q" in params_dict |
| 428 | + ), f"Should contain query weights for layer {layer_idx}" |
| 429 | + assert ( |
| 430 | + f"blocks.{layer_idx}.attn.W_K" in params_dict |
| 431 | + ), f"Should contain key weights for layer {layer_idx}" |
| 432 | + assert ( |
| 433 | + f"blocks.{layer_idx}.attn.W_V" in params_dict |
| 434 | + ), f"Should contain value weights for layer {layer_idx}" |
| 435 | + assert ( |
| 436 | + f"blocks.{layer_idx}.attn.W_O" in params_dict |
| 437 | + ), f"Should contain output weights for layer {layer_idx}" |
| 438 | + |
| 439 | + # Explicit cleanup to help CI memory management |
| 440 | + del params_dict |
| 441 | + del bridge |
| 442 | + gc.collect() |
| 443 | + if torch.cuda.is_available(): |
| 444 | + torch.cuda.empty_cache() |
| 445 | + |
| 446 | + |
| 447 | +def test_get_params_parameter_shapes(): |
| 448 | + """Test that get_params returns parameters with expected shapes for GPT-2.""" |
| 449 | + model_name = "gpt2" |
| 450 | + bridge = TransformerBridge.boot_transformers(model_name) |
| 451 | + |
| 452 | + params_dict = bridge.get_params() |
| 453 | + |
| 454 | + # Check embedding shapes |
| 455 | + embed_weight = params_dict["embed.W_E"] |
| 456 | + assert embed_weight.shape == ( |
| 457 | + bridge.cfg.d_vocab, |
| 458 | + bridge.cfg.d_model, |
| 459 | + ), f"Embedding weight shape should be ({bridge.cfg.d_vocab}, {bridge.cfg.d_model}), got {embed_weight.shape}" |
| 460 | + |
| 461 | + pos_embed_weight = params_dict["pos_embed.W_pos"] |
| 462 | + assert pos_embed_weight.shape == ( |
| 463 | + bridge.cfg.n_ctx, |
| 464 | + bridge.cfg.d_model, |
| 465 | + ), f"Position embedding weight shape should be ({bridge.cfg.n_ctx}, {bridge.cfg.d_model}), got {pos_embed_weight.shape}" |
| 466 | + |
| 467 | + # Check attention weight shapes for first layer |
| 468 | + w_q = params_dict["blocks.0.attn.W_Q"] |
| 469 | + w_k = params_dict["blocks.0.attn.W_K"] |
| 470 | + w_v = params_dict["blocks.0.attn.W_V"] |
| 471 | + w_o = params_dict["blocks.0.attn.W_O"] |
| 472 | + |
| 473 | + expected_qkv_shape = (bridge.cfg.n_heads, bridge.cfg.d_model, bridge.cfg.d_head) |
| 474 | + expected_o_shape = (bridge.cfg.n_heads, bridge.cfg.d_head, bridge.cfg.d_model) |
| 475 | + |
| 476 | + assert ( |
| 477 | + w_q.shape == expected_qkv_shape |
| 478 | + ), f"W_Q shape should be {expected_qkv_shape}, got {w_q.shape}" |
| 479 | + assert ( |
| 480 | + w_k.shape == expected_qkv_shape |
| 481 | + ), f"W_K shape should be {expected_qkv_shape}, got {w_k.shape}" |
| 482 | + assert ( |
| 483 | + w_v.shape == expected_qkv_shape |
| 484 | + ), f"W_V shape should be {expected_qkv_shape}, got {w_v.shape}" |
| 485 | + assert w_o.shape == expected_o_shape, f"W_O shape should be {expected_o_shape}, got {w_o.shape}" |
| 486 | + |
| 487 | + # Check attention bias shapes |
| 488 | + b_q = params_dict["blocks.0.attn.b_Q"] |
| 489 | + b_k = params_dict["blocks.0.attn.b_K"] |
| 490 | + b_v = params_dict["blocks.0.attn.b_V"] |
| 491 | + b_o = params_dict["blocks.0.attn.b_O"] |
| 492 | + |
| 493 | + expected_qkv_bias_shape = (bridge.cfg.n_heads, bridge.cfg.d_head) |
| 494 | + expected_o_bias_shape = (bridge.cfg.d_model,) |
| 495 | + |
| 496 | + assert ( |
| 497 | + b_q.shape == expected_qkv_bias_shape |
| 498 | + ), f"b_Q shape should be {expected_qkv_bias_shape}, got {b_q.shape}" |
| 499 | + assert ( |
| 500 | + b_k.shape == expected_qkv_bias_shape |
| 501 | + ), f"b_K shape should be {expected_qkv_bias_shape}, got {b_k.shape}" |
| 502 | + assert ( |
| 503 | + b_v.shape == expected_qkv_bias_shape |
| 504 | + ), f"b_V shape should be {expected_qkv_bias_shape}, got {b_v.shape}" |
| 505 | + assert ( |
| 506 | + b_o.shape == expected_o_bias_shape |
| 507 | + ), f"b_O shape should be {expected_o_bias_shape}, got {b_o.shape}" |
| 508 | + |
| 509 | + |
| 510 | +def test_get_params_missing_components(): |
| 511 | + """Test that get_params gracefully handles missing components with zero tensors.""" |
| 512 | + model_name = "gpt2" |
| 513 | + bridge = TransformerBridge.boot_transformers(model_name) |
| 514 | + |
| 515 | + # Test that the method works normally first |
| 516 | + params_dict = bridge.get_params() |
| 517 | + assert isinstance(params_dict, dict) |
| 518 | + |
| 519 | + # Test handling of missing components - should return zero tensors instead of exceptions |
| 520 | + # Save original components |
| 521 | + original_embed = bridge.embed |
| 522 | + original_pos_embed = bridge.pos_embed |
| 523 | + original_unembed = bridge.unembed |
| 524 | + |
| 525 | + try: |
| 526 | + # Test missing embed component - should return zero tensor |
| 527 | + del bridge.embed |
| 528 | + params_dict = bridge.get_params() |
| 529 | + assert isinstance(params_dict, dict) |
| 530 | + assert "embed.W_E" in params_dict |
| 531 | + embed_weight = params_dict["embed.W_E"] |
| 532 | + assert torch.all(embed_weight == 0), "Missing embed should be filled with zeros" |
| 533 | + assert embed_weight.shape == (bridge.cfg.d_vocab, bridge.cfg.d_model) |
| 534 | + |
| 535 | + # Restore embed, test missing pos_embed |
| 536 | + bridge.embed = original_embed |
| 537 | + del bridge.pos_embed |
| 538 | + params_dict = bridge.get_params() |
| 539 | + assert isinstance(params_dict, dict) |
| 540 | + assert "pos_embed.W_pos" in params_dict |
| 541 | + pos_embed_weight = params_dict["pos_embed.W_pos"] |
| 542 | + assert torch.all(pos_embed_weight == 0), "Missing pos_embed should be filled with zeros" |
| 543 | + assert pos_embed_weight.shape == (bridge.cfg.n_ctx, bridge.cfg.d_model) |
| 544 | + |
| 545 | + # Restore pos_embed, test missing unembed |
| 546 | + bridge.pos_embed = original_pos_embed |
| 547 | + del bridge.unembed |
| 548 | + params_dict = bridge.get_params() |
| 549 | + assert isinstance(params_dict, dict) |
| 550 | + assert "unembed.W_U" in params_dict |
| 551 | + unembed_weight = params_dict["unembed.W_U"] |
| 552 | + assert torch.all(unembed_weight == 0), "Missing unembed should be filled with zeros" |
| 553 | + assert unembed_weight.shape == (bridge.cfg.d_model, bridge.cfg.d_vocab) |
| 554 | + |
| 555 | + finally: |
| 556 | + # Always restore components |
| 557 | + bridge.embed = original_embed |
| 558 | + bridge.pos_embed = original_pos_embed |
| 559 | + bridge.unembed = original_unembed |
| 560 | + |
| 561 | + |
| 562 | +def test_get_params_consistency(): |
| 563 | + """Test that get_params returns consistent results across multiple calls.""" |
| 564 | + model_name = "gpt2" |
| 565 | + bridge = TransformerBridge.boot_transformers(model_name) |
| 566 | + |
| 567 | + # Get parameters twice |
| 568 | + params1 = bridge.get_params() |
| 569 | + params2 = bridge.get_params() |
| 570 | + |
| 571 | + # Should have same keys |
| 572 | + assert set(params1.keys()) == set( |
| 573 | + params2.keys() |
| 574 | + ), "Parameter keys should be consistent across calls" |
| 575 | + |
| 576 | + # Should have same tensor shapes and values |
| 577 | + for key in params1.keys(): |
| 578 | + assert params1[key].shape == params2[key].shape, f"Shape mismatch for {key}" |
| 579 | + assert torch.equal(params1[key], params2[key]), f"Value mismatch for {key}" |
| 580 | + |
| 581 | + |
| 582 | +def test_get_params_configuration_mismatch(): |
| 583 | + """Test that get_params raises ValueError for configuration mismatches.""" |
| 584 | + model_name = "gpt2" |
| 585 | + bridge = TransformerBridge.boot_transformers(model_name) |
| 586 | + |
| 587 | + # Test that the method works normally first |
| 588 | + params_dict = bridge.get_params() |
| 589 | + assert isinstance(params_dict, dict) |
| 590 | + |
| 591 | + # Save original configuration |
| 592 | + original_n_layers = bridge.cfg.n_layers |
| 593 | + |
| 594 | + try: |
| 595 | + # Simulate configuration mismatch - more layers in config than actual blocks |
| 596 | + bridge.cfg.n_layers = len(bridge.blocks) + 2 |
| 597 | + |
| 598 | + with pytest.raises(ValueError, match="Configuration mismatch.*blocks found"): |
| 599 | + bridge.get_params() |
| 600 | + |
| 601 | + finally: |
| 602 | + # Always restore original configuration |
| 603 | + bridge.cfg.n_layers = original_n_layers |
| 604 | + |
| 605 | + |
| 606 | +def test_get_params_multi_query_attention_reshaping(): |
| 607 | + """Test Multi-Query Attention weight reshaping logic without requiring a large model. |
| 608 | +
|
| 609 | + This test verifies that the get_params function can correctly handle different |
| 610 | + weight shapes that occur in Multi-Query Attention architectures, where K and V |
| 611 | + weights have different shapes than Q weights. |
| 612 | + """ |
| 613 | + model_name = "gpt2" |
| 614 | + bridge = TransformerBridge.boot_transformers(model_name) |
| 615 | + |
| 616 | + # Get the original attention layer to modify |
| 617 | + original_attn = bridge.blocks[0].attn |
| 618 | + original_k_weight = original_attn.k.weight.clone() |
| 619 | + original_v_weight = original_attn.v.weight.clone() |
| 620 | + |
| 621 | + try: |
| 622 | + # Test case 1: Simulate MQA where K and V have shape [d_head, d_model] |
| 623 | + # instead of [d_model, d_model] |
| 624 | + d_head = bridge.cfg.d_head |
| 625 | + d_model = bridge.cfg.d_model |
| 626 | + |
| 627 | + # Create MQA-style K and V weights with shape [d_head, d_model] |
| 628 | + mqa_k_weight = torch.randn( |
| 629 | + d_head, d_model, dtype=original_k_weight.dtype, device=original_k_weight.device |
| 630 | + ) |
| 631 | + mqa_v_weight = torch.randn( |
| 632 | + d_head, d_model, dtype=original_v_weight.dtype, device=original_v_weight.device |
| 633 | + ) |
| 634 | + |
| 635 | + # Temporarily replace the weights |
| 636 | + original_attn.k.weight.data = mqa_k_weight |
| 637 | + original_attn.v.weight.data = mqa_v_weight |
| 638 | + |
| 639 | + # This should work without raising exceptions |
| 640 | + params_dict = bridge.get_params() |
| 641 | + |
| 642 | + # Verify the weights were reshaped correctly |
| 643 | + # For MQA: K and V should be expanded from [d_head, d_model] to [n_heads, d_model, d_head] (same as Q) |
| 644 | + k_param = params_dict["blocks.0.attn.W_K"] |
| 645 | + v_param = params_dict["blocks.0.attn.W_V"] |
| 646 | + |
| 647 | + expected_shape = (bridge.cfg.n_heads, bridge.cfg.d_model, bridge.cfg.d_head) |
| 648 | + assert ( |
| 649 | + k_param.shape == expected_shape |
| 650 | + ), f"K weight should be reshaped to {expected_shape}, got {k_param.shape}" |
| 651 | + assert ( |
| 652 | + v_param.shape == expected_shape |
| 653 | + ), f"V weight should be reshaped to {expected_shape}, got {v_param.shape}" |
| 654 | + |
| 655 | + # Verify that all heads contain the transposed MQA weight (due to transpose + expand operation) |
| 656 | + expected_k_per_head = mqa_k_weight.transpose(0, 1) # [d_head, d_model] -> [d_model, d_head] |
| 657 | + expected_v_per_head = mqa_v_weight.transpose(0, 1) # [d_head, d_model] -> [d_model, d_head] |
| 658 | + |
| 659 | + for head_idx in range(bridge.cfg.n_heads): |
| 660 | + assert torch.allclose( |
| 661 | + k_param[head_idx], expected_k_per_head |
| 662 | + ), f"K head {head_idx} should match transposed MQA weight" |
| 663 | + assert torch.allclose( |
| 664 | + v_param[head_idx], expected_v_per_head |
| 665 | + ), f"V head {head_idx} should match transposed MQA weight" |
| 666 | + |
| 667 | + finally: |
| 668 | + # Always restore original weights |
| 669 | + original_attn.k.weight.data = original_k_weight |
| 670 | + original_attn.v.weight.data = original_v_weight |
| 671 | + |
| 672 | + |
346 | 673 | if __name__ == "__main__": |
347 | 674 | pytest.main([__file__]) |
0 commit comments