From-scratch implementations of RealNVP Normalizing Flows and Denoising Diffusion Probabilistic Models (DDPM) in JAX, evaluated on standard 2D benchmark distributions.
This project implements and compares two powerful generative modeling approaches:
| Model | Paper | Year | Approach |
|---|---|---|---|
| RealNVP | Density estimation using Real-NVP | 2016 | Normalizing Flow |
| DDPM | Denoising Diffusion Probabilistic Models | 2020 | Diffusion Model |
- RealNVP provides exact likelihood computation and fast sampling
- DDPM produces high quality samples through iterative denoising
- Both models successfully learn complex 2D distributions
- Trade-off: RealNVP is faster to train and sample, DDPM can capture more complex distributions with sufficient training
All datasets are standard benchmarks used in generative modeling research:
| Dataset | Description | Difficulty |
|---|---|---|
| Moons | Two interleaving half-circles | Easy |
| Circles | Concentric circles | Easy |
| 8 Gaussians | Ring of 8 Gaussian blobs | Medium |
| Checkerboard | 2D checkerboard pattern | Medium |
| Spirals | Two interleaving spirals | Hard |
| Pinwheel | Multi-arm pinwheel | Hard |
git clone https://github.com/yourusername/generative-models-jax.git
cd generative-models-jax
pip install -r requirements.txt# Train RealNVP on moons dataset
python scripts/train.py --model realnvp --dataset moons
# Train DDPM on 8gaussians
python scripts/train.py --model ddpm --dataset 8gaussians
# Train both models on all datasets
python scripts/train.py --model both --dataset allpython scripts/generate_figures.pyData x ──► Coupling Layer 1 ──► Coupling Layer 2 ──► ... ──► Latent z ~ N(0,I)
│ │
└── MLP (s,t) ◄────────┘ (alternating splits)
Key components:
- Affine coupling layers: Split input, transform one half conditioned on the other
- Alternating pattern: Flip which half is transformed each layer
- Exact likelihood: Change of variables formula gives exact log p(x)
Hyperparameters:
- 4 coupling layers
- 2-layer MLPs with 64 hidden units
- Tanh-bounded scale for stability
Data x_0 ──► Add noise ──► x_1 ──► ... ──► x_L ~ N(0,I)
│
▼ (reverse)
Generated ◄── Denoise ◄── ... ◄── Neural Net predicts mean
Key components:
- Forward process: Gradually add Gaussian noise over L steps
- Reverse process: Neural network predicts denoising step
- Linear schedule: β₀² = 1e-4 to β_L² = 0.02
Hyperparameters:
- 500-1000 diffusion steps
- 3-layer MLP with 128-256 hidden units
- Sinusoidal time embeddings (64-dim)
generative-models-jax/
├── src/
│ ├── models/
│ │ ├── realnvp.py # RealNVP implementation
│ │ └── ddpm.py # DDPM implementation
│ ├── data/
│ │ └── datasets.py # 2D dataset generators
│ ├── training/
│ │ └── trainer.py # Training loops
│ └── utils/
│ └── visualization.py # Plotting utilities
├── scripts/
│ ├── train.py # Main training script
│ └── generate_figures.py # Generate showcase figures
├── results/
│ └── figures/ # Generated plots
└── requirements.txt
The flow successfully learns to map complex distributions to a standard Gaussian:
The diffusion model learns to denoise samples from pure noise back to the data distribution:
For a normalizing flow f: x → z, the log-likelihood is:
log p(x) = log p(z) + log |det ∂f/∂x|
The affine coupling layer ensures the Jacobian is triangular, making the determinant easy to compute:
# Affine transformation: z_b = x_b * exp(s) + t
# where s, t = MLP(x_a)
log_det = sum(s) # Simple sum due to triangular JacobianForward process (fixed):
q(x_t | x_{t-1}) = N(x_t; √(1-β_t) x_{t-1}, β_t I)
Reverse process (learned):
p_θ(x_{t-1} | x_t) = N(x_{t-1}; μ_θ(x_t, t), σ_t² I)
The network learns to predict the posterior mean μ_θ, enabling iterative denoising from pure noise.
# Compare on different datasets
for ds in moons circles 8gaussians checkerboard spirals pinwheel; do
python scripts/train.py --model realnvp --dataset $ds
done
# Increase training epochs for better DDPM results
# Edit scripts/generate_figures.py: n_epochs=500 → n_epochs=1000- Density estimation using Real-NVP - Dinh et al., 2016
- Denoising Diffusion Probabilistic Models - Ho et al., 2020
- Tutorial on Normalizing Flows - Kobyzev et al., 2020
- Understanding Diffusion Models - Luo, 2022
MIT License - feel free to use for learning and research.
- Dataset generators inspired by scikit-learn and the original flow papers
- DDPM implementation follows the original paper's formulation
- Built with JAX for automatic differentiation and JIT compilation






