Skip to content

Commit cdfcb01

Browse files
Fixes #234 (#311)
* Fixes #234 * default logger version is now slurm job id * default logger version is now slurm job id
1 parent ed86bf9 commit cdfcb01

File tree

1 file changed

+26
-3
lines changed

1 file changed

+26
-3
lines changed

pytorch_lightning/trainer/trainer.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def __init__(self,
180180
if self.logger is None:
181181
self.logger = TestTubeLogger(
182182
save_dir=self.default_save_path,
183+
version=self.slurm_job_id,
183184
name='lightning_logs'
184185
)
185186

@@ -240,6 +241,15 @@ def __init__(self,
240241
self.amp_level = amp_level
241242
self.__init_amp(use_amp)
242243

244+
@property
245+
def slurm_job_id(self):
246+
try:
247+
job_id = os.environ['SLURM_JOB_ID']
248+
job_id = int(job_id)
249+
except Exception as e:
250+
job_id = None
251+
return job_id
252+
243253
def __configure_weights_path(self, checkpoint_callback, weights_save_path):
244254
"""
245255
Weight path set in this priority:
@@ -882,12 +892,25 @@ def __init_tcp_connection(self):
882892
:param tries:
883893
:return:
884894
"""
885-
# sets the appropriate port
895+
896+
# use slurm job id for the port number
897+
# guarantees unique ports across jobs from same grid search
898+
try:
899+
# use the last 4 numbers in the job id as the id
900+
default_port = os.environ['SLURM_JOB_ID']
901+
default_port = default_port[-4:]
902+
903+
# all ports should be in the 10k+ range
904+
default_port = int(default_port) + 10000
905+
906+
except Exception as e:
907+
default_port = 12910
908+
909+
# if user gave a port number, use that one instead
886910
try:
887911
port = os.environ['MASTER_PORT']
888912
except Exception:
889-
port = 12910
890-
os.environ['MASTER_PORT'] = str(port)
913+
os.environ['MASTER_PORT'] = str(default_port)
891914

892915
# figure out the root node addr
893916
try:

0 commit comments

Comments
 (0)