Skip to content

Conversation

@huseyincavusbi
Copy link

Description

Add support for Gemma 3 (270M, 1B, 4B, 12B, 27B) and MedGemma (4B, 27B) model families for mechanistic interpretability research.

Motivation: Enable TransformerLens analysis on Google's latest Gemma 3 models, including the medically fine-tuned MedGemma variants for healthcare AI research. This work supports an ongoing research collaboration with Great Ormond Street Hospital's DRIVE unit on mechanistic interpretability of medical LLMs.

14 new models added:

  • Gemma 3: 270M, 1B, 4B, 12B, 27B (pt and it variants)
  • MedGemma: 4B, 27B variants

New architectural features implemented:

  • Hybrid local/global attention (5:1 sliding window pattern)
  • Per-layer RoPE bases (10k local, 1M global) via new rotary_base_local parameter
  • Q/K normalization via new use_qk_norm parameter
  • Multimodal text-only weight extraction for vision-language models
  • Memory-safe 8K default context (overridable to 131K)

Dependencies: No new dependencies required for Gemma 3 support. This PR also includes minor CI fixes that pin existing dependency versions (huggingface-hub, protobuf, transformers) to resolve notebook CI failures.

Type of change

  • New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

bryce13950 and others added 30 commits June 12, 2025 11:19
- Add google/gemma-3-270m and google/gemma-3-270m-it to supported models
- Add architecture detection for Gemma3ForCausalLM
- Add hardcoded configuration with d_head=256 and use_qk_norm=True
- Add Q/K normalization weight loading in gemma weight converter
- Add google/gemma-3-1b-pt and google/gemma-3-1b-it to supported models
- Add configuration with d_model=1152, d_mlp=6912, n_layers=26
- Maintains d_head=256 (hardcoded for all Gemma models)
- Includes use_qk_norm=True and use_normalization_before_and_after=True
…xtraction

- Add google/gemma-3-4b-pt, gemma-3-4b-it, medgemma-4b-pt, medgemma-4b-it
- Implement pattern-based architecture detection (CausalLM vs ConditionalGeneration)
- Add 4B config with GQA support (n_key_value_heads=4)
- Extract text-only weights from multimodal models via language_model component
- Add AutoModel loader for Gemma3ForConditionalGeneration architecture
Add device parameter to all torch.zeros() calls in gemma weight conversion
to ensure bias tensors are created on the same device as weight tensors.
This fixes RuntimeError when loading Gemma models on Apple Silicon with MPS backend.

- Add device parameter to attention biases (b_Q, b_K, b_V, b_O)
- Add device parameter to MLP biases (b_in, b_out)
- Add device parameter to unembed bias (b_U)
- Handle both lm_head and tied embeddings for unembed device
- Reduce default context: 270M/1B (32K->8K), 4B (131K->8K)
- Add n_ctx parameter for context length override
- Fix multimodal weight extraction (nested model access)
- Add kwargs filtering for n_ctx parameter
- Added 6 new models: gemma-3-12b-pt/it, gemma-3-27b-pt/it, medgemma-27b-it/text-it
- 12B config: 3840 d_model, 48 layers, 16 heads, 8 KV heads (2:1 GQA)
- 27B config: 5376 d_model, 62 layers, 32 heads, 16 KV heads (2:1 GQA)
- All use safe 8K default context (overridable to 131K)
- Special handling for medgemma-27b-text-it (text-only, 262144 vocab)
Copilot AI review requested due to automatic review settings December 9, 2025 15:03
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds comprehensive support for Google's Gemma 3 and MedGemma model families, enabling mechanistic interpretability research on 14 new model variants ranging from 270M to 27B parameters. The implementation includes novel architectural features like hybrid local/global attention patterns and per-layer RoPE bases, along with multimodal text-only weight extraction for vision-language models.

Key changes:

  • Added 14 new Gemma 3 and MedGemma model configurations with hybrid attention support
  • Implemented multimodal model weight extraction to handle both text-only and vision-language architectures
  • Introduced new configuration parameters (use_qk_norm, rotary_base_local) for advanced attention mechanisms
  • Pinned dependency versions and updated CI workflows to resolve notebook testing failures
  • Deprecated IPython magic method usage in notebooks

Reviewed changes

Copilot reviewed 32 out of 34 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
transformer_lens/loading_from_pretrained.py Added 14 new model entries, architecture detection logic for Gemma 3 variants, and extensive model configurations with hybrid attention patterns
transformer_lens/pretrained/weight_conversions/gemma.py Implemented multimodal model detection and weight extraction from both text-only and vision-language architectures
transformer_lens/HookedTransformerConfig.py Added use_qk_norm and rotary_base_local configuration parameters with documentation
transformer_lens/components/abstract_attention.py Implemented per-layer RoPE base selection for hybrid local/global attention
tests/unit/test_gemma3_config.py Comprehensive test suite for Gemma 3 configuration generation and architectural features
tests/unit/pretrained_weight_conversions/test_gemma.py Unit tests for multimodal and text-only weight conversion with shape validation
pyproject.toml Pinned dependency versions for huggingface-hub, protobuf, transformers, and added gradio to dev dependencies
.github/workflows/checks.yml Added disk space cleanup, virtualenv configuration, and dependency version verification steps
demos/Interactive_Neuroscope.ipynb Updated deprecated IPython magic methods and modified package installation conditions
demos/Colab_Compatibility.ipynb Updated deprecated IPython magic methods and registered Poetry kernel
transformer_lens/utils.py Added trailing commas for code style consistency
transformer_lens/HookedTransformer.py Added documentation for n_ctx parameter override functionality
Multiple component files Added blank lines after docstrings for PEP 8 compliance

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

torch=[{version="<2.6", python=">=3.8,<3.9"}, {version=">=2.6", python=">=3.9"}]
tqdm=">=4.64.1"
transformers=[{version="<4.51", python=">=3.8,<3.9"}, {version=">=4.51", python=">=3.9"}]
transformers=[{version="<4.46.0", python=">=3.8,<3.9"}, {version="4.46.3", python=">=3.9"}]
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The transformers version is pinned to exactly 4.46.3 for Python 3.9+. This rigid pinning could cause compatibility issues with other packages that may require different versions. Consider using a version range like >=4.46.3,<4.47 or ~=4.46.3 instead to allow patch version updates while maintaining compatibility.

Suggested change
transformers=[{version="<4.46.0", python=">=3.8,<3.9"}, {version="4.46.3", python=">=3.9"}]
transformers=[{version="<4.46.0", python=">=3.8,<3.9"}, {version=">=4.46.3,<4.47", python=">=3.9"}]

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants