1- """Fit a VAE to MNIST.
1+ """Fit a variational autoencoder to MNIST.
22
3- Conventions:
3+ Notes:
4+ - run https://github.com/altosaar/proximity_vi/blob/master/get_binary_mnist.py to download binary MNIST file
45 - batch size is the innermost dimension, then the sample dimension, then latent dimension
56"""
67import torch
78import torch .utils
9+ import torch .utils .data
810from torch import nn
911import nomen
1012import yaml
1113import numpy as np
1214import logging
13-
14- import data
15+ import pathlib
16+ import h5py
1517
1618config = """
1719latent_size: 128
2022batch_size: 128
2123test_batch_size: 512
2224max_iterations: 100000
23- log_interval: 1000
24- n_samples: 77
25+ log_interval: 5000
26+ n_samples: 128
27+ use_gpu: true
28+ train_dir: $TMPDIR
2529"""
2630
27- class NeuralNetwork (nn .Module ):
28- def __init__ (self , input_size , output_size , hidden_size ):
29- super ().__init__ ()
30- modules = [nn .Linear (input_size , hidden_size ),
31- nn .ReLU (),
32- nn .Linear (hidden_size , hidden_size ),
33- nn .ReLU (),
34- nn .Linear (hidden_size , output_size )]
35- self .net = nn .Sequential (* modules )
36-
37- def forward (self , input ):
38- return self .net (input )
39-
40-
4131
4232class Model (nn .Module ):
4333 """Bernoulli model parameterized by a generative network with Gaussian latents for MNIST."""
44- def __init__ (self , latent_size , data_size , batch_size ):
34+ def __init__ (self , latent_size , data_size , batch_size , device ):
4535 super ().__init__ ()
46- # prior on latents is standard normal
47- self . p_z = torch .distributions . Normal ( torch . zeros (latent_size ), torch . ones ( latent_size ))
48- # likelihood is bernoulli, equivalent to negative binary cross entropy
36+ self . p_z = torch . distributions . Normal (
37+ torch .zeros (latent_size , device = device ),
38+ torch . ones ( latent_size , device = device ))
4939 self .log_p_x = BernoulliLogProb ()
50- # generative network is a MLP
51- self . generative_network = NeuralNetwork ( input_size = latent_size , output_size = data_size , hidden_size = latent_size * 2 )
52-
40+ self . generative_network = NeuralNetwork ( input_size = latent_size ,
41+ output_size = data_size ,
42+ hidden_size = latent_size * 2 )
5343
5444 def forward (self , z , x ):
5545 """Return log probability of model."""
5646 log_p_z = self .p_z .log_prob (z ).sum (- 1 )
5747 logits = self .generative_network (z )
48+ # unsqueeze sample dimension
49+ logits , x = torch .broadcast_tensors (logits , x .unsqueeze (1 ))
5850 log_p_x = self .log_p_x (logits , x ).sum (- 1 )
5951 return log_p_z + log_p_x
6052
61-
62- class NormalLogProb (nn .Module ):
63- def __init__ (self ):
64- super ().__init__ ()
65-
66- def forward (self , loc , scale , z ):
67- var = torch .pow (scale , 2 )
68- return - 0.5 * torch .log (2 * np .pi * var ) + torch .pow (z - loc , 2 ) / (2 * var )
69-
70- class BernoulliLogProb (nn .Module ):
71- def __init__ (self ):
72- super ().__init__ ()
73- self .bce_with_logits = nn .BCEWithLogitsLoss (reduction = 'none' )
74-
75- def forward (self , logits , target ):
76- logits , target = torch .broadcast_tensors (logits , target .unsqueeze (1 ))
77- return - self .bce_with_logits (logits , target )
7853
7954class Variational (nn .Module ):
8055 """Approximate posterior parameterized by an inference network."""
8156 def __init__ (self , latent_size , data_size ):
8257 super ().__init__ ()
83- self .inference_network = NeuralNetwork (input_size = data_size , output_size = latent_size * 2 , hidden_size = latent_size * 2 )
58+ self .inference_network = NeuralNetwork (input_size = data_size ,
59+ output_size = latent_size * 2 ,
60+ hidden_size = latent_size * 2 )
8461 self .log_q_z = NormalLogProb ()
8562 self .softplus = nn .Softplus ()
8663
8764 def forward (self , x , n_samples = 1 ):
8865 """Return sample of latent variable and log prob."""
8966 loc , scale_arg = torch .chunk (self .inference_network (x ).unsqueeze (1 ), chunks = 2 , dim = - 1 )
9067 scale = self .softplus (scale_arg )
91- eps = torch .randn ((loc .shape [0 ], n_samples , loc .shape [- 1 ]))
68+ eps = torch .randn ((loc .shape [0 ], n_samples , loc .shape [- 1 ]), device = loc . device )
9269 z = loc + scale * eps # reparameterization
9370 log_q_z = self .log_q_z (loc , scale , z ).sum (- 1 )
9471 return z , log_q_z
9572
9673
74+ class NeuralNetwork (nn .Module ):
75+ def __init__ (self , input_size , output_size , hidden_size ):
76+ super ().__init__ ()
77+ modules = [nn .Linear (input_size , hidden_size ),
78+ nn .ReLU (),
79+ nn .Linear (hidden_size , hidden_size ),
80+ nn .ReLU (),
81+ nn .Linear (hidden_size , output_size )]
82+ self .net = nn .Sequential (* modules )
83+
84+ def forward (self , input ):
85+ return self .net (input )
86+
87+
88+ class NormalLogProb (nn .Module ):
89+ def __init__ (self ):
90+ super ().__init__ ()
91+
92+ def forward (self , loc , scale , z ):
93+ var = torch .pow (scale , 2 )
94+ return - 0.5 * torch .log (2 * np .pi * var ) - torch .pow (z - loc , 2 ) / (2 * var )
95+
96+
97+ class BernoulliLogProb (nn .Module ):
98+ def __init__ (self ):
99+ super ().__init__ ()
100+ self .bce_with_logits = nn .BCEWithLogitsLoss (reduction = 'none' )
101+
102+ def forward (self , logits , target ):
103+ # bernoulli log prob is equivalent to negative binary cross entropy
104+ return - self .bce_with_logits (logits , target )
105+
106+
97107def cycle (iterable ):
98108 while True :
99109 for x in iterable :
100110 yield x
101111
102112
113+ def load_binary_mnist (cfg , ** kwcfg ):
114+ f = h5py .File (pathlib .os .path .join (pathlib .os .environ ['DAT' ], 'binarized_mnist.hdf5' ), 'r' )
115+ x_train = f ['train' ][::]
116+ x_val = f ['valid' ][::]
117+ x_test = f ['test' ][::]
118+ train = torch .utils .data .TensorDataset (torch .from_numpy (x_train ))
119+ train_loader = torch .utils .data .DataLoader (train , batch_size = cfg .batch_size , shuffle = True )
120+ validation = torch .utils .data .TensorDataset (torch .from_numpy (x_val ))
121+ val_loader = torch .utils .data .DataLoader (validation , batch_size = cfg .test_batch_size , shuffle = False )
122+ test = torch .utils .data .TensorDataset (torch .from_numpy (x_test ))
123+ test_loader = torch .utils .data .DataLoader (test , batch_size = cfg .test_batch_size , shuffle = False )
124+ return train_loader , val_loader , test_loader
125+
126+
103127def evaluate (n_samples , model , variational , eval_data ):
104128 model .eval ()
105129 total_log_p_x = 0.0
106130 total_elbo = 0.0
107131 for batch in eval_data :
108- x = batch [0 ]
132+ x = batch [0 ]. to ( next ( model . parameters ()). device )
109133 z , log_q_z = variational (x , n_samples )
110134 log_p_x_and_z = model (z , x )
111- # importance sampling of approximate marginal likelihood
112- # using logsumexp in the sample dimension
135+ # importance sampling of approximate marginal likelihood with q(z)
136+ # as the proposal, and logsumexp in the sample dimension
113137 elbo = log_p_x_and_z - log_q_z
114138 log_p_x = torch .logsumexp (elbo , dim = 1 ) - np .log (n_samples )
115139 # average over sample dimension, sum over minibatch
@@ -123,28 +147,59 @@ def evaluate(n_samples, model, variational, eval_data):
123147if __name__ == '__main__' :
124148 dictionary = yaml .load (config )
125149 cfg = nomen .Config (dictionary )
126-
127- model = Model (latent_size = cfg .latent_size , data_size = cfg .data_size , batch_size = cfg .batch_size )
128- variational = Variational (latent_size = cfg .latent_size , data_size = cfg .data_size )
150+ device = torch .device ("cuda:0" if cfg .use_gpu else "cpu" )
151+
152+ model = Model (latent_size = cfg .latent_size ,
153+ data_size = cfg .data_size ,
154+ batch_size = cfg .batch_size ,
155+ device = device )
156+ variational = Variational (latent_size = cfg .latent_size ,
157+ data_size = cfg .data_size )
158+ model .to (device )
159+ variational .to (device )
160+
161+ optimizer = torch .optim .RMSprop (list (model .parameters ()) +
162+ list (variational .parameters ()),
163+ lr = cfg .learning_rate ,
164+ centered = True )
129165
130- optimizer = torch . optim . RMSprop ( list ( model . parameters ()) + list ( variational . parameters ()),
131- lr = cfg . learning_rate )
166+ kwargs = { 'num_workers' : 0 , 'pin_memory' : False } if cfg . use_gpu else {}
167+ train_data , valid_data , test_data = load_binary_mnist ( cfg , ** kwargs )
132168
133- train_data , valid_data , test_data = data .load_binary_mnist (cfg )
169+ best_valid_elbo = - np .inf
170+ num_no_improvement = 0
134171
135172 for step , batch in enumerate (cycle (train_data )):
136- x = batch [0 ]
173+ x = batch [0 ]. to ( device )
137174 model .zero_grad ()
138175 variational .zero_grad ()
139176 z , log_q_z = variational (x )
140177 log_p_x_and_z = model (z , x )
178+ # average over sample dimension
141179 elbo = (log_p_x_and_z - log_q_z ).mean (1 )
142- loss = - elbo .mean (0 )
180+ # sum over batch dimension
181+ loss = - elbo .sum (0 )
143182 loss .backward ()
144183 optimizer .step ()
145184
146185 if step % cfg .log_interval == 0 :
147- print (f'step:\t { step } \t train elbo: { elbo .detach ().cpu ().numpy ()[ 0 ] :.2f} ' )
186+ print (f'step:\t { step } \t train elbo: { elbo .detach ().cpu ().numpy (). mean () :.2f} ' )
148187 with torch .no_grad ():
149188 valid_elbo , valid_log_p_x = evaluate (cfg .n_samples , model , variational , valid_data )
150- print (f'step:\t { step } \t valid elbo: { valid_elbo :.2f} \t valid log p(x): { valid_log_p_x :.2f} ' )
189+ print (f'step:\t { step } \t \t valid elbo: { valid_elbo :.2f} \t valid log p(x): { valid_log_p_x :.2f} ' )
190+ if valid_elbo > best_valid_elbo :
191+ best_valid_elbo = valid_elbo
192+ states = {'model' : model .state_dict (),
193+ 'variational' : variational .state_dict ()}
194+ torch .save (states , cfg .train_dir / 'best_state_dict' )
195+ else :
196+ num_no_improvement += 1
197+
198+ if num_no_improvement > 5 :
199+ checkpoint = torch .load (cfg .train_dir / 'best_state_dict' )
200+ model .load_state_dict (checkpoint ['model' ])
201+ variational .load_state_dict (checkpoint ['variational' ])
202+ with torch .no_grad ():
203+ test_elbo , test_log_p_x = evaluate (cfg .n_samples , model , variational , test_data )
204+ print (f'step:\t { step } \t \t test elbo: { test_elbo :.2f} \t test log p(x): { test_log_p_x :.2f} ' )
205+ break
0 commit comments