@@ -109,7 +109,7 @@ def get_kv_cache_scaling_factor(kv_module: nn.Module) -> torch.Tensor:
109109
110110def get_quantized_state (
111111 module : torch .nn .Module ,
112- dtype : torch .dtype = torch .float16 ,
112+ dtype : torch .dtype = torch .bfloat16 ,
113113) -> tuple [dict [str , torch .Tensor ], str , int ]:
114114 """Return a state_dict, quantization format, and block_size of the module.
115115
@@ -186,6 +186,7 @@ def __init__(
186186 export_extra_modules : bool = False ,
187187 dtype = torch .bfloat16 ,
188188 trust_remote_code : bool = True ,
189+ moe_router_dtype : torch .dtype | None = None ,
189190 ):
190191 """Create a GPTModel exporter instance."""
191192 if not isinstance (model , (GPTModel , MambaModel , LLaVAModel )):
@@ -196,6 +197,12 @@ def __init__(
196197 self ._hf_config = transformers .AutoConfig .from_pretrained (
197198 pretrained_model_name_or_path , trust_remote_code = trust_remote_code
198199 )
200+ self .moe_router_dtype = None
201+ if moe_router_dtype == "fp32" :
202+ self .moe_router_dtype = torch .float32
203+ elif moe_router_dtype == "fp64" :
204+ self .moe_router_dtype = torch .float64
205+
199206 # If multimodal, extra the text_config
200207 self ._hf_text_config = getattr (self ._hf_config , "text_config" , self ._hf_config )
201208
@@ -489,7 +496,9 @@ def _custom_mapping_to_lambda(mapping):
489496 func = method_map [mapping .func_name ]
490497 prefix = mapping .target_name_or_prefix
491498 func_kwargs = mapping .func_kwargs
492- return lambda m , * args : func (m , prefix .format (* args ), ** func_kwargs )
499+ return lambda m , * args , ** kwargs : func (
500+ m , prefix .format (* args ), ** {** func_kwargs , ** kwargs }
501+ )
493502
494503 for arch , mappings in all_mcore_hf_export_mapping .items ():
495504 all_rules [arch ] = {
@@ -519,12 +528,16 @@ def _name_remapping(
519528 prefix : str ,
520529 skip_output_scale : bool = True ,
521530 mapping = {},
531+ dtype : torch .dtype | None = None ,
522532 ):
533+ if dtype is None :
534+ dtype = self .dtype
535+
523536 if isinstance (module , torch .Tensor ):
524537 self ._state_dict [prefix ] = module
525538 return
526539
527- name_to_value , qformat , block_size = get_quantized_state (module , self . dtype )
540+ name_to_value , qformat , block_size = get_quantized_state (module , dtype )
528541
529542 weight = name_to_value .pop ("weight" )
530543 weight_scale , weight_scale_2 = self ._get_weight_scales (name_to_value , qformat )
@@ -1098,7 +1111,9 @@ def _get_state_dict(self):
10981111
10991112 if not isinstance (layer .mlp , IdentityOp ):
11001113 if "MoE" in str (type (layer .mlp )):
1101- self .rules ["router" ](layer .mlp .router , layer_id )
1114+ self .rules ["router" ](
1115+ layer .mlp .router , layer_id , dtype = self .moe_router_dtype
1116+ )
11021117 if (
11031118 hasattr (layer .mlp , "shared_experts" )
11041119 and layer .mlp .shared_experts is not None
@@ -1136,8 +1151,9 @@ def export_mcore_gpt_to_hf(
11361151 model : torch .nn .Module ,
11371152 pretrained_model_name_or_path : str | os .PathLike | None = None ,
11381153 export_extra_modules : bool = False ,
1139- dtype : torch .dtype = torch .float16 ,
1154+ dtype : torch .dtype = torch .bfloat16 ,
11401155 export_dir : Path | str = tempfile .gettempdir (),
1156+ moe_router_dtype : torch .dtype | None = None ,
11411157):
11421158 """Export Megatron Core GPTModel to unified checkpoint and save to export_dir.
11431159
@@ -1153,7 +1169,11 @@ def export_mcore_gpt_to_hf(
11531169 export_dir: The target export path.
11541170 """
11551171 exporter = GPTModelExporter (
1156- model , pretrained_model_name_or_path , export_extra_modules = export_extra_modules , dtype = dtype
1172+ model ,
1173+ pretrained_model_name_or_path ,
1174+ export_extra_modules = export_extra_modules ,
1175+ dtype = dtype ,
1176+ moe_router_dtype = moe_router_dtype ,
11571177 )
11581178 exporter .save_pretrained (export_dir , pretrained_model_name_or_path )
11591179
@@ -1162,7 +1182,8 @@ def import_mcore_gpt_from_hf(
11621182 model : torch .nn .Module ,
11631183 pretrained_model_path : str ,
11641184 workspace_dir : str | None = None ,
1165- dtype : torch .dtype = torch .float16 ,
1185+ dtype : torch .dtype = torch .bfloat16 ,
1186+ moe_router_dtype : torch .dtype | None = None ,
11661187):
11671188 """Import GPTModel state_dict from supported HuggingFace pretrained model path.
11681189
@@ -1173,6 +1194,10 @@ def import_mcore_gpt_from_hf(
11731194 dtype: The weights data type to import.
11741195 """
11751196 importer = GPTModelImporter (
1176- model , pretrained_model_path , workspace_dir = workspace_dir , dtype = dtype
1197+ model ,
1198+ pretrained_model_path ,
1199+ workspace_dir = workspace_dir ,
1200+ dtype = dtype ,
1201+ moe_router_dtype = moe_router_dtype ,
11771202 )
11781203 importer ._import_state_dict ()
0 commit comments