Skip to content

Commit 5f7158f

Browse files
committed
fix import and export moe router dtype
Signed-off-by: jenchen13 <[email protected]>
1 parent 8393149 commit 5f7158f

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

modelopt/torch/export/plugins/megatron_importer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,14 @@ def __init__(
7777
dequantize: bool = True,
7878
trust_remote_code: bool = True,
7979
verbose: bool = False,
80+
moe_router_dtype: torch.dtype | None = None,
8081
):
8182
"""Create a GPTModel importer instance."""
8283
self._hf_config = transformers.AutoConfig.from_pretrained(
8384
pretrained_model_name_or_path, trust_remote_code=trust_remote_code
8485
)
86+
self.moe_router_dtype = moe_router_dtype
87+
8588
pretrained_model_path = Path(pretrained_model_name_or_path)
8689
if not pretrained_model_path.is_dir():
8790
if workspace_dir is None:
@@ -118,7 +121,7 @@ def _custom_mapping_to_lambda(mapping):
118121
func = method_map[mapping.func_name]
119122
prefix = mapping.target_name_or_prefix
120123
func_kwargs = mapping.func_kwargs
121-
return lambda m, *args: func(m, prefix.format(*args), **func_kwargs)
124+
return lambda m, *args, **kwargs: func(m, prefix.format(*args), **{**func_kwargs, **kwargs})
122125

123126
for arch, mappings in all_mcore_hf_import_mapping.items():
124127
all_rules[arch] = {
@@ -140,6 +143,7 @@ def _name_remapping(
140143
prefix,
141144
mapping={},
142145
parallel_config: ParallelConfig | None = None,
146+
dtype: torch.dtype | None = None,
143147
):
144148
if isinstance(module, torch.Tensor):
145149
tensor = self._get_safetensor(prefix, parallel_config=parallel_config)
@@ -523,7 +527,7 @@ def _import_state_dict(self):
523527
if not isinstance(layer.mlp, IdentityOp):
524528
if "MoE" in str(type(layer.mlp)):
525529
layer_pbar.set_description("Importing MoE")
526-
self.rules["router"](layer.mlp.router, layer_id)
530+
self.rules["router"](layer.mlp.router, layer_id, dtype=self.moe_router_dtype)
527531
if (
528532
hasattr(layer.mlp, "shared_experts")
529533
and layer.mlp.shared_experts is not None

modelopt/torch/export/unified_export_megatron.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535

3636
from modelopt import __version__
3737
from modelopt.torch.utils import import_plugin
38-
from megatron.core import ModelParallelConfig
3938

4039
from .model_config import (
4140
KV_CACHE_FP8,
@@ -187,7 +186,7 @@ def __init__(
187186
export_extra_modules: bool = False,
188187
dtype=torch.bfloat16,
189188
trust_remote_code: bool = True,
190-
config: ModelParallelConfig | None = None,
189+
moe_router_dtype: torch.dtype | None = None,
191190
):
192191
"""Create a GPTModel exporter instance."""
193192
if not isinstance(model, (GPTModel, MambaModel, LLaVAModel)):
@@ -198,9 +197,7 @@ def __init__(
198197
self._hf_config = transformers.AutoConfig.from_pretrained(
199198
pretrained_model_name_or_path, trust_remote_code=trust_remote_code
200199
)
201-
if config.moe_router_dtype:
202-
if config.moe_router_dtype == "fp32":
203-
self.moe_router_dtype = torch.float32
200+
self.moe_router_dtype = moe_router_dtype
204201
# If multimodal, extra the text_config
205202
self._hf_text_config = getattr(self._hf_config, "text_config", self._hf_config)
206203

@@ -1147,7 +1144,7 @@ def export_mcore_gpt_to_hf(
11471144
export_extra_modules: bool = False,
11481145
dtype: torch.dtype = torch.float16,
11491146
export_dir: Path | str = tempfile.gettempdir(),
1150-
config: ModelParallelConfig = None,
1147+
moe_router_dtype: torch.dtype | None = None,
11511148
):
11521149
"""Export Megatron Core GPTModel to unified checkpoint and save to export_dir.
11531150
@@ -1163,7 +1160,7 @@ def export_mcore_gpt_to_hf(
11631160
export_dir: The target export path.
11641161
"""
11651162
exporter = GPTModelExporter(
1166-
model, pretrained_model_name_or_path, export_extra_modules=export_extra_modules, dtype=dtype, config=config
1163+
model, pretrained_model_name_or_path, export_extra_modules=export_extra_modules, dtype=dtype, moe_router_dtype=moe_router_dtype
11671164
)
11681165
exporter.save_pretrained(export_dir, pretrained_model_name_or_path)
11691166

@@ -1173,6 +1170,7 @@ def import_mcore_gpt_from_hf(
11731170
pretrained_model_path: str,
11741171
workspace_dir: str | None = None,
11751172
dtype: torch.dtype = torch.float16,
1173+
moe_router_dtype: torch.dtype | None = None,
11761174
):
11771175
"""Import GPTModel state_dict from supported HuggingFace pretrained model path.
11781176
@@ -1183,6 +1181,6 @@ def import_mcore_gpt_from_hf(
11831181
dtype: The weights data type to import.
11841182
"""
11851183
importer = GPTModelImporter(
1186-
model, pretrained_model_path, workspace_dir=workspace_dir, dtype=dtype,
1184+
model, pretrained_model_path, workspace_dir=workspace_dir, dtype=dtype, moe_router_dtype=moe_router_dtype
11871185
)
11881186
importer._import_state_dict()

0 commit comments

Comments
 (0)