Minimal JAX implementation that unifies Diffusion and Flow Matching under a shared interface.
jax_generative_models provides a unified interface for Diffusion and Flow Matching algorithms in JAX.
By abstracting these methods behind a common "strategy" interface, the project highlights both their
structural similarities and differences. With Tyro for configuration and Rerun for visualization, it
serves as a compact yet extensible base for experimenting with generative models.
| cat | moon | swiss-roll | mnist | |
|---|---|---|---|---|
| Diffusion (DDPM) | ![]() |
![]() |
![]() |
![]() |
| Flow Matching | ![]() |
![]() |
![]() |
![]() |
-
Prerequisites
- uv
A simple and fast Python package manager. Refer to the official documentation for one-command installation. - make
Used to run shortcuts such asmake setup. It is optional, so you can also run the commands in theMakefilemanually. - CUDA 12 (or a compatible version) and an NVIDIA GPU Required for GPU acceleration with JAX.
- uv
-
Clone the repository
cd <path-to-your-workspace> git clone https://github.com/MizuhoAOKI/jax_generative_models.git
-
Set up the virtual environment and install dependencies
cd jax_generative_models-
For CPU only:
make setup_cpu
-
For GPU with CUDA 12 support:
make setup_gpu_cuda12
If you are using a GPU, set up CUDA-related environment variables:
cd jax_generative_models source setup_gpu.sh
Run this in every new shell before executing training/generation scripts.
-
-
Train a generative model
uv run scripts/main.py train strategy:<STRATEGY_NAME> model:<MODEL_NAME> dataset:<DATASET_NAME>
-
Generate samples from a trained model
uv run scripts/main.py generate strategy:<STRATEGY_NAME> model:<MODEL_NAME> dataset:<DATASET_NAME>
-
Make an animation of the transport process after training
uv run scripts/main.py animate strategy:<STRATEGY_NAME> model:<MODEL_NAME> dataset:<DATASET_NAME>
-
Visualize training progress
Rerun is a visualization tool that allows you to monitor training progress in real time, or to inspect logged results afterward. Run Rerun from another terminal:
make rerun
rerun.mp4
You can replace the placeholders in the commands above with the following options. If omitted, each argument falls back to its default value.
| Placeholder | Options | Default | Description |
|---|---|---|---|
<STRATEGY_NAME> |
ddpm, flow-matching |
ddpm |
Generative modeling strategy. |
<MODEL_NAME> |
mlp, resnet, unet |
mlp |
Model architecture to use. |
<DATASET_NAME> |
cat, gaussian-mixture, moon, swiss-roll, mnist |
cat |
Target dataset for training/generation. |
For example, to train a DDPM with an MLP on the "cat" dataset (the default configuration), run:
uv run scripts/main.py train strategy:ddpm model:mlp dataset:catThe unet model is specifically designed for image data like mnist, so it should be used with image-based datasets.
The following command shows a recommended set of parameters for training unet on mnist using the flow-matching strategy:
uv run scripts/main.py train --batch-size 128 --vis.num-vis-samples 256 strategy:flow-matching --strategy.num-transport-steps 100 model:unet dataset:mnistYou can then generate samples conditioned on a specific digit. For example, the following command generates 100 images of the digit "5":
uv run scripts/main.py generate --condition 5 --num-samples 100 strategy:flow-matching model:unet dataset:mnistIf you omit the --condition option, the model will generate random digits from 0 to 9.
- Ho et al. Denoising Diffusion Probabilistic Models. 2020.
- Lipman et al. Flow Matching for Generative Modeling. 2023.
- Holderrieth et al. An Introduction to Flow Matching and Diffusion Models. 2025.
- Lipman et al. Flow Matching Guide and Code. 2024.







