Skip to content

Anjali-Kan/generative-models-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Generative Models: Normalizing Flows & Diffusion

Python JAX License

From-scratch implementations of RealNVP Normalizing Flows and Denoising Diffusion Probabilistic Models (DDPM) in JAX, evaluated on standard 2D benchmark distributions.

Model Comparison

Project Overview

This project implements and compares two powerful generative modeling approaches:

Model Paper Year Approach
RealNVP Density estimation using Real-NVP 2016 Normalizing Flow
DDPM Denoising Diffusion Probabilistic Models 2020 Diffusion Model

Key Findings

  • RealNVP provides exact likelihood computation and fast sampling
  • DDPM produces high quality samples through iterative denoising
  • Both models successfully learn complex 2D distributions
  • Trade-off: RealNVP is faster to train and sample, DDPM can capture more complex distributions with sufficient training

Supported Datasets

All datasets are standard benchmarks used in generative modeling research:

Dataset Description Difficulty
Moons Two interleaving half-circles Easy
Circles Concentric circles Easy
8 Gaussians Ring of 8 Gaussian blobs Medium
Checkerboard 2D checkerboard pattern Medium
Spirals Two interleaving spirals Hard
Pinwheel Multi-arm pinwheel Hard

All Datasets

Quick Start

Installation

git clone https://github.com/yourusername/generative-models-jax.git
cd generative-models-jax
pip install -r requirements.txt

Train Models

# Train RealNVP on moons dataset
python scripts/train.py --model realnvp --dataset moons

# Train DDPM on 8gaussians
python scripts/train.py --model ddpm --dataset 8gaussians

# Train both models on all datasets
python scripts/train.py --model both --dataset all

Generate Showcase Figures

python scripts/generate_figures.py

Architecture Details

RealNVP (Normalizing Flow)

Data x ──► Coupling Layer 1 ──► Coupling Layer 2 ──► ... ──► Latent z ~ N(0,I)
              │                      │
              └── MLP (s,t) ◄────────┘ (alternating splits)

Key components:

  • Affine coupling layers: Split input, transform one half conditioned on the other
  • Alternating pattern: Flip which half is transformed each layer
  • Exact likelihood: Change of variables formula gives exact log p(x)

Hyperparameters:

  • 4 coupling layers
  • 2-layer MLPs with 64 hidden units
  • Tanh-bounded scale for stability

DDPM (Diffusion Model)

Data x_0 ──► Add noise ──► x_1 ──► ... ──► x_L ~ N(0,I)
                                              │
                                              ▼ (reverse)
Generated ◄── Denoise ◄── ... ◄── Neural Net predicts mean

Key components:

  • Forward process: Gradually add Gaussian noise over L steps
  • Reverse process: Neural network predicts denoising step
  • Linear schedule: β₀² = 1e-4 to β_L² = 0.02

Hyperparameters:

  • 500-1000 diffusion steps
  • 3-layer MLP with 128-256 hidden units
  • Sinusoidal time embeddings (64-dim)

Project Structure

generative-models-jax/
├── src/
│   ├── models/
│   │   ├── realnvp.py      # RealNVP implementation
│   │   └── ddpm.py         # DDPM implementation
│   ├── data/
│   │   └── datasets.py     # 2D dataset generators
│   ├── training/
│   │   └── trainer.py      # Training loops
│   └── utils/
│       └── visualization.py # Plotting utilities
├── scripts/
│   ├── train.py            # Main training script
│   └── generate_figures.py # Generate showcase figures
├── results/
│   └── figures/            # Generated plots
└── requirements.txt

Results

RealNVP Results

The flow successfully learns to map complex distributions to a standard Gaussian:

RealNVP on Moons

RealNVP on 8 Gaussians

RealNVP on Checkerboard

DDPM Results

The diffusion model learns to denoise samples from pure noise back to the data distribution:

DDPM on Moons

Training Curves

Training Losses

Implementation Details

Change of Variables (RealNVP)

For a normalizing flow f: x → z, the log-likelihood is:

log p(x) = log p(z) + log |det ∂f/∂x|

The affine coupling layer ensures the Jacobian is triangular, making the determinant easy to compute:

# Affine transformation: z_b = x_b * exp(s) + t
# where s, t = MLP(x_a)
log_det = sum(s)  # Simple sum due to triangular Jacobian

Diffusion Process (DDPM)

Forward process (fixed):

q(x_t | x_{t-1}) = N(x_t; √(1-β_t) x_{t-1}, β_t I)

Reverse process (learned):

p_θ(x_{t-1} | x_t) = N(x_{t-1}; μ_θ(x_t, t), σ_t² I)

The network learns to predict the posterior mean μ_θ, enabling iterative denoising from pure noise.

Experiments to Try

# Compare on different datasets
for ds in moons circles 8gaussians checkerboard spirals pinwheel; do
    python scripts/train.py --model realnvp --dataset $ds
done

# Increase training epochs for better DDPM results
# Edit scripts/generate_figures.py: n_epochs=500 → n_epochs=1000

References

License

MIT License - feel free to use for learning and research.

Acknowledgments

  • Dataset generators inspired by scikit-learn and the original flow papers
  • DDPM implementation follows the original paper's formulation
  • Built with JAX for automatic differentiation and JIT compilation

About

From-scratch implementations of RealNVP Normalizing Flows and Denoising Diffusion Probabilistic Models (DDPM) in JAX, evaluated on standard 2D benchmark distributions.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages