Installation β’ Quickstart β’ Equations β’ Features β’ Documentation β’ Background β’ Citation
Exponax solves partial differential equations in 1D, 2D, and 3D on periodic
domains highly efficiently using Fourier spectral methods and exponential time
differencing. It ships more than 46 PDE solvers covering linear, nonlinear, and
reaction-diffusion dynamics. Built entirely on
JAX and
Equinox, every solver is
automatically differentiable, JIT-compilable, and GPU/TPU-ready β making it
ideal for physics-based deep learning workflows.
pip install exponaxRequires Python 3.10+ and JAX 0.4.13+. π JAX install guide.
Simulate the chaotic Kuramoto-Sivashinsky equation in 1D β a single stepper object, one line to roll out 500 time steps:
import jax
import exponax as ex
import matplotlib.pyplot as plt
ks_stepper = ex.stepper.KuramotoSivashinskyConservative(
num_spatial_dims=1, domain_extent=100.0,
num_points=200, dt=0.1,
)
u_0 = ex.ic.RandomTruncatedFourierSeries(
num_spatial_dims=1, cutoff=5
)(num_points=200, key=jax.random.PRNGKey(0))
trajectory = ex.rollout(ks_stepper, 500, include_init=True)(u_0)
plt.imshow(trajectory[:, 0, :].T, aspect='auto', cmap='RdBu', vmin=-2, vmax=2, origin="lower")
plt.xlabel("Time"); plt.ylabel("Space"); plt.show()Because every stepper is a differentiable JAX function, you can freely compose
it with jax.grad, jax.vmap, and jax.jit:
# Jacobian of the stepper function
jacobian = jax.jacfwd(ks_stepper)(u_0)For a next step, check out this tutorial on 1D
Advection
that explains the basics of Exponax.
| Equation | Stepper | Dimensions |
|---|---|---|
| Advection: |
Advection |
1D, 2D, 3D |
| Diffusion: |
Diffusion |
1D, 2D, 3D |
| Advection-Diffusion: |
AdvectionDiffusion |
1D, 2D, 3D |
| Dispersion: |
Dispersion |
1D, 2D, 3D |
| Hyper-Diffusion: |
HyperDiffusion |
1D, 2D, 3D |
| Equation | Stepper | Dimensions |
|---|---|---|
| Burgers: |
Burgers |
1D, 2D, 3D |
| Korteweg-de Vries: |
KortewegDeVries |
1D, 2D, 3D |
| Kuramoto-Sivashinsky: |
KuramotoSivashinsky |
1D, 2D, 3D |
| KS (conservative): |
KuramotoSivashinskyConservative |
1D, 2D, 3D |
| Navier-Stokes (vorticity): |
NavierStokesVorticity |
2D |
| Kolmogorov Flow (vorticity): |
KolmogorovFlowVorticity |
2D |
| Equation | Stepper | Dimensions |
|---|---|---|
| Fisher-KPP: |
reaction.FisherKPP |
1D, 2D, 3D |
| Allen-Cahn: |
reaction.AllenCahn |
1D, 2D, 3D |
| Cahn-Hilliard: |
reaction.CahnHilliard |
1D, 2D, 3D |
| Gray-Scott: |
reaction.GrayScott |
1D, 2D, 3D |
| Swift-Hohenberg: |
reaction.SwiftHohenberg |
1D, 2D, 3D |
Generic stepper families (for advanced / custom dynamics)
These parametric families generalize the concrete steppers above. Each comes in three flavors: physical coefficients, normalized, and difficulty-based.
| Family | Nonlinearity | Generalizes |
|---|---|---|
GeneralLinearStepper |
None | Advection, Diffusion, Dispersion, etc. |
GeneralConvectionStepper |
Quadratic convection | Burgers, KdV, KS Conservative |
GeneralGradientNormStepper |
Gradient norm | Kuramoto-Sivashinsky |
GeneralVorticityConvectionStepper |
Vorticity convection (2D only) | Navier-Stokes, Kolmogorov Flow |
GeneralPolynomialStepper |
Arbitrary polynomial | Fisher-KPP, Allen-Cahn, etc. |
GeneralNonlinearStepper |
Convection + gradient norm + polynomial | Most of the above |
See the normalized & difficulty interface docs for details.
- Hardware-agnostic β run on CPU, GPU, or TPU in single or double precision.
- Fully differentiable β compute gradients of solutions w.r.t. initial conditions, PDE parameters, or neural network weights when composed with PDE solvers via
jax.grad. - Vectorized batching β advance multiple states or sweep over parameter grids in parallel using
jax.vmap(andeqx.filter_vmap). - Deep-learning native β every stepper is an Equinox Module, composable with neural networks out of the box.
- Lightweight design β no custom grid or state objects; everything is plain
jax.numpyarrays and callable PyTrees. - Initial conditions β library of random IC distributions (truncated Fourier series, Gaussian random fields, etc.).
- Utilities β spectral derivatives, grid creation, autoregressive rollout, interpolation, and more.
- Extensible β add new PDEs by subclassing
BaseStepper.
Documentation is available at fkoehler.site/exponax. Key pages:
- 1D Advection Tutorial β learn the basics
- Solver Showcase 1D / 2D / 3D β visual gallery of all dynamics
- Creating Your Own Solvers β extend Exponax with custom PDEs
- Training a Neural Operator β use
Exponaxfor synthetic data generation and training of a neural emulator - Stepper Overview β API reference for all steppers
- Performance Hints β tips for fast simulations
Exponax solves semi-linear PDEs of the form
where
By restricting to periodic domains on scaled hypercubes with uniform Cartesian grids, all transforms reduce to FFTs β yielding blazing-fast simulations. For example, 50 trajectories of the 2D Kuramoto-Sivashinsky equation (200 time steps, 128x128 grid) are generated in under a second on a modern GPU.
References
- Cox, S.M. and Matthews, P.C. "Exponential time differencing for stiff systems." Journal of Computational Physics 176.2 (2002): 430-455. doi:10.1006/jcph.2002.6995
- Kassam, A.K. and Trefethen, L.N. "Fourth-order time-stepping for stiff PDEs." SIAM Journal on Scientific Computing 26.4 (2005): 1214-1233. doi:10.1137/S1064827502410633
- Montanelli, H. and Bootland, N. "Solving periodic semilinear stiff PDEs in 1D, 2D and 3D with exponential integrators." Mathematics and Computers in Simulation 178 (2020): 307-327. doi:10.1016/j.matcom.2020.06.008
This package is greatly inspired by the
spinX module of the
ChebFun package in MATLAB. spinX served as a
reliable data generator for early works in physics-based deep learning, e.g.,
DeepHiddenPhysics
and Fourier Neural
Operators.
However, due to the two-language barrier, dynamically calling MATLAB solvers
from Python-based deep learning workflows is hard to impossible. This also
excludes the option to differentiate through them β ruling out
differentiable-physics approaches like solver-in-the-loop correction or
diverted-chain training.
We view Exponax as a spiritual successor of spinX. JAX, as the
computational backend, elevates the power of this solver type with automatic
vectorization (jax.vmap), backend-agnostic execution (CPU/GPU/TPU), and tight
integration for deep learning via its versatile automatic differentiation
engine. With reproducible randomness in JAX, datasets can be re-created in
seconds β no need to ever write them to disk.
Beyond ChebFun, other popular pseudo-spectral implementations include Dedalus in the Python world and FourierFlows.jl in the Julia ecosystem (the latter was especially helpful for verifying our implementation of the contour integral method and dealiasing).
Exponax was developed as part of the
APEBench benchmark suite for
autoregressive neural emulators of PDEs. The accompanying paper was accepted at
NeurIPS 2024. If you find this package useful for your research, please
consider citing it:
@article{koehler2024apebench,
title={{APEBench}: A Benchmark for Autoregressive Neural Emulators of {PDE}s},
author={Felix Koehler and Simon Niedermayr and R{\"u}diger Westermann and Nils Thuerey},
journal={Advances in Neural Information Processing Systems (NeurIPS)},
volume={38},
year={2024}
}If you enjoy the project, feel free to give it a star on GitHub!
The main author (Felix Koehler) is a PhD student in the group of Prof. Thuerey at TUM and his research is funded by the Munich Center for Machine Learning.
MIT, see here
fkoehler.site Β Β·Β GitHub @ceyron Β Β·Β X @felix_m_koehler Β Β·Β LinkedIn Felix KΓΆhler

