Skip to content

Ceyron/exponax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

261 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

exponax logo

Efficient Differentiable n-d PDE solvers built on top of JAX & Equinox.

PyPI Tests codecov docs-latest Changelog License

Installation β€’ Quickstart β€’ Equations β€’ Features β€’ Documentation β€’ Background β€’ Citation

Exponax solves partial differential equations in 1D, 2D, and 3D on periodic domains highly efficiently using Fourier spectral methods and exponential time differencing. It ships more than 46 PDE solvers covering linear, nonlinear, and reaction-diffusion dynamics. Built entirely on JAX and Equinox, every solver is automatically differentiable, JIT-compilable, and GPU/TPU-ready β€” making it ideal for physics-based deep learning workflows.

Installation

pip install exponax

Requires Python 3.10+ and JAX 0.4.13+. πŸ‘‰ JAX install guide.

Quickstart

Simulate the chaotic Kuramoto-Sivashinsky equation in 1D β€” a single stepper object, one line to roll out 500 time steps:

import jax
import exponax as ex
import matplotlib.pyplot as plt

ks_stepper = ex.stepper.KuramotoSivashinskyConservative(
    num_spatial_dims=1, domain_extent=100.0,
    num_points=200, dt=0.1,
)

u_0 = ex.ic.RandomTruncatedFourierSeries(
    num_spatial_dims=1, cutoff=5
)(num_points=200, key=jax.random.PRNGKey(0))

trajectory = ex.rollout(ks_stepper, 500, include_init=True)(u_0)

plt.imshow(trajectory[:, 0, :].T, aspect='auto', cmap='RdBu', vmin=-2, vmax=2, origin="lower")
plt.xlabel("Time"); plt.ylabel("Space"); plt.show()

Because every stepper is a differentiable JAX function, you can freely compose it with jax.grad, jax.vmap, and jax.jit:

# Jacobian of the stepper function
jacobian = jax.jacfwd(ks_stepper)(u_0)

For a next step, check out this tutorial on 1D Advection that explains the basics of Exponax.

Built-in Equations

Linear

Equation Stepper Dimensions
Advection: $u_t + c \cdot \nabla u = 0$ Advection 1D, 2D, 3D
Diffusion: $u_t = \nu \Delta u$ Diffusion 1D, 2D, 3D
Advection-Diffusion: $u_t + c \cdot \nabla u = \nu \Delta u$ AdvectionDiffusion 1D, 2D, 3D
Dispersion: $u_t = \xi \nabla^3 u$ Dispersion 1D, 2D, 3D
Hyper-Diffusion: $u_t = -\zeta \Delta^2 u$ HyperDiffusion 1D, 2D, 3D

Nonlinear

Equation Stepper Dimensions
Burgers: $u_t + \frac{1}{2} \nabla \cdot (u \otimes u) = \nu \Delta u$ Burgers 1D, 2D, 3D
Korteweg-de Vries: $u_t + \frac{1}{2} \nabla \cdot (u \otimes u) - \nabla^3 u = \mu \Delta u$ KortewegDeVries 1D, 2D, 3D
Kuramoto-Sivashinsky: $u_t + \frac{1}{2} |\nabla u|^2 + \Delta u + \Delta^2 u = 0$ KuramotoSivashinsky 1D, 2D, 3D
KS (conservative): $u_t + \frac{1}{2} \nabla \cdot (u \otimes u) + \Delta u + \Delta^2 u = 0$ KuramotoSivashinskyConservative 1D, 2D, 3D
Navier-Stokes (vorticity): $\omega_t + (u \cdot \nabla)\omega = \nu \Delta \omega$ NavierStokesVorticity 2D
Kolmogorov Flow (vorticity): $\omega_t + (u \cdot \nabla)\omega = \nu \Delta \omega + f$ KolmogorovFlowVorticity 2D

Reaction-Diffusion

Equation Stepper Dimensions
Fisher-KPP: $u_t = \nu \Delta u + r, u(1 - u)$ reaction.FisherKPP 1D, 2D, 3D
Allen-Cahn: $u_t = \nu \Delta u + c_1 u + c_3 u^3$ reaction.AllenCahn 1D, 2D, 3D
Cahn-Hilliard: $u_t = \nu \Delta(u^3 + c_1 u - \gamma \Delta u)$ reaction.CahnHilliard 1D, 2D, 3D
Gray-Scott: $u_t = \nu_1 \Delta u + f(1-u) - uv^2, \quad v_t = \nu_2 \Delta v - (f+k)v + uv^2$ reaction.GrayScott 1D, 2D, 3D
Swift-Hohenberg: $u_t = ru - (k + \Delta)^2 u + g(u)$ reaction.SwiftHohenberg 1D, 2D, 3D
Generic stepper families (for advanced / custom dynamics)

These parametric families generalize the concrete steppers above. Each comes in three flavors: physical coefficients, normalized, and difficulty-based.

Family Nonlinearity Generalizes
GeneralLinearStepper None Advection, Diffusion, Dispersion, etc.
GeneralConvectionStepper Quadratic convection Burgers, KdV, KS Conservative
GeneralGradientNormStepper Gradient norm Kuramoto-Sivashinsky
GeneralVorticityConvectionStepper Vorticity convection (2D only) Navier-Stokes, Kolmogorov Flow
GeneralPolynomialStepper Arbitrary polynomial Fisher-KPP, Allen-Cahn, etc.
GeneralNonlinearStepper Convection + gradient norm + polynomial Most of the above

See the normalized & difficulty interface docs for details.

Features

  • Hardware-agnostic β€” run on CPU, GPU, or TPU in single or double precision.
  • Fully differentiable β€” compute gradients of solutions w.r.t. initial conditions, PDE parameters, or neural network weights when composed with PDE solvers via jax.grad.
  • Vectorized batching β€” advance multiple states or sweep over parameter grids in parallel using jax.vmap (and eqx.filter_vmap).
  • Deep-learning native β€” every stepper is an Equinox Module, composable with neural networks out of the box.
  • Lightweight design β€” no custom grid or state objects; everything is plain jax.numpy arrays and callable PyTrees.
  • Initial conditions β€” library of random IC distributions (truncated Fourier series, Gaussian random fields, etc.).
  • Utilities β€” spectral derivatives, grid creation, autoregressive rollout, interpolation, and more.
  • Extensible β€” add new PDEs by subclassing BaseStepper.

Documentation

Documentation is available at fkoehler.site/exponax. Key pages:

Background

Exponax solves semi-linear PDEs of the form

$$ \partial u / \partial t = Lu + N(u), $$

where $L$ is a linear differential operator and $N$ is a nonlinear differential operator. The linear part is solved exactly via a matrix exponential in Fourier space, while the nonlinear part is integrated using exponential time differencing Runge-Kutta (ETDRK) schemes of order 1 through 4. The complex contour integral method of Kassam & Trefethen is used for numerical stability.

By restricting to periodic domains on scaled hypercubes with uniform Cartesian grids, all transforms reduce to FFTs β€” yielding blazing-fast simulations. For example, 50 trajectories of the 2D Kuramoto-Sivashinsky equation (200 time steps, 128x128 grid) are generated in under a second on a modern GPU.

References
  1. Cox, S.M. and Matthews, P.C. "Exponential time differencing for stiff systems." Journal of Computational Physics 176.2 (2002): 430-455. doi:10.1006/jcph.2002.6995
  2. Kassam, A.K. and Trefethen, L.N. "Fourth-order time-stepping for stiff PDEs." SIAM Journal on Scientific Computing 26.4 (2005): 1214-1233. doi:10.1137/S1064827502410633
  3. Montanelli, H. and Bootland, N. "Solving periodic semilinear stiff PDEs in 1D, 2D and 3D with exponential integrators." Mathematics and Computers in Simulation 178 (2020): 307-327. doi:10.1016/j.matcom.2020.06.008

Related & Motivation

This package is greatly inspired by the spinX module of the ChebFun package in MATLAB. spinX served as a reliable data generator for early works in physics-based deep learning, e.g., DeepHiddenPhysics and Fourier Neural Operators. However, due to the two-language barrier, dynamically calling MATLAB solvers from Python-based deep learning workflows is hard to impossible. This also excludes the option to differentiate through them β€” ruling out differentiable-physics approaches like solver-in-the-loop correction or diverted-chain training.

We view Exponax as a spiritual successor of spinX. JAX, as the computational backend, elevates the power of this solver type with automatic vectorization (jax.vmap), backend-agnostic execution (CPU/GPU/TPU), and tight integration for deep learning via its versatile automatic differentiation engine. With reproducible randomness in JAX, datasets can be re-created in seconds β€” no need to ever write them to disk.

Beyond ChebFun, other popular pseudo-spectral implementations include Dedalus in the Python world and FourierFlows.jl in the Julia ecosystem (the latter was especially helpful for verifying our implementation of the contour integral method and dealiasing).

Citation

Exponax was developed as part of the APEBench benchmark suite for autoregressive neural emulators of PDEs. The accompanying paper was accepted at NeurIPS 2024. If you find this package useful for your research, please consider citing it:

@article{koehler2024apebench,
  title={{APEBench}: A Benchmark for Autoregressive Neural Emulators of {PDE}s},
  author={Felix Koehler and Simon Niedermayr and R{\"u}diger Westermann and Nils Thuerey},
  journal={Advances in Neural Information Processing Systems (NeurIPS)},
  volume={38},
  year={2024}
}

If you enjoy the project, feel free to give it a star on GitHub!

Funding

The main author (Felix Koehler) is a PhD student in the group of Prof. Thuerey at TUM and his research is funded by the Munich Center for Machine Learning.

License

MIT, see here


fkoehler.site Β Β·Β  GitHub @ceyron Β Β·Β  X @felix_m_koehler Β Β·Β  LinkedIn Felix KΓΆhler

Packages

No packages published

Contributors 4

  •  
  •  
  •  
  •