Skip to content

Commit a8ad6b9

Browse files
committed
Add MiMo dense MTP models support
MiMo adds MTP (Multi-Token Prediction) layers on top of Qwen2 architecture, so these models are very helpful for debugging MTP features. Refer to https://github.com/ISEEKYAN/mbridge/blob/main/mbridge/models/mimo.py Signed-off-by: Hollow Man <hollowman@opensuse.org>
1 parent f91542b commit a8ad6b9

File tree

3 files changed

+285
-0
lines changed

3 files changed

+285
-0
lines changed

src/megatron/bridge/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106
LlamaNemotronHeterogeneousProvider,
107107
)
108108
from megatron.bridge.models.mamba.mamba_provider import MambaModelProvider
109+
from megatron.bridge.models.mimo.mimo_bridge import MimoBridge
109110
from megatron.bridge.models.ministral3 import (
110111
Ministral3Bridge,
111112
Ministral3Model,
@@ -312,6 +313,7 @@
312313
"NemotronNano12Bv2Provider",
313314
"Nemotron3NanoProvider",
314315
"MambaModelProvider",
316+
"MimoBridge",
315317
# Nemotron Models
316318
"NemotronBridge",
317319
"NemotronModelProvider",
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Mapping
16+
17+
import torch
18+
from megatron.core.models.gpt.gpt_model import GPTModel
19+
20+
from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry
21+
from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge, WeightConversionTask
22+
from megatron.bridge.models.conversion.param_mapping import (
23+
AutoMapping,
24+
GatedMLPMapping,
25+
QKVMapping,
26+
)
27+
from megatron.bridge.models.qwen.qwen2_bridge import Qwen2Bridge
28+
29+
30+
@MegatronModelBridge.register_bridge(source="MiMoForCausalLM", target=GPTModel, model_type="mimo")
31+
class MimoBridge(Qwen2Bridge):
32+
"""Megatron Bridge for MiMo Causal LM."""
33+
34+
def provider_bridge(self, hf_pretrained):
35+
provider = super().provider_bridge(hf_pretrained)
36+
hf_config = hf_pretrained.config
37+
38+
# MiMo follows Qwen2 attention behavior and adds MTP on top.
39+
provider.qk_layernorm = False
40+
provider.add_qkv_bias = True
41+
42+
num_mtp_layers = getattr(hf_config, "num_nextn_predict_layers", 0)
43+
if num_mtp_layers > 0:
44+
provider.mtp_num_layers = num_mtp_layers
45+
provider.mtp_loss_scaling_factor = 0.1
46+
47+
return provider
48+
49+
def mapping_registry(self) -> MegatronMappingRegistry:
50+
mapping_list = list(super().mapping_registry().mappings)
51+
52+
mapping_list.extend(
53+
[
54+
AutoMapping(
55+
megatron_param="mtp.layers.*.enorm.weight",
56+
hf_param="model.mtp_layers.*.token_layernorm.weight",
57+
),
58+
AutoMapping(
59+
megatron_param="mtp.layers.*.hnorm.weight",
60+
hf_param="model.mtp_layers.*.hidden_layernorm.weight",
61+
),
62+
AutoMapping(
63+
megatron_param="mtp.layers.*.eh_proj.weight",
64+
hf_param="model.mtp_layers.*.input_proj.weight",
65+
),
66+
AutoMapping(
67+
megatron_param="mtp.layers.*.final_layernorm.weight",
68+
hf_param="model.mtp_layers.*.final_layernorm.weight",
69+
),
70+
]
71+
)
72+
73+
# Support both naming conventions: Megatron-Core may expose MTP layers as
74+
# either "transformer_layer" or "mtp_model_layer" depending on configuration
75+
for layer_prefix in ("transformer_layer", "mtp_model_layer"):
76+
layer_path = f"mtp.layers.*.{layer_prefix}"
77+
mapping_list.extend(
78+
[
79+
AutoMapping(
80+
megatron_param=f"{layer_path}.self_attention.linear_qkv.layer_norm_weight",
81+
hf_param="model.mtp_layers.*.input_layernorm.weight",
82+
),
83+
AutoMapping(
84+
megatron_param=f"{layer_path}.mlp.linear_fc1.layer_norm_weight",
85+
hf_param="model.mtp_layers.*.post_attention_layernorm.weight",
86+
),
87+
AutoMapping(
88+
megatron_param=f"{layer_path}.self_attention.linear_proj.weight",
89+
hf_param="model.mtp_layers.*.self_attn.o_proj.weight",
90+
),
91+
AutoMapping(
92+
megatron_param=f"{layer_path}.mlp.linear_fc2.weight",
93+
hf_param="model.mtp_layers.*.mlp.down_proj.weight",
94+
),
95+
QKVMapping(
96+
megatron_param=f"{layer_path}.self_attention.linear_qkv.weight",
97+
q="model.mtp_layers.*.self_attn.q_proj.weight",
98+
k="model.mtp_layers.*.self_attn.k_proj.weight",
99+
v="model.mtp_layers.*.self_attn.v_proj.weight",
100+
),
101+
QKVMapping(
102+
megatron_param=f"{layer_path}.self_attention.linear_qkv.bias",
103+
q="model.mtp_layers.*.self_attn.q_proj.bias",
104+
k="model.mtp_layers.*.self_attn.k_proj.bias",
105+
v="model.mtp_layers.*.self_attn.v_proj.bias",
106+
),
107+
GatedMLPMapping(
108+
megatron_param=f"{layer_path}.mlp.linear_fc1.weight",
109+
gate="model.mtp_layers.*.mlp.gate_proj.weight",
110+
up="model.mtp_layers.*.mlp.up_proj.weight",
111+
),
112+
]
113+
)
114+
115+
return MegatronMappingRegistry(*mapping_list)
116+
117+
@staticmethod
118+
def _swap_input_proj_halves(weight: torch.Tensor) -> torch.Tensor:
119+
if weight.ndim < 2:
120+
raise ValueError(
121+
f"Expected tensor with at least 2 dimensions for input_proj weight, got shape {weight.shape}"
122+
)
123+
if weight.shape[1] % 2 != 0:
124+
raise ValueError(f"Expected even dimension at dim=1 for input_proj weight, got shape {weight.shape}")
125+
first_half, second_half = weight.chunk(2, dim=1)
126+
return torch.cat((second_half, first_half), dim=1)
127+
128+
def maybe_modify_loaded_hf_weight(
129+
self, hf_param: str | dict[str, str], hf_state_dict: Mapping[str, torch.Tensor]
130+
) -> torch.Tensor:
131+
hf_weights = super().maybe_modify_loaded_hf_weight(hf_param, hf_state_dict)
132+
if isinstance(hf_param, str) and hf_param.endswith(".input_proj.weight"):
133+
return self._swap_input_proj_halves(hf_weights)
134+
return hf_weights
135+
136+
def maybe_modify_converted_hf_weight(
137+
self,
138+
task: WeightConversionTask,
139+
converted_weights_dict: dict[str, torch.Tensor],
140+
hf_state_dict: Mapping[str, torch.Tensor],
141+
) -> dict[str, torch.Tensor]:
142+
converted_weights_dict = super().maybe_modify_converted_hf_weight(
143+
task,
144+
converted_weights_dict,
145+
hf_state_dict,
146+
)
147+
148+
if not task.global_param_name.endswith(".eh_proj.weight"):
149+
return converted_weights_dict
150+
151+
for hf_name, weight in list(converted_weights_dict.items()):
152+
if hf_name.endswith(".input_proj.weight"):
153+
converted_weights_dict[hf_name] = self._swap_input_proj_halves(weight)
154+
155+
return converted_weights_dict
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest.mock import Mock
16+
17+
import pytest
18+
import torch
19+
from transformers import GenerationConfig
20+
21+
from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge, WeightConversionTask
22+
from megatron.bridge.models.gpt_provider import GPTModelProvider
23+
from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM
24+
from megatron.bridge.models.mimo.mimo_bridge import MimoBridge
25+
26+
27+
class TestMimoBridge:
28+
"""Test cases for MimoBridge."""
29+
30+
@pytest.fixture
31+
def mimo_config(self):
32+
return {
33+
"architectures": ["MiMoForCausalLM"],
34+
"attention_bias": True,
35+
"attention_dropout": 0.0,
36+
"bos_token_id": 151643,
37+
"eos_token_id": 151645,
38+
"hidden_size": 4096,
39+
"initializer_range": 0.02,
40+
"intermediate_size": 11008,
41+
"max_position_embeddings": 32768,
42+
"model_type": "mimo",
43+
"num_attention_heads": 32,
44+
"num_hidden_layers": 36,
45+
"num_key_value_heads": 8,
46+
"num_nextn_predict_layers": 1,
47+
"rms_norm_eps": 1e-05,
48+
"rope_theta": 640000.0,
49+
"tie_word_embeddings": False,
50+
"torch_dtype": "bfloat16",
51+
"vocab_size": 151680,
52+
}
53+
54+
@pytest.fixture
55+
def mock_pretrained_mimo(self, mimo_config):
56+
cfg = Mock(spec=list(mimo_config.keys()))
57+
for key, value in mimo_config.items():
58+
setattr(cfg, key, value)
59+
60+
model = Mock(spec=PreTrainedCausalLM)
61+
model.config = cfg
62+
model.generation_config = Mock(spec=GenerationConfig)
63+
return model
64+
65+
def test_registration(self):
66+
assert issubclass(MimoBridge, MegatronModelBridge)
67+
68+
def test_provider_bridge_maps_mtp_config(self, mock_pretrained_mimo):
69+
bridge = MimoBridge()
70+
provider = bridge.provider_bridge(mock_pretrained_mimo)
71+
72+
assert isinstance(provider, GPTModelProvider)
73+
assert provider.hidden_size == mock_pretrained_mimo.config.hidden_size
74+
assert provider.num_attention_heads == mock_pretrained_mimo.config.num_attention_heads
75+
assert provider.ffn_hidden_size == mock_pretrained_mimo.config.intermediate_size
76+
assert provider.vocab_size == mock_pretrained_mimo.config.vocab_size
77+
assert provider.qk_layernorm is False
78+
assert provider.add_qkv_bias is True
79+
assert provider.mtp_num_layers == mock_pretrained_mimo.config.num_nextn_predict_layers
80+
assert provider.mtp_loss_scaling_factor == 0.1
81+
assert provider.bf16 is True
82+
assert provider.params_dtype == torch.bfloat16
83+
84+
def test_mapping_registry_includes_mtp_paths(self):
85+
bridge = MimoBridge()
86+
registry = bridge.mapping_registry()
87+
88+
mapping = registry.megatron_to_hf_lookup("mtp.layers.0.eh_proj.weight")
89+
assert mapping is not None
90+
assert mapping.hf_param == "model.mtp_layers.0.input_proj.weight"
91+
92+
transformer_mapping = registry.megatron_to_hf_lookup(
93+
"mtp.layers.0.transformer_layer.self_attention.linear_qkv.weight"
94+
)
95+
assert transformer_mapping is not None
96+
assert transformer_mapping.hf_param["q"] == "model.mtp_layers.0.self_attn.q_proj.weight"
97+
98+
mtp_model_mapping = registry.megatron_to_hf_lookup(
99+
"mtp.layers.0.mtp_model_layer.self_attention.linear_qkv.weight"
100+
)
101+
assert mtp_model_mapping is not None
102+
assert mtp_model_mapping.hf_param["q"] == "model.mtp_layers.0.self_attn.q_proj.weight"
103+
104+
def test_mtp_input_proj_swap_on_hf_load(self):
105+
bridge = MimoBridge()
106+
weight = torch.arange(24, dtype=torch.float32).reshape(3, 8)
107+
hf_key = "model.mtp_layers.0.input_proj.weight"
108+
109+
modified = bridge.maybe_modify_loaded_hf_weight(hf_key, {hf_key: weight})
110+
111+
expected = torch.cat((weight[:, 4:], weight[:, :4]), dim=1)
112+
assert torch.equal(modified, expected)
113+
114+
def test_mtp_input_proj_swap_on_hf_export(self):
115+
bridge = MimoBridge()
116+
weight = torch.arange(24, dtype=torch.float32).reshape(3, 8)
117+
118+
task = WeightConversionTask(
119+
param_name="mtp.layers.0.eh_proj.weight",
120+
global_param_name="mtp.layers.0.eh_proj.weight",
121+
mapping=Mock(),
122+
)
123+
converted = {"model.mtp_layers.0.input_proj.weight": weight}
124+
125+
modified = bridge.maybe_modify_converted_hf_weight(task, dict(converted), {})
126+
127+
expected = torch.cat((weight[:, 4:], weight[:, :4]), dim=1)
128+
assert torch.equal(modified["model.mtp_layers.0.input_proj.weight"], expected)

0 commit comments

Comments
 (0)