@@ -36,6 +36,19 @@ def import_model_from_hf_name(
3636
3737 model_provider = bridge .to_megatron_provider (load_weights = True )
3838
39+ # Keep track of defaults so can restore them to the config after loading the model
40+ orig_tensor_model_parallel_size = model_provider .tensor_model_parallel_size
41+ orig_pipeline_model_parallel_size = model_provider .pipeline_model_parallel_size
42+ orig_expert_model_parallel_size = model_provider .expert_model_parallel_size
43+ orig_expert_tensor_parallel_size = model_provider .expert_tensor_parallel_size
44+ orig_num_layers_in_first_pipeline_stage = (
45+ model_provider .num_layers_in_first_pipeline_stage
46+ )
47+ orig_num_layers_in_last_pipeline_stage = (
48+ model_provider .num_layers_in_last_pipeline_stage
49+ )
50+ orig_pipeline_dtype = model_provider .pipeline_dtype
51+
3952 if megatron_config is not None :
4053 model_provider .tensor_model_parallel_size = megatron_config [
4154 "tensor_model_parallel_size"
@@ -59,6 +72,18 @@ def import_model_from_hf_name(
5972 model_provider .initialize_model_parallel (seed = 0 )
6073 megatron_model = model_provider .provide_distributed_model (wrap_with_ddp = False )
6174
75+ # The above parallelism settings are used to load the model in a distributed manner.
76+ # However, we do not want to save the parallelism settings to the checkpoint config
77+ # because they may result in validation errors when loading the checkpoint.
78+ config = megatron_model [0 ].config
79+ config .tensor_model_parallel_size = orig_tensor_model_parallel_size
80+ config .pipeline_model_parallel_size = orig_pipeline_model_parallel_size
81+ config .expert_model_parallel_size = orig_expert_model_parallel_size
82+ config .expert_tensor_parallel_size = orig_expert_tensor_parallel_size
83+ config .num_layers_in_first_pipeline_stage = orig_num_layers_in_first_pipeline_stage
84+ config .num_layers_in_last_pipeline_stage = orig_num_layers_in_last_pipeline_stage
85+ config .pipeline_dtype = orig_pipeline_dtype
86+
6287 bridge .save_megatron_model (megatron_model , output_path )
6388
6489 # resetting mcore state
0 commit comments