1414import logging
1515import pathlib
1616import h5py
17+ import random
1718
1819config = """
1920latent_size: 128
2223batch_size: 128
2324test_batch_size: 512
2425max_iterations: 100000
25- log_interval: 5000
26+ log_interval: 10000
2627n_samples: 128
2728use_gpu: true
2829train_dir: $TMPDIR
30+ seed: 582838
2931"""
3032
3133
@@ -116,7 +118,7 @@ def load_binary_mnist(cfg, **kwcfg):
116118 x_val = f ['valid' ][::]
117119 x_test = f ['test' ][::]
118120 train = torch .utils .data .TensorDataset (torch .from_numpy (x_train ))
119- train_loader = torch .utils .data .DataLoader (train , batch_size = cfg .batch_size , shuffle = True )
121+ train_loader = torch .utils .data .DataLoader (train , batch_size = cfg .batch_size , shuffle = True , ** kwcfg )
120122 validation = torch .utils .data .TensorDataset (torch .from_numpy (x_val ))
121123 val_loader = torch .utils .data .DataLoader (validation , batch_size = cfg .test_batch_size , shuffle = False )
122124 test = torch .utils .data .TensorDataset (torch .from_numpy (x_test ))
@@ -148,6 +150,9 @@ def evaluate(n_samples, model, variational, eval_data):
148150 dictionary = yaml .load (config )
149151 cfg = nomen .Config (dictionary )
150152 device = torch .device ("cuda:0" if cfg .use_gpu else "cpu" )
153+ torch .manual_seed (cfg .seed )
154+ np .random .seed (cfg .seed )
155+ random .seed (cfg .seed )
151156
152157 model = Model (latent_size = cfg .latent_size ,
153158 data_size = cfg .data_size ,
@@ -163,7 +168,7 @@ def evaluate(n_samples, model, variational, eval_data):
163168 lr = cfg .learning_rate ,
164169 centered = True )
165170
166- kwargs = {'num_workers' : 0 , 'pin_memory' : False } if cfg .use_gpu else {}
171+ kwargs = {'num_workers' : 4 , 'pin_memory' : True } if cfg .use_gpu else {}
167172 train_data , valid_data , test_data = load_binary_mnist (cfg , ** kwargs )
168173
169174 best_valid_elbo = - np .inf
@@ -188,6 +193,7 @@ def evaluate(n_samples, model, variational, eval_data):
188193 valid_elbo , valid_log_p_x = evaluate (cfg .n_samples , model , variational , valid_data )
189194 print (f'step:\t { step } \t \t valid elbo: { valid_elbo :.2f} \t valid log p(x): { valid_log_p_x :.2f} ' )
190195 if valid_elbo > best_valid_elbo :
196+ num_no_improvement = 0
191197 best_valid_elbo = valid_elbo
192198 states = {'model' : model .state_dict (),
193199 'variational' : variational .state_dict ()}
0 commit comments