Skip to content

Commit 3d23a56

Browse files
make experiment param in trainer optional (#77)
* removed forced exp * modified test to also run without exp
1 parent aa7245d commit 3d23a56

File tree

2 files changed

+22
-19
lines changed

2 files changed

+22
-19
lines changed

pytorch_lightning/models/trainer.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def reduce_distributed_output(output, nb_gpus):
5050
class Trainer(TrainerIO):
5151

5252
def __init__(self,
53-
experiment,
53+
experiment=None,
5454
early_stop_callback=None,
5555
checkpoint_callback=None,
5656
gradient_clip=0,
@@ -122,7 +122,9 @@ def __init__(self,
122122
self.on_gpu = gpus is not None and torch.cuda.is_available()
123123
self.progress_bar = progress_bar
124124
self.experiment = experiment
125-
self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version)
125+
self.exp_save_path = None
126+
if self.experiment is not None:
127+
self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version)
126128
self.cluster = cluster
127129
self.process_position = process_position
128130
self.current_gpu_name = current_gpu_name
@@ -312,13 +314,15 @@ def __is_function_implemented(self, f_name):
312314

313315
@property
314316
def __tng_tqdm_dic(self):
315-
# ForkedPdb().set_trace()
316317
tqdm_dic = {
317318
'tng_loss': '{0:.3f}'.format(self.avg_loss),
318-
'v_nb': '{}'.format(self.experiment.version),
319319
'epoch': '{}'.format(self.current_epoch),
320320
'batch_nb': '{}'.format(self.batch_nb),
321321
}
322+
323+
if self.experiment is not None:
324+
tqdm_dic['v_nb'] = self.experiment.version
325+
322326
tqdm_dic.update(self.tqdm_metrics)
323327

324328
if self.on_gpu:
@@ -462,7 +466,8 @@ def fit(self, model):
462466
if self.use_ddp:
463467
# must copy only the meta of the exp so it survives pickle/unpickle
464468
# when going to new process
465-
self.experiment = self.experiment.get_meta_copy()
469+
if self.experiment is not None:
470+
self.experiment = self.experiment.get_meta_copy()
466471

467472
if self.is_slurm_managing_tasks:
468473
task = int(os.environ['SLURM_LOCALID'])
@@ -564,8 +569,9 @@ def ddp_train(self, gpu_nb, model):
564569

565570
# recover original exp before went into process
566571
# init in write mode only on proc 0
567-
self.experiment.debug = self.proc_rank > 0
568-
self.experiment = self.experiment.get_non_ddp_exp()
572+
if self.experiment is not None:
573+
self.experiment.debug = self.proc_rank > 0
574+
self.experiment = self.experiment.get_non_ddp_exp()
569575

570576
# show progbar only on prog_rank 0
571577
self.prog_bar = self.prog_bar and self.node_rank == 0 and gpu_nb == 0
@@ -575,7 +581,8 @@ def ddp_train(self, gpu_nb, model):
575581
self.world_size = self.nb_gpu_nodes * len(self.data_parallel_device_ids)
576582

577583
# let the exp know the rank to avoid overwriting logs
578-
self.experiment.rank = self.proc_rank
584+
if self.experiment is not None:
585+
self.experiment.rank = self.proc_rank
579586

580587
# set up server using proc 0's ip address
581588
# try to init for 20 times at max in case ports are taken
@@ -673,10 +680,12 @@ def __run_pretrain_routine(self, model):
673680

674681
# give model convenience properties
675682
ref_model.trainer = self
676-
ref_model.experiment = self.experiment
683+
684+
if self.experiment is not None:
685+
ref_model.experiment = self.experiment
677686

678687
# save exp to get started
679-
if self.proc_rank == 0:
688+
if self.proc_rank == 0 and self.experiment is not None:
680689
self.experiment.save()
681690

682691
# track model now.
@@ -756,7 +765,7 @@ def __train(self):
756765

757766
# when batch should be saved
758767
if (batch_nb + 1) % self.log_save_interval == 0 or early_stop_epoch:
759-
if self.proc_rank == 0:
768+
if self.proc_rank == 0 and self.experiment is not None:
760769
self.experiment.save()
761770

762771
# when metrics should be logged
@@ -784,7 +793,7 @@ def __train(self):
784793
# log metrics
785794
scalar_metrics = self.__metrics_to_scalars(
786795
metrics, blacklist=self.__log_vals_blacklist())
787-
if self.proc_rank == 0:
796+
if self.proc_rank == 0 and self.experiment is not None:
788797
self.experiment.log(scalar_metrics, global_step=self.global_step)
789798
self.experiment.save()
790799

@@ -813,7 +822,7 @@ def __train(self):
813822
if stop:
814823
return
815824

816-
def __metrics_to_scalars(self, metrics, blacklist=[]):
825+
def __metrics_to_scalars(self, metrics, blacklist=set()):
817826
new_metrics = {}
818827
for k, v in metrics.items():
819828
if type(v) is torch.Tensor:

tests/test_models.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,10 @@ def test_simple_cpu():
3737
save_dir = init_save_dir()
3838

3939
# exp file to get meta
40-
test_exp_version = 10
41-
exp = get_exp(False, version=test_exp_version)
42-
exp.argparse(hparams)
43-
exp.save()
44-
4540
trainer_options = dict(
4641
max_nb_epochs=1,
4742
val_percent_check=0.1,
4843
train_percent_check=0.1,
49-
experiment=exp,
5044
)
5145

5246
# fit model

0 commit comments

Comments
 (0)