Skip to content

Releases: MizuhoAOKI/jax_generative_models

v0.2.0: Add MNIST Example

19 Dec 08:43
f527916

Choose a tag to compare

Release Description

Add an MNIST example to jax_generative_models.

Key Features

  • Model Architecture: Introduce a unet model for image generation.
  • Datasets: Add the mnist dataset.
  • Configuration: Use Tyro arguments to specify the conditions under which the trained model generates outputs.
    uv run scripts/main.py generate --condition 5 --num-samples 100 strategy:flow-matching model:unet dataset:mnist
    image

Generation Results

mnist
Diffusion (DDPM)
Flow Matching

v0.1.0: Initial Release of JAX Generative Models - Unified Interface for Diffusion and Flow Matching

05 Dec 10:21

Choose a tag to compare

Release Description

This is the initial release of jax_generative_models 🐱
It is a minimal JAX codebase unifying Diffusion and Flow Matching algorithms as alternative strategies for transporting data distributions.

Key Features

  • Unified Strategy Interface: Seamlessly switch between ddpm and flow-matching algorithms.
  • Model Architectures: Includes implementations for basic mlp and resnet models.
  • Toy Datasets: Built-in support for 2D datasets (cat, gaussian-mixture, moon, swiss-roll) for rapid experimentation.
  • Visualization: Integrated with Rerun for real-time training monitoring and transport process visualization.
  • Configuration: Type-safe configuration management using Tyro.

Technical Stack

  • Core: Python 3.12, JAX, Equinox, Tyro
  • Management: uv (Package Manager), Makefile support

Getting Started

Please refer to the README for installation instructions (CPU/CUDA 12) and usage examples.

jax_generative_models.mp4