@@ -92,7 +92,7 @@ def ensure_directory_exists(filename):
9292
9393
9494def get_checkpoint_names (checkpoints_path , iteration , use_distributed_optimizer , release = False ,
95- pipeline_parallel = None , tensor_rank = None , pipeline_rank = None ):
95+ pipeline_parallel = None , tensor_rank = None , pipeline_rank = None , only_model = False ):
9696 """Determine the directory name for this rank's checkpoint."""
9797 if release :
9898 directory = 'release'
@@ -119,8 +119,9 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
119119
120120 if use_distributed_optimizer :
121121 model_name = os .path .join (common_path , "model_rng.pt" )
122+ data_parallel_rank = 0 if only_model else mpu .get_data_parallel_rank ()
122123 optim_name = os .path .join (
123- common_path + "_%03d" % mpu . get_data_parallel_rank () ,
124+ common_path + "_%03d" % data_parallel_rank ,
124125 "optim.pt" )
125126 else :
126127 model_name = optim_name = os .path .join (common_path , "model_optim_rng.pt" )
@@ -139,14 +140,14 @@ def find_checkpoint_rank_0(checkpoints_path, iteration, use_distributed_optimize
139140 # Look for checkpoint with no pipelining
140141 filenames = get_checkpoint_names (checkpoints_path , iteration , use_distributed_optimizer , release ,
141142 pipeline_parallel = False ,
142- tensor_rank = 0 , pipeline_rank = 0 )
143+ tensor_rank = 0 , pipeline_rank = 0 , only_model = True )
143144 if os .path .isfile (filenames [0 ]):
144145 return filenames
145146
146147 # Look for checkpoint with pipelining
147148 filenames = get_checkpoint_names (checkpoints_path , iteration , use_distributed_optimizer , release ,
148149 pipeline_parallel = True ,
149- tensor_rank = 0 , pipeline_rank = 0 )
150+ tensor_rank = 0 , pipeline_rank = 0 , only_model = True )
150151 if os .path .isfile (filenames [0 ]):
151152 return filenames
152153
@@ -379,10 +380,11 @@ def fix_query_key_value_ordering(model, checkpoint_version):
379380 print_rank_0 (" succesfully fixed query-key-values ordering for"
380381 " checkpoint version {}" .format (checkpoint_version ))
381382
382- def _load_base_checkpoint (load_dir , use_distributed_optimizer , rank0 = False , iteration = None , release = None ):
383+ def _load_base_checkpoint (load_dir , use_distributed_optimizer , rank0 = False , iteration = None , release = None , no_load_optim = False ):
383384 """ Load the base state_dict from the given directory
384385
385386 If rank0 is true, just loads rank 0 checkpoint, ignoring arguments.
387+ If rank0 is true or no_load_optim is true, we do not care about the optimizer, only the model checkpoint.
386388 """
387389
388390 # Read the tracker file and set the iteration.
@@ -408,7 +410,7 @@ def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False, iter
408410 release )
409411 else :
410412 checkpoint_names = get_checkpoint_names (load_dir , iteration , use_distributed_optimizer ,
411- release )
413+ release , only_model = no_load_optim )
412414 if release :
413415 print_rank_0 (f' loading release checkpoint from { load_dir } ' )
414416 else :
@@ -572,7 +574,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
572574 use_distributed_optimizer = args .use_distributed_optimizer ,
573575 rank0 = False ,
574576 iteration = iteration ,
575- release = release )
577+ release = release ,
578+ no_load_optim = args .no_load_optim )
576579
577580 if model_state_dict is None :
578581 return 0
0 commit comments