Skip to content

Commit 94b8e54

Browse files
Merge pull request #755 from mlcommons/juhan/ddp_fix2
Allow DDP checkpointing
2 parents 04ea6e1 + 80a93bf commit 94b8e54

File tree

3 files changed

+29
-18
lines changed

3 files changed

+29
-18
lines changed

algorithmic_efficiency/checkpoint_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ def maybe_restore_checkpoint(framework: str,
119119

120120
else:
121121
checkpoint_state = latest_ckpt
122-
if isinstance(model_params, torch.nn.DataParallel):
122+
if isinstance(
123+
model_params,
124+
(torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
123125
model_params = model_params.module
124126
model_params.load_state_dict(checkpoint_state['model_params'])
125127
checkpoint_state['model_params'] = model_params
@@ -196,7 +198,9 @@ def save_checkpoint(framework: str,
196198
opt_state = jax.device_get(jax_utils.unreplicate(opt_state))
197199
model_state = jax.device_get(jax_utils.unreplicate(model_state))
198200
else:
199-
if isinstance(model_params, torch.nn.DataParallel):
201+
if isinstance(
202+
model_params,
203+
(torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
200204
model_params = model_params.module
201205
model_params = model_params.state_dict()
202206
optimizer_state_dict = {}

algorithmic_efficiency/logger_utils.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import GPUtil
1717
import pandas as pd
1818
import psutil
19+
import torch.distributed as dist
1920

2021
from algorithmic_efficiency import spec
2122
from algorithmic_efficiency.pytorch_utils import pytorch_setup
@@ -43,9 +44,6 @@ def get_log_dir(
4344
resume_last_run: bool,
4445
overwrite: bool,
4546
) -> Optional[str]:
46-
if RANK != 0:
47-
return
48-
4947
# Construct path to experiment workload directory.
5048
experiment_dir = os.path.expanduser(experiment_dir)
5149
workload_dir_name = f'{workload}_{framework}'
@@ -61,18 +59,25 @@ def get_log_dir(
6159
logging.info(
6260
f'Removing existing experiment directory {experiment_path} because '
6361
'--overwrite was set.')
64-
shutil.rmtree(experiment_path)
62+
if RANK == 0:
63+
shutil.rmtree(experiment_path)
6564
elif resume_last_run:
6665
logging.info(
6766
f'Resuming from experiment directory {experiment_path} because '
6867
'--resume_last_run was set.')
6968
else:
70-
resume = input(
71-
'Found existing experiment dir with the same name: {}. Do you wish '
72-
'to resume training from this dir? [y/N]:'.format(experiment_path))
73-
if resume.lower() != 'y':
74-
sys.exit()
75-
69+
if RANK == 0:
70+
resume = input(
71+
'Found existing experiment dir with the same name: {}. Do you wish '
72+
'to resume training from this dir? [y/N]:'.format(experiment_path))
73+
if resume.lower() != 'y':
74+
sys.exit()
75+
76+
if USE_PYTORCH_DDP:
77+
try:
78+
dist.barrier()
79+
except RuntimeError:
80+
sys.exit()
7681
logging.info(f'Creating experiment directory at {experiment_path}.')
7782
makedir(experiment_path)
7883
return experiment_path

submission_runner.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -316,10 +316,12 @@ def train_once(
316316
flag_file_name = os.path.join(log_dir, f'flags_{preemption_count}.json')
317317
logging.info(f'Saving flags to {flag_file_name}.')
318318
logger_utils.write_json(flag_file_name, flags.FLAGS.flag_values_dict())
319-
metrics_logger = logger_utils.set_up_loggers(log_dir,
320-
flags.FLAGS,
321-
hyperparameters)
322-
workload.attach_metrics_logger(metrics_logger)
319+
metrics_logger = None
320+
if RANK == 0:
321+
metrics_logger = logger_utils.set_up_loggers(log_dir,
322+
flags.FLAGS,
323+
hyperparameters)
324+
workload.attach_metrics_logger(metrics_logger)
323325

324326
global_start_time = get_time()
325327
train_state['last_step_end_time'] = global_start_time
@@ -429,7 +431,7 @@ def train_once(
429431

430432
logging_start_time = get_time()
431433

432-
if log_dir is not None:
434+
if log_dir is not None and RANK == 0:
433435
metrics_logger.append_scalar_metrics(
434436
latest_eval_result,
435437
global_step=global_step,
@@ -467,7 +469,7 @@ def train_once(
467469

468470
metrics = {'eval_results': eval_results, 'global_step': global_step}
469471

470-
if log_dir is not None:
472+
if log_dir is not None and RANK == 0:
471473
metrics_logger.append_scalar_metrics(
472474
{'score': train_state['accumulated_submission_time']},
473475
global_step=global_step,

0 commit comments

Comments
 (0)