-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun.py
More file actions
88 lines (69 loc) · 2.66 KB
/
run.py
File metadata and controls
88 lines (69 loc) · 2.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import datetime
import os
import hydra
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from experiments.tagging.experiment import TopTaggingExperiment
from experiments.tagging.finetuneexperiment import TopTaggingFineTuneExperiment
from experiments.tagging.jetclassexperiment import JetClassTaggingExperiment
from experiments.tagging.toptagxlexperiment import TopTagXLExperiment
@hydra.main(config_path="config_quick", config_name="toptagging", version_base=None)
def main(cfg):
if torch.cuda.is_available() and cfg.gpus == -1:
world_size = torch.cuda.device_count()
elif torch.cuda.is_available() and cfg.gpus >= 1:
world_size = cfg.gpus
else:
world_size = 1
if world_size > 1:
os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1")
os.environ.setdefault("NCCL_DEBUG", "WARN")
os.environ.setdefault("NCCL_IB_DISABLE", "1")
os.environ.setdefault("CUDA_DEVICE_MAX_CONNECTIONS", "1")
os.environ.setdefault("OMP_NUM_THREADS", "1")
_set_common_env(world_size)
mp.spawn(ddp_worker, nprocs=world_size, args=(cfg,))
else:
# no GPU or only one GPU -> run on main process
ddp_worker(rank=0, cfg=cfg)
def ddp_worker(rank, cfg):
world_size = int(os.environ.get("WORLD_SIZE", 1))
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
if world_size > 1:
# set up communication between processes
dist.init_process_group(
backend="nccl",
init_method="env://",
world_size=world_size,
rank=rank,
timeout=datetime.timedelta(minutes=30),
)
torch.cuda.set_device(rank)
if cfg.exp_type == "toptagging":
constructor = TopTaggingExperiment
elif cfg.exp_type == "toptaggingft":
constructor = TopTaggingFineTuneExperiment
elif cfg.exp_type == "toptagxl":
constructor = TopTagXLExperiment
elif cfg.exp_type == "jctagging":
constructor = JetClassTaggingExperiment
else:
raise ValueError(f"exp_type {cfg.exp_type} not implemented")
exp = constructor(cfg, rank, world_size)
exp()
if world_size > 1:
dist.barrier(device_ids=[rank])
dist.destroy_process_group()
def _set_common_env(world_size):
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
os.environ.setdefault("MASTER_PORT", str(_find_free_port()))
os.environ["WORLD_SIZE"] = str(world_size)
def _find_free_port():
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
if __name__ == "__main__":
main()