-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain_distributed.py
More file actions
135 lines (104 loc) · 4.42 KB
/
main_distributed.py
File metadata and controls
135 lines (104 loc) · 4.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from setup.cfg import Config
from src.trainer import Trainer
import src.model
from src.loss import *
from src.logger import Logger
import argparse
from pprint import pprint
from torch.nn.parallel import DistributedDataParallel as DDP
import torch
import torch.multiprocessing as mp
from torch.distributed import init_process_group, destroy_process_group, all_reduce
import os
def ddp_setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "1235"
init_process_group(backend="nccl", rank=rank, world_size=world_size)
def aggregate_metrics(metrics_dict, rank):
"""Aggregate metrics across all GPUs by summing total_loss, n_correct, n_samples"""
aggregated = {}
for split_name, split_metrics in metrics_dict.items():
# Convert metrics to tensors for all_reduce
total_loss = torch.tensor(split_metrics['total_loss'], dtype=torch.float32, device=f'cuda:{rank}')
n_correct = torch.tensor(split_metrics['n_correct'], dtype=torch.long, device=f'cuda:{rank}')
n_samples = torch.tensor(split_metrics['n_samples'], dtype=torch.long, device=f'cuda:{rank}')
# Sum across all GPUs
all_reduce(total_loss)
all_reduce(n_correct)
all_reduce(n_samples)
# Convert back and compute derived metrics
aggregated[split_name] = {
'total_loss': total_loss.item(),
'n_correct': int(n_correct.item()),
'n_samples': int(n_samples.item()),
'accuracy': n_correct.item() / n_samples.item() if n_samples.item() > 0 else 0.0,
'mean_loss': total_loss.item() / n_samples.item() if n_samples.item() > 0 else 0.0
}
return aggregated
def main_distributed(rank: int, cfg: Config):
"""
Main script for distributed training. Rank (device) is variable and taken care of by
torch.multiprocessing
"""
ddp_setup(rank=rank, world_size=cfg["n_gpus"])
train_dl, val_dl, test_dl = src.dataset.init_dataloader(cfg["dataset"])
model = src.model.init_model(cfg["run"]["model"]).to(rank)
model = DDP(model, device_ids=[rank])
loss = src.loss.init_loss(cfg["run"]["loss"])
optimizer = src.optimizer.init_optimizer(cfg["run"]["optimizer"], model)
# first initialize trainer for each device
trainer = Trainer(train_dl, val_dl, test_dl, model, loss, optimizer, device=rank)
# then run initial epoch to validate that everything runs.
train_metrics = trainer.train(initial=True)
val_metrics = trainer.val()
metrics = {"train": train_metrics, "val": val_metrics}
# then initialize logger, create directory, save config, and save epoch 0 metrics
logger = Logger(cfg, rank)
# Aggregate metrics across all GPUs
agg_metrics = aggregate_metrics(metrics, rank)
logger.save_metrics(agg_metrics, epoch=cfg["run"]["epoch"])
print(f"Epoch: {cfg['run']['epoch']}")
pprint(agg_metrics)
logger.save_state(trainer, epoch=cfg["run"]["epoch"])
try:
for epoch in range(cfg["run"]["epoch"] + 1, cfg["run"]["total_epochs"]):
train_metrics = trainer.train()
val_metrics = trainer.val()
metrics = {"train": train_metrics, "val": val_metrics}
# Aggregate metrics across all GPUs
agg_metrics = aggregate_metrics(metrics, rank)
print(f"Epoch: {epoch}")
pprint(agg_metrics)
logger.save_metrics(agg_metrics, epoch)
logger.save_state(trainer, epoch)
train_metrics = trainer.train()
val_metrics = trainer.val()
test_metrics = trainer.test()
metrics = {"train": train_metrics, "val": val_metrics, "test": test_metrics}
# Aggregate metrics across all GPUs
agg_metrics = aggregate_metrics(metrics, rank)
print(f"Epoch: {cfg['run']['total_epochs']}")
pprint(agg_metrics)
logger.save_metrics(agg_metrics, cfg['run']['total_epochs'])
logger.save_state(trainer, cfg['run']['total_epochs'])
except KeyboardInterrupt:
print("Run Failed. Keyboard Interrupt.")
print(f"Logs in: {cfg['log']['savedir']}")
except Exception as e:
print(e)
print(f"Run Failed. Logs in: {cfg['log']['savedir']}")
finally:
destroy_process_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', type=str, help='Path to config file.')
parser.add_argument('--resume', action='store_true', help='Continue from prev run.')
args = parser.parse_args()
cfg = Config(args.cfg, resume=args.resume)
import time
now = time.time()
if cfg["n_gpus"] > 1:
mp.spawn(main_distributed, args=(cfg,), nprocs=cfg["n_gpus"])
else:
main(cfg)
print(time.time() - now)