Skip to content

Commit 68c8535

Browse files
author
Jaan Altosaar
committed
pin memory; fix seed
1 parent aeee4f9 commit 68c8535

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

train_variational_autoencoder_pytorch.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import logging
1515
import pathlib
1616
import h5py
17+
import random
1718

1819
config = """
1920
latent_size: 128
@@ -22,10 +23,11 @@
2223
batch_size: 128
2324
test_batch_size: 512
2425
max_iterations: 100000
25-
log_interval: 5000
26+
log_interval: 10000
2627
n_samples: 128
2728
use_gpu: true
2829
train_dir: $TMPDIR
30+
seed: 582838
2931
"""
3032

3133

@@ -116,7 +118,7 @@ def load_binary_mnist(cfg, **kwcfg):
116118
x_val = f['valid'][::]
117119
x_test = f['test'][::]
118120
train = torch.utils.data.TensorDataset(torch.from_numpy(x_train))
119-
train_loader = torch.utils.data.DataLoader(train, batch_size=cfg.batch_size, shuffle=True)
121+
train_loader = torch.utils.data.DataLoader(train, batch_size=cfg.batch_size, shuffle=True, **kwcfg)
120122
validation = torch.utils.data.TensorDataset(torch.from_numpy(x_val))
121123
val_loader = torch.utils.data.DataLoader(validation, batch_size=cfg.test_batch_size, shuffle=False)
122124
test = torch.utils.data.TensorDataset(torch.from_numpy(x_test))
@@ -148,6 +150,9 @@ def evaluate(n_samples, model, variational, eval_data):
148150
dictionary = yaml.load(config)
149151
cfg = nomen.Config(dictionary)
150152
device = torch.device("cuda:0" if cfg.use_gpu else "cpu")
153+
torch.manual_seed(cfg.seed)
154+
np.random.seed(cfg.seed)
155+
random.seed(cfg.seed)
151156

152157
model = Model(latent_size=cfg.latent_size,
153158
data_size=cfg.data_size,
@@ -163,7 +168,7 @@ def evaluate(n_samples, model, variational, eval_data):
163168
lr=cfg.learning_rate,
164169
centered=True)
165170

166-
kwargs = {'num_workers': 0, 'pin_memory': False} if cfg.use_gpu else {}
171+
kwargs = {'num_workers': 4, 'pin_memory': True} if cfg.use_gpu else {}
167172
train_data, valid_data, test_data = load_binary_mnist(cfg, **kwargs)
168173

169174
best_valid_elbo = -np.inf
@@ -188,6 +193,7 @@ def evaluate(n_samples, model, variational, eval_data):
188193
valid_elbo, valid_log_p_x = evaluate(cfg.n_samples, model, variational, valid_data)
189194
print(f'step:\t{step}\t\tvalid elbo: {valid_elbo:.2f}\tvalid log p(x): {valid_log_p_x:.2f}')
190195
if valid_elbo > best_valid_elbo:
196+
num_no_improvement = 0
191197
best_valid_elbo = valid_elbo
192198
states = {'model': model.state_dict(),
193199
'variational': variational.state_dict()}

0 commit comments

Comments
 (0)