Add MiMo dense MTP models bridge support#2387
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds support for MiMo (Multi-Token Prediction) dense models from Xiaomi, which are built on top of the Qwen2 architecture with additional MTP layers. The implementation includes a bridge for weight conversion, provider configurations for various MiMo 7B model variants, and comprehensive test coverage.
Changes:
- Added
MimoBridgeclass extendingQwen2Bridgewith MTP-specific weight mappings and transformations - Added
MiMoModelProvider7Bbase class and 5 variant providers for different MiMo 7B models - Added comprehensive unit tests for both bridge functionality and provider configurations
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| src/megatron/bridge/models/mimo/mimo_bridge.py | Implements bridge for MiMo models with MTP layer mappings and input projection weight transformations |
| src/megatron/bridge/models/mimo/mimo_causal_provider.py | Defines provider configurations for MiMo 7B model family with MTP parameters |
| src/megatron/bridge/models/mimo/init.py | Exports new MiMo provider classes |
| src/megatron/bridge/models/init.py | Exports new MiMo providers and bridge at top level |
| tests/unit_tests/models/mimo/test_mimo_bridge.py | Tests bridge registration, provider mapping, MTP layer mappings, and weight transformations |
| tests/unit_tests/models/mimo/test_mimo_causal_provider.py | Tests provider default configurations and variant-specific settings |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
📝 WalkthroughWalkthroughThis PR adds MiMo (Xiaomi) model family support to Megatron Bridge by introducing MiMo 7B provider configurations (base, SFT, RL variants) and a custom MimoBridge class extending Qwen2Bridge with MiMo-specific weight mappings and parameter transformations. Changes
Sequence DiagramsequenceDiagram
participant HFModel as HuggingFace Model
participant MimoBridge as MimoBridge
participant Megatron as Megatron GPTModel
participant WeightDict as Weight Dictionary
HFModel->>MimoBridge: Load HF state dict (with input_proj weights)
MimoBridge->>MimoBridge: maybe_modify_loaded_hf_weight()
MimoBridge->>MimoBridge: _swap_input_proj_halves(weight)
MimoBridge->>WeightDict: Return swapped input_proj weights
MimoBridge->>MimoBridge: mapping_registry (MiMo-specific mappings)
MimoBridge->>Megatron: Convert HF weights using custom mappings
Megatron->>MimoBridge: maybe_modify_converted_hf_weight()
MimoBridge->>MimoBridge: Swap halves again for eh_proj/input_proj
MimoBridge->>Megatron: Return final converted weights
Megatron->>Megatron: Load converted weights into model
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@src/megatron/bridge/models/mimo/mimo_bridge.py`:
- Around line 49-113: The loop in mapping_registry currently iterates over
("transformer_layer", "mtp_model_layer") but "mtp_model_layer" appears unused
elsewhere (e.g., model instantiation in llava_provider.py); verify by grepping
for the literal "mtp_model_layer" and the unit test
test_mapping_registry_includes_mtp_paths; if it is dead remove "mtp_model_layer"
from the tuple so mapping_registry only builds mappings for "transformer_layer"
(leave MegatronMappingRegistry, AutoMapping, QKVMapping, GatedMLPMapping usage
unchanged), and update or remove the test that expects mtp paths to reflect the
real mapping; if you find other live references keep both prefixes and leave the
code/tests as-is.
🧹 Nitpick comments (2)
src/megatron/bridge/models/mimo/mimo_bridge.py (2)
15-15: Use built-in generics instead oftyping.Dict.Per coding guidelines, prefer
dictovertyping.Dictfor Python 3.10+.Mappingfromtypingis fine (it's a protocol, not a generic alias), butDicton lines 131 and 133 should bedict.♻️ Suggested fix
-from typing import Dict, Mapping +from collections.abc import MappingThen update usages on lines 131 and 147:
- ) -> Dict[str, torch.Tensor]: + ) -> dict[str, torch.Tensor]:As per coding guidelines: "Use built-in generics (list, dict, tuple) instead of typing equivalents."
115-118: Add tensor shape validation before splitting.The coding guidelines require validating tensor shapes before weight conversion.
_swap_input_proj_halvesassumesdim=1exists and is evenly divisible by 2. A defensive check would prevent silent data corruption if called with an unexpected tensor shape.🛡️ Suggested shape validation
`@staticmethod` def _swap_input_proj_halves(weight: torch.Tensor) -> torch.Tensor: + assert weight.ndim >= 2 and weight.shape[1] % 2 == 0, ( + f"Expected weight with even dim-1, got shape {weight.shape}" + ) first_half, second_half = weight.chunk(2, dim=1) return torch.cat((second_half, first_half), dim=1)As per coding guidelines: "Always validate tensor shapes before copying weights in weight conversion."
12ff792 to
f46ddd4
Compare
7f66f95 to
a755a15
Compare
|
@HollowMan6 No need to do individual model provider anymore, just use bridge to override if necessary. There is a list of common mapping should take care of most. |
MiMo adds MTP (Multi-Token Prediction) layers on top of Qwen2 architecture, so these models are very helpful for debugging MTP features. Refer to https://github.com/ISEEKYAN/mbridge/blob/main/mbridge/models/mimo.py Signed-off-by: Hollow Man <hollowman@opensuse.org>
a755a15 to
a8ad6b9
Compare
|
Updated @yaoyu-33 |
|
/ok to test a8ad6b9 |
What does this PR do ?
MiMo adds MTP (Multi-Token Prediction) layers on top of Qwen2 architecture, so these models are very helpful for debugging MTP features.
Refer to https://github.com/ISEEKYAN/mbridge/blob/main/mbridge/models/mimo.py
Tested together with VeRL and looks good https://github.com/verl-project/verl/blob/main/docs/advance/mtp.md
Changelog
GitHub Actions CI
See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.
Before your PR is "Ready for review"
Pre checks:
If you haven't finished some of the above items you can still open "Draft" PR.
Additional Information
Summary by CodeRabbit
Release Notes
New Features
Tests