Skip to content

Add MiMo dense MTP models bridge support#2387

Merged
yaoyu-33 merged 1 commit intoNVIDIA-NeMo:mainfrom
HollowMan6:mimo_dense
Feb 25, 2026
Merged

Add MiMo dense MTP models bridge support#2387
yaoyu-33 merged 1 commit intoNVIDIA-NeMo:mainfrom
HollowMan6:mimo_dense

Conversation

@HollowMan6
Copy link
Contributor

@HollowMan6 HollowMan6 commented Feb 14, 2026

What does this PR do ?

MiMo adds MTP (Multi-Token Prediction) layers on top of Qwen2 architecture, so these models are very helpful for debugging MTP features.

Refer to https://github.com/ISEEKYAN/mbridge/blob/main/mbridge/models/mimo.py

Tested together with VeRL and looks good https://github.com/verl-project/verl/blob/main/docs/advance/mtp.md

Changelog

  • Add MiMo dense MTP models bridge support

GitHub Actions CI

See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)

Summary by CodeRabbit

Release Notes

  • New Features

    • Added support for MiMo 7B models with multiple pre-configured variants including Base, Supervised Fine-Tune, and Reinforcement Learning configurations.
    • Extended context support available with the RL-0530 variant, enabling processing of 65K token sequences.
    • Integrated model bridge enabling seamless MiMo model conversion and deployment.
  • Tests

    • Added comprehensive unit tests validating MiMo model configurations and bridge integration.

Copilot AI review requested due to automatic review settings February 14, 2026 16:20
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 14, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds support for MiMo (Multi-Token Prediction) dense models from Xiaomi, which are built on top of the Qwen2 architecture with additional MTP layers. The implementation includes a bridge for weight conversion, provider configurations for various MiMo 7B model variants, and comprehensive test coverage.

Changes:

  • Added MimoBridge class extending Qwen2Bridge with MTP-specific weight mappings and transformations
  • Added MiMoModelProvider7B base class and 5 variant providers for different MiMo 7B models
  • Added comprehensive unit tests for both bridge functionality and provider configurations

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated no comments.

Show a summary per file
File Description
src/megatron/bridge/models/mimo/mimo_bridge.py Implements bridge for MiMo models with MTP layer mappings and input projection weight transformations
src/megatron/bridge/models/mimo/mimo_causal_provider.py Defines provider configurations for MiMo 7B model family with MTP parameters
src/megatron/bridge/models/mimo/init.py Exports new MiMo provider classes
src/megatron/bridge/models/init.py Exports new MiMo providers and bridge at top level
tests/unit_tests/models/mimo/test_mimo_bridge.py Tests bridge registration, provider mapping, MTP layer mappings, and weight transformations
tests/unit_tests/models/mimo/test_mimo_causal_provider.py Tests provider default configurations and variant-specific settings

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 14, 2026

📝 Walkthrough

Walkthrough

This PR adds MiMo (Xiaomi) model family support to Megatron Bridge by introducing MiMo 7B provider configurations (base, SFT, RL variants) and a custom MimoBridge class extending Qwen2Bridge with MiMo-specific weight mappings and parameter transformations.

Changes

Cohort / File(s) Summary
MiMo Provider Configuration
src/megatron/bridge/models/mimo/mimo_causal_provider.py
New module defining six dataclass variants of MiMo 7B model providers (base, BSFT, BRL, BRLZero, BRL0530) with specific hyperparameters including hidden size (4096), num_layers (32), attention heads (32), MTP settings, and dtype configurations. BRL0530 variant overrides seq_length to 65536.
MiMo Bridge Implementation
src/megatron/bridge/models/mimo/mimo_bridge.py
New MimoBridge class extending Qwen2Bridge with custom provider_bridge override (disables qk_layernorm, enables qkv_bias, configures MTP layers), augmented mapping_registry with MiMo-specific layer mappings, and weight transformation logic including _swap_input_proj_halves helper to swap weight tensor halves for MiMo's dual-channel projections.
Module Exports
src/megatron/bridge/models/__init__.py, src/megatron/bridge/models/mimo/__init__.py
Added public API exports for six MiMoModelProvider7* classes and MimoBridge class to enable direct imports from megatron.bridge.models and megatron.bridge.models.mimo.
Unit Tests
tests/unit_tests/models/mimo/test_mimo_causal_provider.py, tests/unit_tests/models/mimo/test_mimo_bridge.py
New test suites validating MiMo provider default hyperparameters across all six variants and MimoBridge integration including registration, provider mapping attributes, mapping registry entries, and weight swapping logic on HF load/conversion.

Sequence Diagram

sequenceDiagram
    participant HFModel as HuggingFace Model
    participant MimoBridge as MimoBridge
    participant Megatron as Megatron GPTModel
    participant WeightDict as Weight Dictionary

    HFModel->>MimoBridge: Load HF state dict (with input_proj weights)
    MimoBridge->>MimoBridge: maybe_modify_loaded_hf_weight()
    MimoBridge->>MimoBridge: _swap_input_proj_halves(weight)
    MimoBridge->>WeightDict: Return swapped input_proj weights
    
    MimoBridge->>MimoBridge: mapping_registry (MiMo-specific mappings)
    MimoBridge->>Megatron: Convert HF weights using custom mappings
    
    Megatron->>MimoBridge: maybe_modify_converted_hf_weight()
    MimoBridge->>MimoBridge: Swap halves again for eh_proj/input_proj
    MimoBridge->>Megatron: Return final converted weights
    
    Megatron->>Megatron: Load converted weights into model
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested labels

Run CICD

Suggested reviewers

  • yaoyu-33
🚥 Pre-merge checks | ✅ 3 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Test Results For Major Changes ⚠️ Warning PR adds major new functionality (MiMo dense MTP models) but PR description lacks test execution results, performance metrics, or regression testing documentation. Include actual test execution results, performance metrics if applicable, and regression testing confirmation in the PR description to demonstrate the changes work correctly.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Add MiMo dense MTP models bridge support' clearly and specifically describes the main change: adding bridge support for MiMo models with dense MTP features.
Merge Conflict Detection ✅ Passed ✅ No merge conflicts detected when merging into main
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@src/megatron/bridge/models/mimo/mimo_bridge.py`:
- Around line 49-113: The loop in mapping_registry currently iterates over
("transformer_layer", "mtp_model_layer") but "mtp_model_layer" appears unused
elsewhere (e.g., model instantiation in llava_provider.py); verify by grepping
for the literal "mtp_model_layer" and the unit test
test_mapping_registry_includes_mtp_paths; if it is dead remove "mtp_model_layer"
from the tuple so mapping_registry only builds mappings for "transformer_layer"
(leave MegatronMappingRegistry, AutoMapping, QKVMapping, GatedMLPMapping usage
unchanged), and update or remove the test that expects mtp paths to reflect the
real mapping; if you find other live references keep both prefixes and leave the
code/tests as-is.
🧹 Nitpick comments (2)
src/megatron/bridge/models/mimo/mimo_bridge.py (2)

15-15: Use built-in generics instead of typing.Dict.

Per coding guidelines, prefer dict over typing.Dict for Python 3.10+. Mapping from typing is fine (it's a protocol, not a generic alias), but Dict on lines 131 and 133 should be dict.

♻️ Suggested fix
-from typing import Dict, Mapping
+from collections.abc import Mapping

Then update usages on lines 131 and 147:

-    ) -> Dict[str, torch.Tensor]:
+    ) -> dict[str, torch.Tensor]:

As per coding guidelines: "Use built-in generics (list, dict, tuple) instead of typing equivalents."


115-118: Add tensor shape validation before splitting.

The coding guidelines require validating tensor shapes before weight conversion. _swap_input_proj_halves assumes dim=1 exists and is evenly divisible by 2. A defensive check would prevent silent data corruption if called with an unexpected tensor shape.

🛡️ Suggested shape validation
 `@staticmethod`
 def _swap_input_proj_halves(weight: torch.Tensor) -> torch.Tensor:
+    assert weight.ndim >= 2 and weight.shape[1] % 2 == 0, (
+        f"Expected weight with even dim-1, got shape {weight.shape}"
+    )
     first_half, second_half = weight.chunk(2, dim=1)
     return torch.cat((second_half, first_half), dim=1)

As per coding guidelines: "Always validate tensor shapes before copying weights in weight conversion."

@yaoyu-33
Copy link
Contributor

@HollowMan6
Can you help to refer to our latest refactor?
#2052

No need to do individual model provider anymore, just use bridge to override if necessary. There is a list of common mapping should take care of most.

MiMo adds MTP (Multi-Token Prediction) layers on top of Qwen2 architecture,
so these models are very helpful for debugging MTP features.

Refer to https://github.com/ISEEKYAN/mbridge/blob/main/mbridge/models/mimo.py

Signed-off-by: Hollow Man <hollowman@opensuse.org>
@HollowMan6
Copy link
Contributor Author

Updated @yaoyu-33

@yaoyu-33
Copy link
Contributor

/ok to test a8ad6b9

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants