Skip to content

Commit ae89e12

Browse files
authored
fix: Reset parallelism configs to default after initial import (#1078)
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
1 parent 191a160 commit ae89e12

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

nemo_rl/models/megatron/community_import.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,19 @@ def import_model_from_hf_name(
3636

3737
model_provider = bridge.to_megatron_provider(load_weights=True)
3838

39+
# Keep track of defaults so can restore them to the config after loading the model
40+
orig_tensor_model_parallel_size = model_provider.tensor_model_parallel_size
41+
orig_pipeline_model_parallel_size = model_provider.pipeline_model_parallel_size
42+
orig_expert_model_parallel_size = model_provider.expert_model_parallel_size
43+
orig_expert_tensor_parallel_size = model_provider.expert_tensor_parallel_size
44+
orig_num_layers_in_first_pipeline_stage = (
45+
model_provider.num_layers_in_first_pipeline_stage
46+
)
47+
orig_num_layers_in_last_pipeline_stage = (
48+
model_provider.num_layers_in_last_pipeline_stage
49+
)
50+
orig_pipeline_dtype = model_provider.pipeline_dtype
51+
3952
if megatron_config is not None:
4053
model_provider.tensor_model_parallel_size = megatron_config[
4154
"tensor_model_parallel_size"
@@ -59,6 +72,18 @@ def import_model_from_hf_name(
5972
model_provider.initialize_model_parallel(seed=0)
6073
megatron_model = model_provider.provide_distributed_model(wrap_with_ddp=False)
6174

75+
# The above parallelism settings are used to load the model in a distributed manner.
76+
# However, we do not want to save the parallelism settings to the checkpoint config
77+
# because they may result in validation errors when loading the checkpoint.
78+
config = megatron_model[0].config
79+
config.tensor_model_parallel_size = orig_tensor_model_parallel_size
80+
config.pipeline_model_parallel_size = orig_pipeline_model_parallel_size
81+
config.expert_model_parallel_size = orig_expert_model_parallel_size
82+
config.expert_tensor_parallel_size = orig_expert_tensor_parallel_size
83+
config.num_layers_in_first_pipeline_stage = orig_num_layers_in_first_pipeline_stage
84+
config.num_layers_in_last_pipeline_stage = orig_num_layers_in_last_pipeline_stage
85+
config.pipeline_dtype = orig_pipeline_dtype
86+
6287
bridge.save_megatron_model(megatron_model, output_path)
6388

6489
# resetting mcore state

0 commit comments

Comments
 (0)