@@ -92,7 +92,7 @@ def ensure_directory_exists(filename):
9292
9393
9494def get_checkpoint_names (checkpoints_path , iteration , use_distributed_optimizer , release = False ,
95- pipeline_parallel = None , tensor_rank = None , pipeline_rank = None ):
95+ pipeline_parallel = None , tensor_rank = None , pipeline_rank = None , only_model = False ):
9696 """Determine the directory name for this rank's checkpoint."""
9797 if release :
9898 directory = 'release'
@@ -119,7 +119,7 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
119119
120120 if use_distributed_optimizer :
121121 model_name = os .path .join (common_path , "model_rng.pt" )
122- optim_name = os .path .join (
122+ optim_name = None if only_model else os .path .join (
123123 common_path + "_%03d" % mpu .get_data_parallel_rank (),
124124 "optim.pt" )
125125 else :
@@ -139,14 +139,14 @@ def find_checkpoint_rank_0(checkpoints_path, iteration, use_distributed_optimize
139139 # Look for checkpoint with no pipelining
140140 filenames = get_checkpoint_names (checkpoints_path , iteration , use_distributed_optimizer , release ,
141141 pipeline_parallel = False ,
142- tensor_rank = 0 , pipeline_rank = 0 )
142+ tensor_rank = 0 , pipeline_rank = 0 , only_model = True )
143143 if os .path .isfile (filenames [0 ]):
144144 return filenames
145145
146146 # Look for checkpoint with pipelining
147147 filenames = get_checkpoint_names (checkpoints_path , iteration , use_distributed_optimizer , release ,
148148 pipeline_parallel = True ,
149- tensor_rank = 0 , pipeline_rank = 0 )
149+ tensor_rank = 0 , pipeline_rank = 0 , only_model = True )
150150 if os .path .isfile (filenames [0 ]):
151151 return filenames
152152
@@ -379,10 +379,11 @@ def fix_query_key_value_ordering(model, checkpoint_version):
379379 print_rank_0 (" succesfully fixed query-key-values ordering for"
380380 " checkpoint version {}" .format (checkpoint_version ))
381381
382- def _load_base_checkpoint (load_dir , use_distributed_optimizer , rank0 = False , iteration = None , release = None ):
382+ def _load_base_checkpoint (load_dir , use_distributed_optimizer , rank0 = False , iteration = None , release = None , no_load_optim = False ):
383383 """ Load the base state_dict from the given directory
384384
385385 If rank0 is true, just loads rank 0 checkpoint, ignoring arguments.
386+ If rank0 is true or no_load_optim is true, we do not care about the optimizer, only the model checkpoint.
386387 """
387388
388389 # Read the tracker file and set the iteration.
@@ -408,7 +409,7 @@ def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False, iter
408409 release )
409410 else :
410411 checkpoint_names = get_checkpoint_names (load_dir , iteration , use_distributed_optimizer ,
411- release )
412+ release , only_model = no_load_optim )
412413 if release :
413414 print_rank_0 (f' loading release checkpoint from { load_dir } ' )
414415 else :
@@ -419,7 +420,9 @@ def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False, iter
419420 # Load the checkpoint.
420421 try :
421422 model_state_dict = torch .load (model_checkpoint_name , map_location = 'cpu' )
422- if use_distributed_optimizer :
423+ if rank0 or no_load_optim :
424+ optim_state_dict = None
425+ elif use_distributed_optimizer :
423426 optim_state_dict = torch .load (optim_checkpoint_name , map_location = 'cpu' )
424427 else :
425428 optim_state_dict = model_state_dict
@@ -572,7 +575,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
572575 use_distributed_optimizer = args .use_distributed_optimizer ,
573576 rank0 = False ,
574577 iteration = iteration ,
575- release = release )
578+ release = release ,
579+ no_load_optim = args .no_load_optim )
576580
577581 if model_state_dict is None :
578582 return 0
0 commit comments