Skip to content

Commit 06e3ea9

Browse files
revert training/checkpointing.py
1 parent b218786 commit 06e3ea9

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

megatron/training/checkpointing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1475,13 +1475,13 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load',
14751475
ckpt_args = state_dict.get("args")
14761476

14771477
if not hasattr(ckpt_args, "tensor_model_parallel_size"):
1478-
print_rank_0("WARNING: TP size not found in checkpoint args, using 1 as default.")
1478+
print_rank_0("WARNING: TP size not found in checkpoint args, using 0 as default.")
14791479
if not hasattr(ckpt_args, "pipeline_model_parallel_size"):
1480-
print_rank_0("WARNING: PP size not found in checkpoint args, using 1 as default.")
1480+
print_rank_0("WARNING: PP size not found in checkpoint args, using 0 as default.")
14811481

14821482
ckpt_tp_pp = (
1483-
getattr(ckpt_args, "tensor_model_parallel_size", 1),
1484-
getattr(ckpt_args, "pipeline_model_parallel_size", 1),
1483+
getattr(ckpt_args, "tensor_model_parallel_size", 0),
1484+
getattr(ckpt_args, "pipeline_model_parallel_size", 0),
14851485
)
14861486
run_tp_pp = (
14871487
args.tensor_model_parallel_size,

0 commit comments

Comments
 (0)