@@ -109,7 +109,7 @@ def get_kv_cache_scaling_factor(kv_module: nn.Module) -> torch.Tensor:
109109
110110def get_quantized_state (
111111 module : torch .nn .Module ,
112- dtype : torch .dtype = torch .float16 ,
112+ dtype : torch .dtype = torch .bfloat16 ,
113113) -> tuple [dict [str , torch .Tensor ], str , int ]:
114114 """Return a state_dict, quantization format, and block_size of the module.
115115
@@ -197,7 +197,12 @@ def __init__(
197197 self ._hf_config = transformers .AutoConfig .from_pretrained (
198198 pretrained_model_name_or_path , trust_remote_code = trust_remote_code
199199 )
200- self .moe_router_dtype = moe_router_dtype
200+ self .moe_router_dtype = None
201+ if moe_router_dtype == "fp32" :
202+ self .moe_router_dtype = torch .float32
203+ elif moe_router_dtype == "fp64" :
204+ self .moe_router_dtype = torch .float64
205+
201206 # If multimodal, extra the text_config
202207 self ._hf_text_config = getattr (self ._hf_config , "text_config" , self ._hf_config )
203208
@@ -1142,7 +1147,7 @@ def export_mcore_gpt_to_hf(
11421147 model : torch .nn .Module ,
11431148 pretrained_model_name_or_path : str | os .PathLike | None = None ,
11441149 export_extra_modules : bool = False ,
1145- dtype : torch .dtype = torch .float16 ,
1150+ dtype : torch .dtype = torch .bfloat16 ,
11461151 export_dir : Path | str = tempfile .gettempdir (),
11471152 moe_router_dtype : torch .dtype | None = None ,
11481153):
@@ -1169,7 +1174,7 @@ def import_mcore_gpt_from_hf(
11691174 model : torch .nn .Module ,
11701175 pretrained_model_path : str ,
11711176 workspace_dir : str | None = None ,
1172- dtype : torch .dtype = torch .float16 ,
1177+ dtype : torch .dtype = torch .bfloat16 ,
11731178 moe_router_dtype : torch .dtype | None = None ,
11741179):
11751180 """Import GPTModel state_dict from supported HuggingFace pretrained model path.
0 commit comments