3535
3636from modelopt import __version__
3737from modelopt .torch .utils import import_plugin
38- from megatron .core import ModelParallelConfig
3938
4039from .model_config import (
4140 KV_CACHE_FP8 ,
@@ -187,7 +186,7 @@ def __init__(
187186 export_extra_modules : bool = False ,
188187 dtype = torch .bfloat16 ,
189188 trust_remote_code : bool = True ,
190- config : ModelParallelConfig | None = None ,
189+ moe_router_dtype : torch . dtype | None = None ,
191190 ):
192191 """Create a GPTModel exporter instance."""
193192 if not isinstance (model , (GPTModel , MambaModel , LLaVAModel )):
@@ -198,9 +197,7 @@ def __init__(
198197 self ._hf_config = transformers .AutoConfig .from_pretrained (
199198 pretrained_model_name_or_path , trust_remote_code = trust_remote_code
200199 )
201- if config .moe_router_dtype :
202- if config .moe_router_dtype == "fp32" :
203- self .moe_router_dtype = torch .float32
200+ self .moe_router_dtype = moe_router_dtype
204201 # If multimodal, extra the text_config
205202 self ._hf_text_config = getattr (self ._hf_config , "text_config" , self ._hf_config )
206203
@@ -1147,7 +1144,7 @@ def export_mcore_gpt_to_hf(
11471144 export_extra_modules : bool = False ,
11481145 dtype : torch .dtype = torch .float16 ,
11491146 export_dir : Path | str = tempfile .gettempdir (),
1150- config : ModelParallelConfig = None ,
1147+ moe_router_dtype : torch . dtype | None = None ,
11511148):
11521149 """Export Megatron Core GPTModel to unified checkpoint and save to export_dir.
11531150
@@ -1163,7 +1160,7 @@ def export_mcore_gpt_to_hf(
11631160 export_dir: The target export path.
11641161 """
11651162 exporter = GPTModelExporter (
1166- model , pretrained_model_name_or_path , export_extra_modules = export_extra_modules , dtype = dtype , config = config
1163+ model , pretrained_model_name_or_path , export_extra_modules = export_extra_modules , dtype = dtype , moe_router_dtype = moe_router_dtype
11671164 )
11681165 exporter .save_pretrained (export_dir , pretrained_model_name_or_path )
11691166
@@ -1173,6 +1170,7 @@ def import_mcore_gpt_from_hf(
11731170 pretrained_model_path : str ,
11741171 workspace_dir : str | None = None ,
11751172 dtype : torch .dtype = torch .float16 ,
1173+ moe_router_dtype : torch .dtype | None = None ,
11761174):
11771175 """Import GPTModel state_dict from supported HuggingFace pretrained model path.
11781176
@@ -1183,6 +1181,6 @@ def import_mcore_gpt_from_hf(
11831181 dtype: The weights data type to import.
11841182 """
11851183 importer = GPTModelImporter (
1186- model , pretrained_model_path , workspace_dir = workspace_dir , dtype = dtype ,
1184+ model , pretrained_model_path , workspace_dir = workspace_dir , dtype = dtype , moe_router_dtype = moe_router_dtype
11871185 )
11881186 importer ._import_state_dict ()
0 commit comments