Skip to content

Commit aeee4f9

Browse files
author
Jaan Altosaar
committed
fix test marginal likelihood
1 parent 7ec7132 commit aeee4f9

File tree

1 file changed

+114
-59
lines changed

1 file changed

+114
-59
lines changed
Lines changed: 114 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
1-
"""Fit a VAE to MNIST.
1+
"""Fit a variational autoencoder to MNIST.
22
3-
Conventions:
3+
Notes:
4+
- run https://github.com/altosaar/proximity_vi/blob/master/get_binary_mnist.py to download binary MNIST file
45
- batch size is the innermost dimension, then the sample dimension, then latent dimension
56
"""
67
import torch
78
import torch.utils
9+
import torch.utils.data
810
from torch import nn
911
import nomen
1012
import yaml
1113
import numpy as np
1214
import logging
13-
14-
import data
15+
import pathlib
16+
import h5py
1517

1618
config = """
1719
latent_size: 128
@@ -20,96 +22,118 @@
2022
batch_size: 128
2123
test_batch_size: 512
2224
max_iterations: 100000
23-
log_interval: 1000
24-
n_samples: 77
25+
log_interval: 5000
26+
n_samples: 128
27+
use_gpu: true
28+
train_dir: $TMPDIR
2529
"""
2630

27-
class NeuralNetwork(nn.Module):
28-
def __init__(self, input_size, output_size, hidden_size):
29-
super().__init__()
30-
modules = [nn.Linear(input_size, hidden_size),
31-
nn.ReLU(),
32-
nn.Linear(hidden_size, hidden_size),
33-
nn.ReLU(),
34-
nn.Linear(hidden_size, output_size)]
35-
self.net = nn.Sequential(*modules)
36-
37-
def forward(self, input):
38-
return self.net(input)
39-
40-
4131

4232
class Model(nn.Module):
4333
"""Bernoulli model parameterized by a generative network with Gaussian latents for MNIST."""
44-
def __init__(self, latent_size, data_size, batch_size):
34+
def __init__(self, latent_size, data_size, batch_size, device):
4535
super().__init__()
46-
# prior on latents is standard normal
47-
self.p_z = torch.distributions.Normal(torch.zeros(latent_size), torch.ones(latent_size))
48-
# likelihood is bernoulli, equivalent to negative binary cross entropy
36+
self.p_z = torch.distributions.Normal(
37+
torch.zeros(latent_size, device=device),
38+
torch.ones(latent_size, device=device))
4939
self.log_p_x = BernoulliLogProb()
50-
# generative network is a MLP
51-
self.generative_network = NeuralNetwork(input_size=latent_size, output_size=data_size, hidden_size=latent_size * 2)
52-
40+
self.generative_network = NeuralNetwork(input_size=latent_size,
41+
output_size=data_size,
42+
hidden_size=latent_size * 2)
5343

5444
def forward(self, z, x):
5545
"""Return log probability of model."""
5646
log_p_z = self.p_z.log_prob(z).sum(-1)
5747
logits = self.generative_network(z)
48+
# unsqueeze sample dimension
49+
logits, x = torch.broadcast_tensors(logits, x.unsqueeze(1))
5850
log_p_x = self.log_p_x(logits, x).sum(-1)
5951
return log_p_z + log_p_x
6052

61-
62-
class NormalLogProb(nn.Module):
63-
def __init__(self):
64-
super().__init__()
65-
66-
def forward(self, loc, scale, z):
67-
var = torch.pow(scale, 2)
68-
return -0.5 * torch.log(2 * np.pi * var) + torch.pow(z - loc, 2) / (2 * var)
69-
70-
class BernoulliLogProb(nn.Module):
71-
def __init__(self):
72-
super().__init__()
73-
self.bce_with_logits = nn.BCEWithLogitsLoss(reduction='none')
74-
75-
def forward(self, logits, target):
76-
logits, target = torch.broadcast_tensors(logits, target.unsqueeze(1))
77-
return -self.bce_with_logits(logits, target)
7853

7954
class Variational(nn.Module):
8055
"""Approximate posterior parameterized by an inference network."""
8156
def __init__(self, latent_size, data_size):
8257
super().__init__()
83-
self.inference_network = NeuralNetwork(input_size=data_size, output_size=latent_size * 2, hidden_size=latent_size*2)
58+
self.inference_network = NeuralNetwork(input_size=data_size,
59+
output_size=latent_size * 2,
60+
hidden_size=latent_size*2)
8461
self.log_q_z = NormalLogProb()
8562
self.softplus = nn.Softplus()
8663

8764
def forward(self, x, n_samples=1):
8865
"""Return sample of latent variable and log prob."""
8966
loc, scale_arg = torch.chunk(self.inference_network(x).unsqueeze(1), chunks=2, dim=-1)
9067
scale = self.softplus(scale_arg)
91-
eps = torch.randn((loc.shape[0], n_samples, loc.shape[-1]))
68+
eps = torch.randn((loc.shape[0], n_samples, loc.shape[-1]), device=loc.device)
9269
z = loc + scale * eps # reparameterization
9370
log_q_z = self.log_q_z(loc, scale, z).sum(-1)
9471
return z, log_q_z
9572

9673

74+
class NeuralNetwork(nn.Module):
75+
def __init__(self, input_size, output_size, hidden_size):
76+
super().__init__()
77+
modules = [nn.Linear(input_size, hidden_size),
78+
nn.ReLU(),
79+
nn.Linear(hidden_size, hidden_size),
80+
nn.ReLU(),
81+
nn.Linear(hidden_size, output_size)]
82+
self.net = nn.Sequential(*modules)
83+
84+
def forward(self, input):
85+
return self.net(input)
86+
87+
88+
class NormalLogProb(nn.Module):
89+
def __init__(self):
90+
super().__init__()
91+
92+
def forward(self, loc, scale, z):
93+
var = torch.pow(scale, 2)
94+
return -0.5 * torch.log(2 * np.pi * var) - torch.pow(z - loc, 2) / (2 * var)
95+
96+
97+
class BernoulliLogProb(nn.Module):
98+
def __init__(self):
99+
super().__init__()
100+
self.bce_with_logits = nn.BCEWithLogitsLoss(reduction='none')
101+
102+
def forward(self, logits, target):
103+
# bernoulli log prob is equivalent to negative binary cross entropy
104+
return -self.bce_with_logits(logits, target)
105+
106+
97107
def cycle(iterable):
98108
while True:
99109
for x in iterable:
100110
yield x
101111

102112

113+
def load_binary_mnist(cfg, **kwcfg):
114+
f = h5py.File(pathlib.os.path.join(pathlib.os.environ['DAT'], 'binarized_mnist.hdf5'), 'r')
115+
x_train = f['train'][::]
116+
x_val = f['valid'][::]
117+
x_test = f['test'][::]
118+
train = torch.utils.data.TensorDataset(torch.from_numpy(x_train))
119+
train_loader = torch.utils.data.DataLoader(train, batch_size=cfg.batch_size, shuffle=True)
120+
validation = torch.utils.data.TensorDataset(torch.from_numpy(x_val))
121+
val_loader = torch.utils.data.DataLoader(validation, batch_size=cfg.test_batch_size, shuffle=False)
122+
test = torch.utils.data.TensorDataset(torch.from_numpy(x_test))
123+
test_loader = torch.utils.data.DataLoader(test, batch_size=cfg.test_batch_size, shuffle=False)
124+
return train_loader, val_loader, test_loader
125+
126+
103127
def evaluate(n_samples, model, variational, eval_data):
104128
model.eval()
105129
total_log_p_x = 0.0
106130
total_elbo = 0.0
107131
for batch in eval_data:
108-
x = batch[0]
132+
x = batch[0].to(next(model.parameters()).device)
109133
z, log_q_z = variational(x, n_samples)
110134
log_p_x_and_z = model(z, x)
111-
# importance sampling of approximate marginal likelihood
112-
# using logsumexp in the sample dimension
135+
# importance sampling of approximate marginal likelihood with q(z)
136+
# as the proposal, and logsumexp in the sample dimension
113137
elbo = log_p_x_and_z - log_q_z
114138
log_p_x = torch.logsumexp(elbo, dim=1) - np.log(n_samples)
115139
# average over sample dimension, sum over minibatch
@@ -123,28 +147,59 @@ def evaluate(n_samples, model, variational, eval_data):
123147
if __name__ == '__main__':
124148
dictionary = yaml.load(config)
125149
cfg = nomen.Config(dictionary)
126-
127-
model = Model(latent_size=cfg.latent_size, data_size=cfg.data_size, batch_size=cfg.batch_size)
128-
variational = Variational(latent_size=cfg.latent_size, data_size=cfg.data_size)
150+
device = torch.device("cuda:0" if cfg.use_gpu else "cpu")
151+
152+
model = Model(latent_size=cfg.latent_size,
153+
data_size=cfg.data_size,
154+
batch_size=cfg.batch_size,
155+
device=device)
156+
variational = Variational(latent_size=cfg.latent_size,
157+
data_size=cfg.data_size)
158+
model.to(device)
159+
variational.to(device)
160+
161+
optimizer = torch.optim.RMSprop(list(model.parameters()) +
162+
list(variational.parameters()),
163+
lr=cfg.learning_rate,
164+
centered=True)
129165

130-
optimizer = torch.optim.RMSprop(list(model.parameters()) + list(variational.parameters()),
131-
lr=cfg.learning_rate)
166+
kwargs = {'num_workers': 0, 'pin_memory': False} if cfg.use_gpu else {}
167+
train_data, valid_data, test_data = load_binary_mnist(cfg, **kwargs)
132168

133-
train_data, valid_data, test_data = data.load_binary_mnist(cfg)
169+
best_valid_elbo = -np.inf
170+
num_no_improvement = 0
134171

135172
for step, batch in enumerate(cycle(train_data)):
136-
x = batch[0]
173+
x = batch[0].to(device)
137174
model.zero_grad()
138175
variational.zero_grad()
139176
z, log_q_z = variational(x)
140177
log_p_x_and_z = model(z, x)
178+
# average over sample dimension
141179
elbo = (log_p_x_and_z - log_q_z).mean(1)
142-
loss = -elbo.mean(0)
180+
# sum over batch dimension
181+
loss = -elbo.sum(0)
143182
loss.backward()
144183
optimizer.step()
145184

146185
if step % cfg.log_interval == 0:
147-
print(f'step:\t{step}\ttrain elbo: {elbo.detach().cpu().numpy()[0]:.2f}')
186+
print(f'step:\t{step}\ttrain elbo: {elbo.detach().cpu().numpy().mean():.2f}')
148187
with torch.no_grad():
149188
valid_elbo, valid_log_p_x = evaluate(cfg.n_samples, model, variational, valid_data)
150-
print(f'step:\t{step}\tvalid elbo: {valid_elbo:.2f}\tvalid log p(x): {valid_log_p_x:.2f}')
189+
print(f'step:\t{step}\t\tvalid elbo: {valid_elbo:.2f}\tvalid log p(x): {valid_log_p_x:.2f}')
190+
if valid_elbo > best_valid_elbo:
191+
best_valid_elbo = valid_elbo
192+
states = {'model': model.state_dict(),
193+
'variational': variational.state_dict()}
194+
torch.save(states, cfg.train_dir / 'best_state_dict')
195+
else:
196+
num_no_improvement += 1
197+
198+
if num_no_improvement > 5:
199+
checkpoint = torch.load(cfg.train_dir / 'best_state_dict')
200+
model.load_state_dict(checkpoint['model'])
201+
variational.load_state_dict(checkpoint['variational'])
202+
with torch.no_grad():
203+
test_elbo, test_log_p_x = evaluate(cfg.n_samples, model, variational, test_data)
204+
print(f'step:\t{step}\t\ttest elbo: {test_elbo:.2f}\ttest log p(x): {test_log_p_x:.2f}')
205+
break

0 commit comments

Comments
 (0)