|
| 1 | +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from typing import Mapping |
| 16 | + |
| 17 | +import torch |
| 18 | +from megatron.core.models.gpt.gpt_model import GPTModel |
| 19 | + |
| 20 | +from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry |
| 21 | +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge, WeightConversionTask |
| 22 | +from megatron.bridge.models.conversion.param_mapping import ( |
| 23 | + AutoMapping, |
| 24 | + GatedMLPMapping, |
| 25 | + QKVMapping, |
| 26 | +) |
| 27 | +from megatron.bridge.models.qwen.qwen2_bridge import Qwen2Bridge |
| 28 | + |
| 29 | + |
| 30 | +@MegatronModelBridge.register_bridge(source="MiMoForCausalLM", target=GPTModel, model_type="mimo") |
| 31 | +class MimoBridge(Qwen2Bridge): |
| 32 | + """Megatron Bridge for MiMo Causal LM.""" |
| 33 | + |
| 34 | + def provider_bridge(self, hf_pretrained): |
| 35 | + provider = super().provider_bridge(hf_pretrained) |
| 36 | + hf_config = hf_pretrained.config |
| 37 | + |
| 38 | + # MiMo follows Qwen2 attention behavior and adds MTP on top. |
| 39 | + provider.qk_layernorm = False |
| 40 | + provider.add_qkv_bias = True |
| 41 | + |
| 42 | + num_mtp_layers = getattr(hf_config, "num_nextn_predict_layers", 0) |
| 43 | + if num_mtp_layers > 0: |
| 44 | + provider.mtp_num_layers = num_mtp_layers |
| 45 | + provider.mtp_loss_scaling_factor = 0.1 |
| 46 | + |
| 47 | + return provider |
| 48 | + |
| 49 | + def mapping_registry(self) -> MegatronMappingRegistry: |
| 50 | + mapping_list = list(super().mapping_registry().mappings) |
| 51 | + |
| 52 | + mapping_list.extend( |
| 53 | + [ |
| 54 | + AutoMapping( |
| 55 | + megatron_param="mtp.layers.*.enorm.weight", |
| 56 | + hf_param="model.mtp_layers.*.token_layernorm.weight", |
| 57 | + ), |
| 58 | + AutoMapping( |
| 59 | + megatron_param="mtp.layers.*.hnorm.weight", |
| 60 | + hf_param="model.mtp_layers.*.hidden_layernorm.weight", |
| 61 | + ), |
| 62 | + AutoMapping( |
| 63 | + megatron_param="mtp.layers.*.eh_proj.weight", |
| 64 | + hf_param="model.mtp_layers.*.input_proj.weight", |
| 65 | + ), |
| 66 | + AutoMapping( |
| 67 | + megatron_param="mtp.layers.*.final_layernorm.weight", |
| 68 | + hf_param="model.mtp_layers.*.final_layernorm.weight", |
| 69 | + ), |
| 70 | + ] |
| 71 | + ) |
| 72 | + |
| 73 | + # Support both naming conventions: Megatron-Core may expose MTP layers as |
| 74 | + # either "transformer_layer" or "mtp_model_layer" depending on configuration |
| 75 | + for layer_prefix in ("transformer_layer", "mtp_model_layer"): |
| 76 | + layer_path = f"mtp.layers.*.{layer_prefix}" |
| 77 | + mapping_list.extend( |
| 78 | + [ |
| 79 | + AutoMapping( |
| 80 | + megatron_param=f"{layer_path}.self_attention.linear_qkv.layer_norm_weight", |
| 81 | + hf_param="model.mtp_layers.*.input_layernorm.weight", |
| 82 | + ), |
| 83 | + AutoMapping( |
| 84 | + megatron_param=f"{layer_path}.mlp.linear_fc1.layer_norm_weight", |
| 85 | + hf_param="model.mtp_layers.*.post_attention_layernorm.weight", |
| 86 | + ), |
| 87 | + AutoMapping( |
| 88 | + megatron_param=f"{layer_path}.self_attention.linear_proj.weight", |
| 89 | + hf_param="model.mtp_layers.*.self_attn.o_proj.weight", |
| 90 | + ), |
| 91 | + AutoMapping( |
| 92 | + megatron_param=f"{layer_path}.mlp.linear_fc2.weight", |
| 93 | + hf_param="model.mtp_layers.*.mlp.down_proj.weight", |
| 94 | + ), |
| 95 | + QKVMapping( |
| 96 | + megatron_param=f"{layer_path}.self_attention.linear_qkv.weight", |
| 97 | + q="model.mtp_layers.*.self_attn.q_proj.weight", |
| 98 | + k="model.mtp_layers.*.self_attn.k_proj.weight", |
| 99 | + v="model.mtp_layers.*.self_attn.v_proj.weight", |
| 100 | + ), |
| 101 | + QKVMapping( |
| 102 | + megatron_param=f"{layer_path}.self_attention.linear_qkv.bias", |
| 103 | + q="model.mtp_layers.*.self_attn.q_proj.bias", |
| 104 | + k="model.mtp_layers.*.self_attn.k_proj.bias", |
| 105 | + v="model.mtp_layers.*.self_attn.v_proj.bias", |
| 106 | + ), |
| 107 | + GatedMLPMapping( |
| 108 | + megatron_param=f"{layer_path}.mlp.linear_fc1.weight", |
| 109 | + gate="model.mtp_layers.*.mlp.gate_proj.weight", |
| 110 | + up="model.mtp_layers.*.mlp.up_proj.weight", |
| 111 | + ), |
| 112 | + ] |
| 113 | + ) |
| 114 | + |
| 115 | + return MegatronMappingRegistry(*mapping_list) |
| 116 | + |
| 117 | + @staticmethod |
| 118 | + def _swap_input_proj_halves(weight: torch.Tensor) -> torch.Tensor: |
| 119 | + if weight.ndim < 2: |
| 120 | + raise ValueError( |
| 121 | + f"Expected tensor with at least 2 dimensions for input_proj weight, got shape {weight.shape}" |
| 122 | + ) |
| 123 | + if weight.shape[1] % 2 != 0: |
| 124 | + raise ValueError(f"Expected even dimension at dim=1 for input_proj weight, got shape {weight.shape}") |
| 125 | + first_half, second_half = weight.chunk(2, dim=1) |
| 126 | + return torch.cat((second_half, first_half), dim=1) |
| 127 | + |
| 128 | + def maybe_modify_loaded_hf_weight( |
| 129 | + self, hf_param: str | dict[str, str], hf_state_dict: Mapping[str, torch.Tensor] |
| 130 | + ) -> torch.Tensor: |
| 131 | + hf_weights = super().maybe_modify_loaded_hf_weight(hf_param, hf_state_dict) |
| 132 | + if isinstance(hf_param, str) and hf_param.endswith(".input_proj.weight"): |
| 133 | + return self._swap_input_proj_halves(hf_weights) |
| 134 | + return hf_weights |
| 135 | + |
| 136 | + def maybe_modify_converted_hf_weight( |
| 137 | + self, |
| 138 | + task: WeightConversionTask, |
| 139 | + converted_weights_dict: dict[str, torch.Tensor], |
| 140 | + hf_state_dict: Mapping[str, torch.Tensor], |
| 141 | + ) -> dict[str, torch.Tensor]: |
| 142 | + converted_weights_dict = super().maybe_modify_converted_hf_weight( |
| 143 | + task, |
| 144 | + converted_weights_dict, |
| 145 | + hf_state_dict, |
| 146 | + ) |
| 147 | + |
| 148 | + if not task.global_param_name.endswith(".eh_proj.weight"): |
| 149 | + return converted_weights_dict |
| 150 | + |
| 151 | + for hf_name, weight in list(converted_weights_dict.items()): |
| 152 | + if hf_name.endswith(".input_proj.weight"): |
| 153 | + converted_weights_dict[hf_name] = self._swap_input_proj_halves(weight) |
| 154 | + |
| 155 | + return converted_weights_dict |
0 commit comments