Skip to content

Commit 57f21b7

Browse files
committed
don't load optimizer instead of arbitrarily loading dp-rank 0
1 parent a8e64f6 commit 57f21b7

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

megatron/checkpointing.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,8 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
119119

120120
if use_distributed_optimizer:
121121
model_name = os.path.join(common_path, "model_rng.pt")
122-
data_parallel_rank = 0 if only_model else mpu.get_data_parallel_rank()
123-
optim_name = os.path.join(
124-
common_path + "_%03d" % data_parallel_rank,
122+
optim_name = None if only_model else os.path.join(
123+
common_path + "_%03d" % mpu.get_data_parallel_rank(),
125124
"optim.pt")
126125
else:
127126
model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt")
@@ -421,7 +420,9 @@ def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False, iter
421420
# Load the checkpoint.
422421
try:
423422
model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
424-
if use_distributed_optimizer:
423+
if rank0 or no_load_optim:
424+
optim_state_dict = None
425+
elif use_distributed_optimizer:
425426
optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu')
426427
else:
427428
optim_state_dict = model_state_dict

0 commit comments

Comments
 (0)