Skip to content

Commit 41357c8

Browse files
committed
allow custom router dtype
Signed-off-by: jenchen13 <[email protected]>
1 parent 95f6c25 commit 41357c8

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

modelopt/torch/export/unified_export_megatron.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
from modelopt import __version__
3737
from modelopt.torch.utils import import_plugin
38+
from megatron.core import ModelParallelConfig
3839

3940
from .model_config import (
4041
KV_CACHE_FP8,
@@ -186,6 +187,7 @@ def __init__(
186187
export_extra_modules: bool = False,
187188
dtype=torch.bfloat16,
188189
trust_remote_code: bool = True,
190+
config: ModelParallelConfig | None = None,
189191
):
190192
"""Create a GPTModel exporter instance."""
191193
if not isinstance(model, (GPTModel, MambaModel, LLaVAModel)):
@@ -196,6 +198,9 @@ def __init__(
196198
self._hf_config = transformers.AutoConfig.from_pretrained(
197199
pretrained_model_name_or_path, trust_remote_code=trust_remote_code
198200
)
201+
if config.moe_router_dtype:
202+
if config.moe_router_dtype == "fp32":
203+
self.moe_router_dtype = torch.float32
199204
# If multimodal, extra the text_config
200205
self._hf_text_config = getattr(self._hf_config, "text_config", self._hf_config)
201206

@@ -486,9 +491,11 @@ def _custom_mapping_to_lambda(mapping):
486491
"pack_name_remapping": self._pack_name_remapping,
487492
"pack_name_remapping_gpt_oss": self._pack_name_remapping_gpt_oss,
488493
}
494+
print("Mapping: ", mapping)
489495
func = method_map[mapping.func_name]
490496
prefix = mapping.target_name_or_prefix
491497
func_kwargs = mapping.func_kwargs
498+
dtype = mapping.dtype
492499
return lambda m, *args: func(m, prefix.format(*args), **func_kwargs)
493500

494501
for arch, mappings in all_mcore_hf_export_mapping.items():
@@ -519,12 +526,16 @@ def _name_remapping(
519526
prefix: str,
520527
skip_output_scale: bool = True,
521528
mapping={},
529+
dtype: torch.dtype | None = None
522530
):
531+
if dtype is None:
532+
dtype = self.dtype
533+
523534
if isinstance(module, torch.Tensor):
524535
self._state_dict[prefix] = module
525536
return
526537

527-
name_to_value, qformat, block_size = get_quantized_state(module, self.dtype)
538+
name_to_value, qformat, block_size = get_quantized_state(module, dtype)
528539

529540
weight = name_to_value.pop("weight")
530541
weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat)
@@ -1098,7 +1109,7 @@ def _get_state_dict(self):
10981109

10991110
if not isinstance(layer.mlp, IdentityOp):
11001111
if "MoE" in str(type(layer.mlp)):
1101-
self.rules["router"](layer.mlp.router, layer_id)
1112+
self.rules["router"](layer.mlp.router, layer_id, dtype=self.moe_router_dtype)
11021113
if (
11031114
hasattr(layer.mlp, "shared_experts")
11041115
and layer.mlp.shared_experts is not None
@@ -1138,6 +1149,7 @@ def export_mcore_gpt_to_hf(
11381149
export_extra_modules: bool = False,
11391150
dtype: torch.dtype = torch.float16,
11401151
export_dir: Path | str = tempfile.gettempdir(),
1152+
config: ModelParallelConfig = None,
11411153
):
11421154
"""Export Megatron Core GPTModel to unified checkpoint and save to export_dir.
11431155
@@ -1153,7 +1165,7 @@ def export_mcore_gpt_to_hf(
11531165
export_dir: The target export path.
11541166
"""
11551167
exporter = GPTModelExporter(
1156-
model, pretrained_model_name_or_path, export_extra_modules=export_extra_modules, dtype=dtype
1168+
model, pretrained_model_name_or_path, export_extra_modules=export_extra_modules, dtype=dtype, config=config
11571169
)
11581170
exporter.save_pretrained(export_dir, pretrained_model_name_or_path)
11591171

@@ -1173,6 +1185,6 @@ def import_mcore_gpt_from_hf(
11731185
dtype: The weights data type to import.
11741186
"""
11751187
importer = GPTModelImporter(
1176-
model, pretrained_model_path, workspace_dir=workspace_dir, dtype=dtype
1188+
model, pretrained_model_path, workspace_dir=workspace_dir, dtype=dtype,
11771189
)
1178-
importer._import_state_dict()
1190+
importer._import_state_dict()

0 commit comments

Comments
 (0)