3535
3636from modelopt import __version__
3737from modelopt .torch .utils import import_plugin
38+ from megatron .core import ModelParallelConfig
3839
3940from .model_config import (
4041 KV_CACHE_FP8 ,
@@ -186,6 +187,7 @@ def __init__(
186187 export_extra_modules : bool = False ,
187188 dtype = torch .bfloat16 ,
188189 trust_remote_code : bool = True ,
190+ config : ModelParallelConfig | None = None ,
189191 ):
190192 """Create a GPTModel exporter instance."""
191193 if not isinstance (model , (GPTModel , MambaModel , LLaVAModel )):
@@ -196,6 +198,9 @@ def __init__(
196198 self ._hf_config = transformers .AutoConfig .from_pretrained (
197199 pretrained_model_name_or_path , trust_remote_code = trust_remote_code
198200 )
201+ if config .moe_router_dtype :
202+ if config .moe_router_dtype == "fp32" :
203+ self .moe_router_dtype = torch .float32
199204 # If multimodal, extra the text_config
200205 self ._hf_text_config = getattr (self ._hf_config , "text_config" , self ._hf_config )
201206
@@ -486,9 +491,11 @@ def _custom_mapping_to_lambda(mapping):
486491 "pack_name_remapping" : self ._pack_name_remapping ,
487492 "pack_name_remapping_gpt_oss" : self ._pack_name_remapping_gpt_oss ,
488493 }
494+ print ("Mapping: " , mapping )
489495 func = method_map [mapping .func_name ]
490496 prefix = mapping .target_name_or_prefix
491497 func_kwargs = mapping .func_kwargs
498+ dtype = mapping .dtype
492499 return lambda m , * args : func (m , prefix .format (* args ), ** func_kwargs )
493500
494501 for arch , mappings in all_mcore_hf_export_mapping .items ():
@@ -519,12 +526,16 @@ def _name_remapping(
519526 prefix : str ,
520527 skip_output_scale : bool = True ,
521528 mapping = {},
529+ dtype : torch .dtype | None = None
522530 ):
531+ if dtype is None :
532+ dtype = self .dtype
533+
523534 if isinstance (module , torch .Tensor ):
524535 self ._state_dict [prefix ] = module
525536 return
526537
527- name_to_value , qformat , block_size = get_quantized_state (module , self . dtype )
538+ name_to_value , qformat , block_size = get_quantized_state (module , dtype )
528539
529540 weight = name_to_value .pop ("weight" )
530541 weight_scale , weight_scale_2 = self ._get_weight_scales (name_to_value , qformat )
@@ -1098,7 +1109,7 @@ def _get_state_dict(self):
10981109
10991110 if not isinstance (layer .mlp , IdentityOp ):
11001111 if "MoE" in str (type (layer .mlp )):
1101- self .rules ["router" ](layer .mlp .router , layer_id )
1112+ self .rules ["router" ](layer .mlp .router , layer_id , dtype = self . moe_router_dtype )
11021113 if (
11031114 hasattr (layer .mlp , "shared_experts" )
11041115 and layer .mlp .shared_experts is not None
@@ -1138,6 +1149,7 @@ def export_mcore_gpt_to_hf(
11381149 export_extra_modules : bool = False ,
11391150 dtype : torch .dtype = torch .float16 ,
11401151 export_dir : Path | str = tempfile .gettempdir (),
1152+ config : ModelParallelConfig = None ,
11411153):
11421154 """Export Megatron Core GPTModel to unified checkpoint and save to export_dir.
11431155
@@ -1153,7 +1165,7 @@ def export_mcore_gpt_to_hf(
11531165 export_dir: The target export path.
11541166 """
11551167 exporter = GPTModelExporter (
1156- model , pretrained_model_name_or_path , export_extra_modules = export_extra_modules , dtype = dtype
1168+ model , pretrained_model_name_or_path , export_extra_modules = export_extra_modules , dtype = dtype , config = config
11571169 )
11581170 exporter .save_pretrained (export_dir , pretrained_model_name_or_path )
11591171
@@ -1173,6 +1185,6 @@ def import_mcore_gpt_from_hf(
11731185 dtype: The weights data type to import.
11741186 """
11751187 importer = GPTModelImporter (
1176- model , pretrained_model_path , workspace_dir = workspace_dir , dtype = dtype
1188+ model , pretrained_model_path , workspace_dir = workspace_dir , dtype = dtype ,
11771189 )
1178- importer ._import_state_dict ()
1190+ importer ._import_state_dict ()
0 commit comments