Skip to content

Commit 476911d

Browse files
Pid port + duplicate rank_zero logging (#2231)
* init the port using a seed that matches process id for ddp * init the port using a seed that matches process id for ddp * init the port using a seed that matches process id for ddp * init the port using a seed that matches process id for ddp * init the port using a seed that matches process id for ddp * init the port using a seed that matches process id for ddp * init the port using a seed that matches process id for ddp Co-authored-by: Zhaofeng Wu <[email protected]>
1 parent 15cf6a8 commit 476911d

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

pytorch_lightning/trainer/distrib_data_parallel.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -372,14 +372,16 @@ def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):
372372
def __set_random_port(self):
373373
"""
374374
When running DDP NOT managed by SLURM, the ports might collide
375-
:return:
376375
"""
377376
try:
378377
default_port = os.environ['MASTER_PORT']
379378
except Exception:
380-
import random
381-
default_port = random.randint(10000, 19000)
382-
os.environ['MASTER_PORT'] = str(default_port)
379+
# use the process id as a seed to a generator for port only
380+
pid = os.getpid()
381+
rng1 = np.random.RandomState(pid)
382+
default_port = rng1.randint(10000, 19999, 1)[0]
383+
384+
os.environ['MASTER_PORT'] = str(default_port)
383385

384386
def spawn_ddp_children(self, model):
385387
self.__set_random_port()

pytorch_lightning/trainer/trainer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
3333
from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin
3434
from pytorch_lightning.utilities.exceptions import MisconfigurationException
35-
from pytorch_lightning.utilities import rank_zero_warn, parsing, rank_zero_info
35+
from pytorch_lightning.utilities import rank_zero_warn, parsing, rank_zero_info, rank_zero_only
3636

3737
try:
3838
from apex import amp
@@ -322,6 +322,14 @@ def __init__(
322322
# https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383
323323
os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0)
324324

325+
# init the default rank if exists
326+
# we need to call this here or NVIDIA flags and other messaging in init will show on all ranks
327+
# this way we only show it on rank 0
328+
if 'LOCAL_RANK' in os.environ:
329+
rank_zero_only.rank = os.environ['LOCAL_RANK']
330+
if 'SLURM_JOB_ID' in os.environ:
331+
rank_zero_only.rank = os.environ['SLURM_JOB_ID']
332+
325333
# Init callbacks
326334
self.prepare_data_per_node = prepare_data_per_node
327335
self.callbacks = callbacks or []
@@ -892,6 +900,7 @@ def fit(
892900
mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model,))
893901

894902
elif self.distributed_backend == 'ddp_spawn':
903+
self.__set_random_port()
895904
model.share_memory()
896905

897906
# spin up peers

0 commit comments

Comments
 (0)