Skip to content

Latest commit

 

History

History
120 lines (104 loc) · 4.63 KB

File metadata and controls

120 lines (104 loc) · 4.63 KB

UG-VAE

This repository contains the official Pytorch implementation of the Unsupervised Global VAE (UG-VAE) model proposed in the paper Unsupervised Learning of Global Factors in Deep Generative Models.

Please, if you use this code, cite the preprint using:

@article{peis2020unsupervised,
  title={Unsupervised Learning of Global Factors in Deep Generative Models},
  author={Peis, Ignacio and Olmos, Pablo M and Art{\'e}s-Rodr{\'\i}guez, Antonio},
  journal={arXiv preprint arXiv:2012.08234},
  year={2020}
}

Dependencies

torch 1.7.0
torchvision 0.8.1
matplotlib 3.3.3
numpy 1.19.4
pandas 1.1.4
scikit-learn 0.23.2

Usage

UG-VAE is implemented as a class that inherits from Pytorch nn.Module in models.py. You can train UG-VAE using the train.py script. A few examples are included below:

# example for training celeba:
python3 train.py --dataset celeba --arch beta_vae --epochs 10 --model_name celeba 

# example for training mnist
python3 train.py --dataset mnist --arch k_vae --dim_z 10 --dim_beta 20 --K 10 --model_name mnist

# example for training mixed celeba+faces
python3 train.py --dataset celeba_faces --arch beta_vae --dim_z 40 --dim_beta 40 --K 40 --model_name celeba_faces

The script will create a log dir in results/[model_name]. In /checkpoints/, model checkpoints for each log interval are stored. In /figs/ you will have reconstructions and samples at the end of each log interval, and a plot with the losses during the training procedure. Some remarkable arguments are:

  • --dataset: you can choose among celeba, mnist, celeba_faces, cars_chairs and some others included in datasets.py. For CelebA, 3D FACES, Cars dataset, and Chairs you have to download from the given links and put the images in data/[dataset]/img/
  • --arch: The architecture for encoder and decoder networks: use beta_vae for convolutional networks that work with 64x64 images (as CelebA), or k_vae for fully-connected networks that work with 28x28 images (as MNIST).
  • --model_name will be the name of the logdir stored in results/ folder.
  • --no_cuda for disabling GPU training.

Examples

UG-VAE learns both local and global disentanglement from random batches of data in a fully unsupervised manner, which leads to a promising performance in domain alignment and discovering non-trivial underlying structures. Some experiments are included in the experiments/ folder.

Interpolation

In experiments/interpolation.py you have an implementation of the Experiment 4.1 of the paper. By running:

python3 interpolation.py --dataset celeba --arch beta_vae --epochs 10 --model_name celeba 

you will store in results/[model_name]/figs/interpolation/) figures like the following:



Domain alignment

In experiments/domain_alignment.py you have an implementation of experiment 4.2 of the paper. By running:

python3 interpolation.py --dataset celeba --arch beta_vae --epochs 10 --model_name celeba 

you will store in results/[model_name]/figs/interpolation/) a similar figure than the following:



Global structure

In experiments/celeba_attributes.py and experiments/mnist_series.py you have an implementation of experiment 4.3 of the paper. By running:

python3 celeba_attributes.py

you will obtain figures like the following:


If you run:

python3 mnist_series.py

you will obtain figures like the following:


Contributors

Ignacio Peis, Pablo M. Olmos and Antonio Artés-Rodríguez.

For further information: ipeis@tsc.uc3m.es