diff --git a/diffrax/__init__.py b/diffrax/__init__.py index d8aca8d4..e9b9cc44 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -73,6 +73,7 @@ Ralston, ReversibleHeun, SemiImplicitEuler, + StormerVerlet, Sil3, StratonovichMilstein, Tsit5, diff --git a/diffrax/solver/__init__.py b/diffrax/solver/__init__.py index ace213c4..dc1a0db1 100644 --- a/diffrax/solver/__init__.py +++ b/diffrax/solver/__init__.py @@ -36,5 +36,6 @@ MultiButcherTableau, ) from .semi_implicit_euler import SemiImplicitEuler +from .stormer_verlet import StormerVerlet from .sil3 import Sil3 from .tsit5 import Tsit5 diff --git a/diffrax/solver/stormer_verlet.py b/diffrax/solver/stormer_verlet.py new file mode 100644 index 00000000..52e3f582 --- /dev/null +++ b/diffrax/solver/stormer_verlet.py @@ -0,0 +1,78 @@ +from typing import Tuple + +from equinox.internal import ω + +from ..custom_types import Bool, DenseInfo, PyTree, Scalar +from ..local_interpolation import LocalLinearInterpolation +from ..solution import RESULTS +from ..term import AbstractTerm +from .base import AbstractSolver + +_ErrorEstimate = None +_SolverState = None + +class StormerVerlet(AbstractSolver): + """ Störmer-Verlet method. + + Symplectic method. Does not support adaptive step sizing. Uses 1st order local + linear interpolation for dense/ts output. + """ + + term_structure = (AbstractTerm, AbstractTerm) + interpolation_cls = LocalLinearInterpolation + + def order(self, terms): + return 2 + + def init( + self, + terms: Tuple[AbstractTerm, AbstractTerm], + t0: Scalar, + t1: Scalar, + y0: PyTree, + args: PyTree, + ) -> _SolverState: + return None + + def step( + self, + terms: Tuple[AbstractTerm, AbstractTerm], + t0: Scalar, + t1: Scalar, + y0: Tuple[PyTree, PyTree], + args: PyTree, + solver_state: _SolverState, + made_jump: Bool, + ) -> Tuple[Tuple[PyTree, PyTree], _ErrorEstimate, DenseInfo, _SolverState, RESULTS]: + del solver_state, made_jump + + term_1, term_2 = terms + y0_1, y0_2 = y0 + midpoint = (t1 + t0)/2 + + control1_half_1 = term_1.contr(t0, midpoint) + control1_half_2 = term_1.contr(midpoint, t1) + control2 = term_2.contr(t0, t1) + + yhalf_1 = (y0_1 ** ω + term_1.vf_prod(t0, y0_2, args, control1_half_1) ** ω).ω + y1_2 = (y0_2 ** ω + term_2.vf_prod(midpoint, yhalf_1, args, control2) ** ω).ω + y1_1 = (yhalf_1 ** ω + term_1.vf_prod(t1, y1_2, args, control1_half_2 ** ω)).ω + + y1 = (y1_1, y1_2) + dense_info = dict(y0=y0, y1=y1) + return y1, None, dense_info, None, RESULTS.successful + + def func( + self, + terms: Tuple[AbstractTerm, AbstractTerm], + t0: Scalar, + y0: Tuple[PyTree, PyTree], + args: PyTree + ) -> Tuple[PyTree, PyTree]: + term_1, term_2 = terms + y0_1, y0_2 = y0 + f1 = term_1.func(t0, y0_2, args) + f2 = term_2.func(t0, y0_1, args) + return (f1, f2) + + \ No newline at end of file diff --git a/test/helpers.py b/test/helpers.py index b4764ffe..24dfb835 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -7,7 +7,6 @@ import jax.random as jrandom import jax.tree_util as jtu - all_ode_solvers = ( diffrax.Bosh3(), diffrax.Dopri5(), @@ -32,6 +31,11 @@ diffrax.KenCarp5(), ) +all_symplectic_solvers = ( + diffrax.SemiImplicitEuler(), + diffrax.StormerVerlet(), +) + def implicit_tol(solver): if isinstance(solver, diffrax.AbstractImplicitSolver): diff --git a/test/test_integrate.py b/test/test_integrate.py index d2e9cd84..dc09b8e3 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -15,6 +15,7 @@ from .helpers import ( all_ode_solvers, all_split_solvers, + all_symplectic_solvers, implicit_tol, random_pytree, shaped_allclose, @@ -165,6 +166,59 @@ def f(t, y, args): assert -0.9 < order - solver.order(term) < 0.9 +@pytest.mark.parametrize("solver", all_symplectic_solvers) +def test_symplectic_ode_order(solver): + solver = implicit_tol(solver) + key = jrandom.PRNGKey(17) + p_key, q_key, k_key = jrandom.split(key, 3) + p0 = jrandom.uniform(p_key, shape=(), minval=0, maxval=1) + q0 = jrandom.uniform(q_key, shape=(), minval=0, maxval=1) + k = jrandom.uniform(k_key, shape=(), minval=0.1, maxval=10) + y0 = (p0, q0) + t0 = 0 + t1 = 4 + + def p_vector_field(t, q, k): + return q + + def q_vector_field(t, p, k): + return -k * p + + def analytic_solution(t, k, p0, q0): + φ = jnp.sqrt(k) + p_t = p0 * jnp.cos(φ * t) + (q0/φ) * jnp.sin(φ * t) + q_t = -p0 * φ * jnp.sin(φ * t) + q0 * jnp.cos(φ * t) + return p_t, q_t + + + term = ( + diffrax.ODETerm(p_vector_field), + diffrax.ODETerm(q_vector_field), + ) + + true_pT, true_qT = analytic_solution(t1, k, p0, q0) + exponents = [] + errors_p = [] + errors_q = [] + for exponent in [0, -1, -2, -3, -4, -6, -8, -12]: + dt0 = 2**exponent + sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, k, max_steps=None) + pT, qT = sol.ys + error_p = jnp.sum(jnp.abs(pT - true_pT)) + error_q = jnp.sum(jnp.abs(qT - true_qT)) + if error_p < 2**-28 and error_q < 2**-28: + break + exponents.append(exponent) + errors_p.append(jnp.log2(error_q)) + errors_q.append(jnp.log2(error_q)) + + order_p = scipy.stats.linregress(exponents, errors_p). slope + order_q = scipy.stats.linregress(exponents, errors_q). slope + # Same wide range as for general ODE solvers, but we + # require this approximate order both for `p` and `q` + assert -0.9 < order_p - solver.order(term) < 0.9 + assert -0.9 < order_q - solver.order(term) < 0.9 + def _squareplus(x): return 0.5 * (x + jnp.sqrt(x**2 + 4)) @@ -338,14 +392,15 @@ def f(t, y, args): assert shaped_allclose(sol1.derivative(ti), -sol2.derivative(-ti)) -def test_semi_implicit_euler(): +@pytest.mark.parametrize("solver", all_symplectic_solvers) +def test_symplectic_solvers(solver): term1 = diffrax.ODETerm(lambda t, y, args: -y) term2 = diffrax.ODETerm(lambda t, y, args: y) y0 = (1.0, -0.5) dt0 = 0.00001 sol1 = diffrax.diffeqsolve( (term1, term2), - diffrax.SemiImplicitEuler(), + solver, 0, 1, dt0,