@@ -50,7 +50,7 @@ def reduce_distributed_output(output, nb_gpus):
5050class Trainer (TrainerIO ):
5151
5252 def __init__ (self ,
53- experiment ,
53+ experiment = None ,
5454 early_stop_callback = None ,
5555 checkpoint_callback = None ,
5656 gradient_clip = 0 ,
@@ -122,7 +122,9 @@ def __init__(self,
122122 self .on_gpu = gpus is not None and torch .cuda .is_available ()
123123 self .progress_bar = progress_bar
124124 self .experiment = experiment
125- self .exp_save_path = experiment .get_data_path (experiment .name , experiment .version )
125+ self .exp_save_path = None
126+ if self .experiment is not None :
127+ self .exp_save_path = experiment .get_data_path (experiment .name , experiment .version )
126128 self .cluster = cluster
127129 self .process_position = process_position
128130 self .current_gpu_name = current_gpu_name
@@ -312,13 +314,15 @@ def __is_function_implemented(self, f_name):
312314
313315 @property
314316 def __tng_tqdm_dic (self ):
315- # ForkedPdb().set_trace()
316317 tqdm_dic = {
317318 'tng_loss' : '{0:.3f}' .format (self .avg_loss ),
318- 'v_nb' : '{}' .format (self .experiment .version ),
319319 'epoch' : '{}' .format (self .current_epoch ),
320320 'batch_nb' : '{}' .format (self .batch_nb ),
321321 }
322+
323+ if self .experiment is not None :
324+ tqdm_dic ['v_nb' ] = self .experiment .version
325+
322326 tqdm_dic .update (self .tqdm_metrics )
323327
324328 if self .on_gpu :
@@ -462,7 +466,8 @@ def fit(self, model):
462466 if self .use_ddp :
463467 # must copy only the meta of the exp so it survives pickle/unpickle
464468 # when going to new process
465- self .experiment = self .experiment .get_meta_copy ()
469+ if self .experiment is not None :
470+ self .experiment = self .experiment .get_meta_copy ()
466471
467472 if self .is_slurm_managing_tasks :
468473 task = int (os .environ ['SLURM_LOCALID' ])
@@ -564,8 +569,9 @@ def ddp_train(self, gpu_nb, model):
564569
565570 # recover original exp before went into process
566571 # init in write mode only on proc 0
567- self .experiment .debug = self .proc_rank > 0
568- self .experiment = self .experiment .get_non_ddp_exp ()
572+ if self .experiment is not None :
573+ self .experiment .debug = self .proc_rank > 0
574+ self .experiment = self .experiment .get_non_ddp_exp ()
569575
570576 # show progbar only on prog_rank 0
571577 self .prog_bar = self .prog_bar and self .node_rank == 0 and gpu_nb == 0
@@ -575,7 +581,8 @@ def ddp_train(self, gpu_nb, model):
575581 self .world_size = self .nb_gpu_nodes * len (self .data_parallel_device_ids )
576582
577583 # let the exp know the rank to avoid overwriting logs
578- self .experiment .rank = self .proc_rank
584+ if self .experiment is not None :
585+ self .experiment .rank = self .proc_rank
579586
580587 # set up server using proc 0's ip address
581588 # try to init for 20 times at max in case ports are taken
@@ -673,10 +680,12 @@ def __run_pretrain_routine(self, model):
673680
674681 # give model convenience properties
675682 ref_model .trainer = self
676- ref_model .experiment = self .experiment
683+
684+ if self .experiment is not None :
685+ ref_model .experiment = self .experiment
677686
678687 # save exp to get started
679- if self .proc_rank == 0 :
688+ if self .proc_rank == 0 and self . experiment is not None :
680689 self .experiment .save ()
681690
682691 # track model now.
@@ -756,7 +765,7 @@ def __train(self):
756765
757766 # when batch should be saved
758767 if (batch_nb + 1 ) % self .log_save_interval == 0 or early_stop_epoch :
759- if self .proc_rank == 0 :
768+ if self .proc_rank == 0 and self . experiment is not None :
760769 self .experiment .save ()
761770
762771 # when metrics should be logged
@@ -784,7 +793,7 @@ def __train(self):
784793 # log metrics
785794 scalar_metrics = self .__metrics_to_scalars (
786795 metrics , blacklist = self .__log_vals_blacklist ())
787- if self .proc_rank == 0 :
796+ if self .proc_rank == 0 and self . experiment is not None :
788797 self .experiment .log (scalar_metrics , global_step = self .global_step )
789798 self .experiment .save ()
790799
@@ -813,7 +822,7 @@ def __train(self):
813822 if stop :
814823 return
815824
816- def __metrics_to_scalars (self , metrics , blacklist = [] ):
825+ def __metrics_to_scalars (self , metrics , blacklist = set () ):
817826 new_metrics = {}
818827 for k , v in metrics .items ():
819828 if type (v ) is torch .Tensor :
0 commit comments