-
Notifications
You must be signed in to change notification settings - Fork 480
Add Gemma 3 and MedGemma model support #1149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Add Gemma 3 and MedGemma model support #1149
Conversation
Release v2.16.2
- Add google/gemma-3-270m and google/gemma-3-270m-it to supported models - Add architecture detection for Gemma3ForCausalLM - Add hardcoded configuration with d_head=256 and use_qk_norm=True - Add Q/K normalization weight loading in gemma weight converter
- Add google/gemma-3-1b-pt and google/gemma-3-1b-it to supported models - Add configuration with d_model=1152, d_mlp=6912, n_layers=26 - Maintains d_head=256 (hardcoded for all Gemma models) - Includes use_qk_norm=True and use_normalization_before_and_after=True
…xtraction - Add google/gemma-3-4b-pt, gemma-3-4b-it, medgemma-4b-pt, medgemma-4b-it - Implement pattern-based architecture detection (CausalLM vs ConditionalGeneration) - Add 4B config with GQA support (n_key_value_heads=4) - Extract text-only weights from multimodal models via language_model component - Add AutoModel loader for Gemma3ForConditionalGeneration architecture
Add device parameter to all torch.zeros() calls in gemma weight conversion to ensure bias tensors are created on the same device as weight tensors. This fixes RuntimeError when loading Gemma models on Apple Silicon with MPS backend. - Add device parameter to attention biases (b_Q, b_K, b_V, b_O) - Add device parameter to MLP biases (b_in, b_out) - Add device parameter to unembed bias (b_U) - Handle both lm_head and tied embeddings for unembed device
- Reduce default context: 270M/1B (32K->8K), 4B (131K->8K) - Add n_ctx parameter for context length override - Fix multimodal weight extraction (nested model access) - Add kwargs filtering for n_ctx parameter
- Added 6 new models: gemma-3-12b-pt/it, gemma-3-27b-pt/it, medgemma-27b-it/text-it - 12B config: 3840 d_model, 48 layers, 16 heads, 8 KV heads (2:1 GQA) - 27B config: 5376 d_model, 62 layers, 32 heads, 16 KV heads (2:1 GQA) - All use safe 8K default context (overridable to 131K) - Special handling for medgemma-27b-text-it (text-only, 262144 vocab)
…ll huggingface-hub
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds comprehensive support for Google's Gemma 3 and MedGemma model families, enabling mechanistic interpretability research on 14 new model variants ranging from 270M to 27B parameters. The implementation includes novel architectural features like hybrid local/global attention patterns and per-layer RoPE bases, along with multimodal text-only weight extraction for vision-language models.
Key changes:
- Added 14 new Gemma 3 and MedGemma model configurations with hybrid attention support
- Implemented multimodal model weight extraction to handle both text-only and vision-language architectures
- Introduced new configuration parameters (
use_qk_norm,rotary_base_local) for advanced attention mechanisms - Pinned dependency versions and updated CI workflows to resolve notebook testing failures
- Deprecated IPython magic method usage in notebooks
Reviewed changes
Copilot reviewed 32 out of 34 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| transformer_lens/loading_from_pretrained.py | Added 14 new model entries, architecture detection logic for Gemma 3 variants, and extensive model configurations with hybrid attention patterns |
| transformer_lens/pretrained/weight_conversions/gemma.py | Implemented multimodal model detection and weight extraction from both text-only and vision-language architectures |
| transformer_lens/HookedTransformerConfig.py | Added use_qk_norm and rotary_base_local configuration parameters with documentation |
| transformer_lens/components/abstract_attention.py | Implemented per-layer RoPE base selection for hybrid local/global attention |
| tests/unit/test_gemma3_config.py | Comprehensive test suite for Gemma 3 configuration generation and architectural features |
| tests/unit/pretrained_weight_conversions/test_gemma.py | Unit tests for multimodal and text-only weight conversion with shape validation |
| pyproject.toml | Pinned dependency versions for huggingface-hub, protobuf, transformers, and added gradio to dev dependencies |
| .github/workflows/checks.yml | Added disk space cleanup, virtualenv configuration, and dependency version verification steps |
| demos/Interactive_Neuroscope.ipynb | Updated deprecated IPython magic methods and modified package installation conditions |
| demos/Colab_Compatibility.ipynb | Updated deprecated IPython magic methods and registered Poetry kernel |
| transformer_lens/utils.py | Added trailing commas for code style consistency |
| transformer_lens/HookedTransformer.py | Added documentation for n_ctx parameter override functionality |
| Multiple component files | Added blank lines after docstrings for PEP 8 compliance |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| torch=[{version="<2.6", python=">=3.8,<3.9"}, {version=">=2.6", python=">=3.9"}] | ||
| tqdm=">=4.64.1" | ||
| transformers=[{version="<4.51", python=">=3.8,<3.9"}, {version=">=4.51", python=">=3.9"}] | ||
| transformers=[{version="<4.46.0", python=">=3.8,<3.9"}, {version="4.46.3", python=">=3.9"}] |
Copilot
AI
Dec 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The transformers version is pinned to exactly 4.46.3 for Python 3.9+. This rigid pinning could cause compatibility issues with other packages that may require different versions. Consider using a version range like >=4.46.3,<4.47 or ~=4.46.3 instead to allow patch version updates while maintaining compatibility.
| transformers=[{version="<4.46.0", python=">=3.8,<3.9"}, {version="4.46.3", python=">=3.9"}] | |
| transformers=[{version="<4.46.0", python=">=3.8,<3.9"}, {version=">=4.46.3,<4.47", python=">=3.9"}] |
Co-authored-by: Copilot <[email protected]>
Co-authored-by: Copilot <[email protected]>
This reverts commit 95cc561.
Description
Add support for Gemma 3 (270M, 1B, 4B, 12B, 27B) and MedGemma (4B, 27B) model families for mechanistic interpretability research.
Motivation: Enable TransformerLens analysis on Google's latest Gemma 3 models, including the medically fine-tuned MedGemma variants for healthcare AI research. This work supports an ongoing research collaboration with Great Ormond Street Hospital's DRIVE unit on mechanistic interpretability of medical LLMs.
14 new models added:
New architectural features implemented:
rotary_base_localparameteruse_qk_normparameterDependencies: No new dependencies required for Gemma 3 support. This PR also includes minor CI fixes that pin existing dependency versions (
huggingface-hub,protobuf,transformers) to resolve notebook CI failures.Type of change
Checklist: