Skip to content

Commit a8e64f6

Browse files
committed
support checkpoints with distrib optimizer in checkpoint-util
1 parent bd12802 commit a8e64f6

File tree

3 files changed

+16
-7
lines changed

3 files changed

+16
-7
lines changed

megatron/checkpointing.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def ensure_directory_exists(filename):
9292

9393

9494
def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release=False,
95-
pipeline_parallel=None, tensor_rank=None, pipeline_rank=None):
95+
pipeline_parallel=None, tensor_rank=None, pipeline_rank=None, only_model=False):
9696
"""Determine the directory name for this rank's checkpoint."""
9797
if release:
9898
directory = 'release'
@@ -119,8 +119,9 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
119119

120120
if use_distributed_optimizer:
121121
model_name = os.path.join(common_path, "model_rng.pt")
122+
data_parallel_rank = 0 if only_model else mpu.get_data_parallel_rank()
122123
optim_name = os.path.join(
123-
common_path + "_%03d" % mpu.get_data_parallel_rank(),
124+
common_path + "_%03d" % data_parallel_rank,
124125
"optim.pt")
125126
else:
126127
model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt")
@@ -139,14 +140,14 @@ def find_checkpoint_rank_0(checkpoints_path, iteration, use_distributed_optimize
139140
# Look for checkpoint with no pipelining
140141
filenames = get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release,
141142
pipeline_parallel=False,
142-
tensor_rank=0, pipeline_rank=0)
143+
tensor_rank=0, pipeline_rank=0, only_model=True)
143144
if os.path.isfile(filenames[0]):
144145
return filenames
145146

146147
# Look for checkpoint with pipelining
147148
filenames = get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release,
148149
pipeline_parallel=True,
149-
tensor_rank=0, pipeline_rank=0)
150+
tensor_rank=0, pipeline_rank=0, only_model=True)
150151
if os.path.isfile(filenames[0]):
151152
return filenames
152153

@@ -379,10 +380,11 @@ def fix_query_key_value_ordering(model, checkpoint_version):
379380
print_rank_0(" succesfully fixed query-key-values ordering for"
380381
" checkpoint version {}".format(checkpoint_version))
381382

382-
def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False, iteration=None, release=None):
383+
def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False, iteration=None, release=None, no_load_optim=False):
383384
""" Load the base state_dict from the given directory
384385
385386
If rank0 is true, just loads rank 0 checkpoint, ignoring arguments.
387+
If rank0 is true or no_load_optim is true, we do not care about the optimizer, only the model checkpoint.
386388
"""
387389

388390
# Read the tracker file and set the iteration.
@@ -408,7 +410,7 @@ def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False, iter
408410
release)
409411
else:
410412
checkpoint_names = get_checkpoint_names(load_dir, iteration, use_distributed_optimizer,
411-
release)
413+
release, only_model=no_load_optim)
412414
if release:
413415
print_rank_0(f' loading release checkpoint from {load_dir}')
414416
else:
@@ -572,7 +574,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
572574
use_distributed_optimizer=args.use_distributed_optimizer,
573575
rank0=False,
574576
iteration=iteration,
575-
release=release)
577+
release=release,
578+
no_load_optim=args.no_load_optim)
576579

577580
if model_state_dict is None:
578581
return 0

tools/checkpoint_loader_megatron.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def _load_checkpoint(queue, args):
5050
'--no-initialization',
5151
'--load', args.load_dir
5252
]
53+
if args.use_distributed_optimizer:
54+
sys.argv.append("--use-distributed-optimizer")
55+
5356

5457
margs = parse_args()
5558
margs = load_args_from_checkpoint(margs)

tools/checkpoint_util.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ def main():
124124
parser.add_argument('--no-checking', action='store_false',
125125
help='Do not perform checking on the name and ordering of weights',
126126
dest='checking')
127+
128+
parser.add_argument('--use-distributed-optimizer', action='store_true',
129+
help='Loaded checkpoint uses distributed optimizer.')
127130

128131
known_args, _ = parser.parse_known_args()
129132
loader = load_plugin('loader', known_args.loader)

0 commit comments

Comments
 (0)