The Official PyTorch Implementation of "Poisson Variational Autoencoder" (NeurIPS 2024 Spotlight Paper)
Welcome to the "Poisson Variational Autoencoder" (P-VAE) codebase! P-VAE is a brain-inspired generative model that unifies major theories in neuroscience with modern machine learning.
When trained on whitened natural image patches, the P-VAE learns sparse, "Gabor-like" features.
This is significant because if you stick an electrode into the primary visual cortex and record from actual neurons, this is the type of selectivity you would observe. Remarkably, the P-VAE develops a similar selectivity in a purely unsupervised manner, despite never being exposed to data from real neurons.
To learn more, check out:
- Research paper: https://openreview.net/forum?id=ektPEcqGLb
- X summary thread: https://x.com/hadivafaii/status/1794467115510227442
- Talk: https://www.youtube.com/live/Y9hP79tBXHo
./main/: Full architecture and training code for all four VAEs, including the P-VAE, reproducing paper results../base/distributions.py: Distributions used in the paper, including Poisson with our novel reparameterization algorithm../analysis/: Data analysis and result generation code../scripts/: Model fitting scripts (examples below).
We also provide a minimal PyTorch Lightning implementation of the P-VAE, stripped down to its essential components. This serves as an excellent starting point for understanding the model. Check it out:
To train a model, run:
cd scripts/
./fit_vae.sh <device> <dataset> <model> <archi><device>:int, CUDA device index.<dataset>:str, choices ={'vH16', 'CIFAR16', 'MNIST'}.<model>:str, choices ={'poisson', 'categorical', 'gaussian', 'laplace'}.<archi>:str, architecture format ={'lin|lin', 'conv+b|lin', 'conv+b|conv+b'}(interpreted asenc|dec).
In the paper, we refer to 'vH16' and 'CIFAR16' options as "van Hateren" and "CIFAR_16x16", respectively. In earlier versions of the code, the van Hateren dataset was also called DOVES. Therefore, vH16, van Hateren, and DOVES are interchangeable.
See ./main/train_vae.py for additional arguments. For example, you can set latent dimensionality to
./fit_vae.sh <device> <dataset> <model> <archi> --n_latents 1024 --kl_beta 2.5results.ipynb: Generates all VAE-related tables and figures from the paper.results_lca.py: Generates sparse coding results.
We provide four linear VAE model checkpoints trained with:
./fit_vae.sh 0 'vH16' <model> 'lin|lin'Checkpoints are located in ./checkpoints/ and can be loaded/visualized using load_models.ipynb. If additional model checkpoints would be helpful, feel free to reach out.
Download the processed datasets from the following links:
- Complete folder: Drive Link.
- Or individual datasets:
Place the downloaded data under ~/Datasets/ with the following structure:
~/Datasets/DOVES/vH16~/Datasets/CIFAR16/xtract16~/Datasets/MNIST/processed
For details, see the make_dataset() function in ./base/dataset.py.
If you use our code in your research, please cite our paper:
@inproceedings{vafaii2024poisson,
title={Poisson Variational Autoencoder},
author={Hadi Vafaii and Dekel Galor and Jacob L. Yates},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
year={2024},
url={https://openreview.net/forum?id=ektPEcqGLb},
}- For code-related questions, please open an issue in this repository.
- For paper-related questions, contact me at vafaii@berkeley.edu.

