@@ -180,6 +180,7 @@ def __init__(self,
180180 if self .logger is None :
181181 self .logger = TestTubeLogger (
182182 save_dir = self .default_save_path ,
183+ version = self .slurm_job_id ,
183184 name = 'lightning_logs'
184185 )
185186
@@ -240,6 +241,15 @@ def __init__(self,
240241 self .amp_level = amp_level
241242 self .__init_amp (use_amp )
242243
244+ @property
245+ def slurm_job_id (self ):
246+ try :
247+ job_id = os .environ ['SLURM_JOB_ID' ]
248+ job_id = int (job_id )
249+ except Exception as e :
250+ job_id = None
251+ return job_id
252+
243253 def __configure_weights_path (self , checkpoint_callback , weights_save_path ):
244254 """
245255 Weight path set in this priority:
@@ -882,12 +892,25 @@ def __init_tcp_connection(self):
882892 :param tries:
883893 :return:
884894 """
885- # sets the appropriate port
895+
896+ # use slurm job id for the port number
897+ # guarantees unique ports across jobs from same grid search
898+ try :
899+ # use the last 4 numbers in the job id as the id
900+ default_port = os .environ ['SLURM_JOB_ID' ]
901+ default_port = default_port [- 4 :]
902+
903+ # all ports should be in the 10k+ range
904+ default_port = int (default_port ) + 10000
905+
906+ except Exception as e :
907+ default_port = 12910
908+
909+ # if user gave a port number, use that one instead
886910 try :
887911 port = os .environ ['MASTER_PORT' ]
888912 except Exception :
889- port = 12910
890- os .environ ['MASTER_PORT' ] = str (port )
913+ os .environ ['MASTER_PORT' ] = str (default_port )
891914
892915 # figure out the root node addr
893916 try :
0 commit comments