Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
QUICSORT as QUICSORT,
Ralston as Ralston,
ReversibleHeun as ReversibleHeun,
Ros3p as Ros3p,
SEA as SEA,
SemiImplicitEuler as SemiImplicitEuler,
ShARK as ShARK,
Expand Down
5 changes: 5 additions & 0 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
Euler,
EulerHeun,
ItoMilstein,
Ros3p,
StratonovichMilstein,
)
from ._step_size_controller import (
Expand Down Expand Up @@ -1034,6 +1035,10 @@ def diffeqsolve(
eqx.is_array_like(xi) and jnp.iscomplexobj(xi)
for xi in jtu.tree_leaves((terms, y0, args))
):
if isinstance(solver, Ros3p):
# TODO: add complex dtype support to ros3p.
raise ValueError("Ros3p does not support complex dtypes.")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this check could probably be moved to Ros3p.step? I try to keep the core diffeqsolve code solver-agnostic where possible.


warnings.warn(
"Complex dtype support in Diffrax is a work in progress and may not yet "
"produce correct results. Consider splitting your computation into real "
Expand Down
1 change: 1 addition & 0 deletions diffrax/_solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .quicsort import QUICSORT as QUICSORT
from .ralston import Ralston as Ralston
from .reversible_heun import ReversibleHeun as ReversibleHeun
from .ros3p import Ros3p as Ros3p
from .runge_kutta import (
AbstractDIRK as AbstractDIRK,
AbstractERK as AbstractERK,
Expand Down
223 changes: 223 additions & 0 deletions diffrax/_solver/ros3p.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
from collections.abc import Callable
from dataclasses import dataclass
from typing import ClassVar, TypeAlias

import equinox.internal as eqxi
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu
import lineax as lx
import numpy as np
from jaxtyping import ArrayLike

from .._custom_types import (
Args,
BoolScalarLike,
DenseInfo,
RealScalarLike,
VF,
Y,
)
from .._local_interpolation import ThirdOrderHermitePolynomialInterpolation
from .._solution import RESULTS
from .._term import AbstractTerm
from .base import AbstractAdaptiveSolver


_SolverState: TypeAlias = VF


@dataclass(frozen=True)
class _RosenbrockTableau:
"""The coefficient tableau for Rosenbrock methods"""

m_sol: np.ndarray
m_error: np.ndarray

a_lower: tuple[np.ndarray, ...]
c_lower: tuple[np.ndarray, ...]

α: np.ndarray
γ: np.ndarray

num_stages: int

# Example tableau
#
# α1 | a11 a12 a13 | c11 c12 c13 | γ1
# α1 | a21 a22 a23 | c21 c22 c23 | γ2
# α3 | a31 a32 a33 | c31 c32 c33 | γ3
# ---+----------------
# | m1 m2 m3
# | me1 me2 me3


_tableau = _RosenbrockTableau(
m_sol=np.array([2.0, 0.5773502691896258, 0.4226497308103742]),
m_error=np.array([2.113248654051871, 1.0, 0.4226497308103742]),
a_lower=(
np.array([1.267949192431123]),
np.array([1.267949192431123, 0.0]),
),
c_lower=(
np.array([-1.607695154586736]),
np.array([-3.464101615137755, -1.732050807568877]),
),
α=np.array([0.0, 1.0, 1.0]),
γ=np.array(
[
0.7886751345948129,
-0.2113248654051871,
-1.0773502691896260,
]
),
num_stages=3,
)


class Ros3p(AbstractAdaptiveSolver):
r"""Ros3p method.
3rd order Rosenbrock method for solving stiff equation. Uses third-order Hermite
polynomial interpolation for dense output.
??? cite "Reference"
```bibtex
@article{LangVerwer2001ROS3P,
author = {Lang, J. and Verwer, J.},
title = {ROS3P---An Accurate Third-Order Rosenbrock Solver Designed
for Parabolic Problems},
journal = {BIT Numerical Mathematics},
volume = {41},
number = {4},
pages = {731--738},
year = {2001},
doi = {10.1023/A:1021900219772}
}
```
"""

term_structure: ClassVar = AbstractTerm[ArrayLike, ArrayLike]
interpolation_cls: ClassVar[
Callable[..., ThirdOrderHermitePolynomialInterpolation]
] = ThirdOrderHermitePolynomialInterpolation.from_k

tableau: ClassVar[_RosenbrockTableau] = _tableau

def init(self, terms, t0, t1, y0, args) -> _SolverState:
del t1
return terms.vf(t0, y0, args)

def order(self, terms):
return 3

def step(
self,
terms: AbstractTerm[ArrayLike, ArrayLike],
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional work is required to make it work with MultiTerm. Reading about other rosenbrock method will allow me to design the proper PyTree abstraction. So, I've limited the term structure to the simple ode.

I can implement this now or include it in the next PR along with the next Rosenbrock method.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think it should already work as-is. A MultiTerm is an AbstractTerm already.

Try diffeqsolve(MultiTerm(ODETerm(...), ODETerm(...)), Ros3p(), ...) and see what happens?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, what I can see here though is that the evolving state and the control are being restricted to specifically arrays, and not the general pytree-of-arrays that we aim to support.

Take a look at equinox.internal.ω for a helper that we use ubiquitously to make it easy to work with pytree-valued state.

Alternatively, it would be fairly straightforward to use jax.flatten_util.ravel_pytree before and after the code you already have.

(Could you add a test that uses pytree-valued state to be sure that whichever choice you make works?)

t0: RealScalarLike,
t1: RealScalarLike,
y0: Y,
args: Args,
solver_state: _SolverState,
made_jump: BoolScalarLike,
) -> tuple[Y, Y, DenseInfo, _SolverState, RESULTS]:
y0_leaves = jtu.tree_leaves(y0)
sol_dtype = jnp.result_type(*y0_leaves)

time_derivative = jax.jacfwd(lambda t: terms.vf(t, y0, args))(t0)
control = terms.contr(t0, t1)

γ = jnp.array(self.tableau.γ, dtype=sol_dtype)
α = jnp.array(self.tableau.α, dtype=sol_dtype)

def embed_lower(x):
out = np.zeros(
(self.tableau.num_stages, self.tableau.num_stages), dtype=x[0].dtype
)
for i, val in enumerate(x):
out[i + 1, : i + 1] = val
return jnp.array(out, dtype=sol_dtype)

a_lower = embed_lower(self.tableau.a_lower)
c_lower = embed_lower(self.tableau.c_lower)
m_sol = jnp.array(self.tableau.m_sol, dtype=sol_dtype)
m_error = jnp.array(self.tableau.m_error, dtype=sol_dtype)

# common L.H.S
eye_shape = jax.ShapeDtypeStruct(time_derivative.shape, dtype=sol_dtype)
A = (lx.IdentityLinearOperator(eye_shape) / (control * γ[0])) - (
lx.JacobianLinearOperator(
lambda y, args: terms.vf(t0, y, args), y0, args=args
)
)

u = jnp.zeros(
(self.tableau.num_stages,) + time_derivative.shape, dtype=sol_dtype
)

def use_saved_vf(u):
stage_0_vf = solver_state
stage_0_b = stage_0_vf + ((control * γ[0]) * time_derivative)
stage_0_u = lx.linear_solve(A, stage_0_b).value

u = u.at[0].set(stage_0_u)
start_stage = 1
return u, start_stage

if made_jump is False:
u, start_stage = use_saved_vf(u)
else:
u, start_stage = lax.cond(
eqxi.unvmap_any(made_jump), lambda u: (u, 0), use_saved_vf, u
)

def body(u, stage):
# Σ_j a_{stage j} · u_j
y0_increment = jnp.tensordot(a_lower[stage], u, axes=[[0], [0]])
vf = terms.vf(
t0 + (α[stage] * control),
y0 + y0_increment,
args,
)

# Σ_j (c_{stage j}/control) · u_j
c_scaled_control = jax.vmap(lambda c: c / control)(c_lower[stage])
vf_increment = jnp.tensordot(c_scaled_control, u, axes=[[0], [0]])

b = vf + vf_increment + ((control * γ[stage]) * time_derivative)
# solving Ax=b
stage_u = lx.linear_solve(A, b).value
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a self.linear_solver argument that specifies the linear solver to use? (C.f. where this appears elsewhere in Diffrax, e.g. as an attribute of VeryChord)

u = u.at[stage].set(stage_u)
return u, vf

u, stage_vf = lax.scan(
f=body, init=u, xs=jnp.arange(start_stage, self.tableau.num_stages)
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is some super inside baseball here, but: performing a scan, whose body function both reads from and writes to the same buffer, will be inefficient when reverse-mode differentiated.

This particular combination of things is something that XLA fails to optimize well.

For this reason we have equinox.internal.scan instead, which can be used like so:

u, stage_vf = equinox.internal.scan(
    body, u, jnp.arange(self.tableau.num_stages),
    buffers=lambda x: x,
    kind="checkpointed",
    checkpoints=self.tableau.num_stages
)

The interesting part here is the buffers argument, which is used to specify a path to the particular arrays that will be the subject of this inplace-updating behaviour. The body function will then be called with an array wrapper that does some smart things to avoid the XLA issue.

(You've bumped straight into one of the most difficult JAX issues to tackle, I'm afraid!)


Other than that, note that this will error if start_stage is a traced value, i.e. if made_jump is not False, as then jnp.arange(start_stage, self.tableau.num_stages) will not be an array of known size.

I tihnk the easiest way to tackle this would be to remove the FSAL logic and just evaluate all stages on each step. (It is possible to still do FSAL, it's just very complicated – c.f. AbstractRungeKutta.step – let's not do that now 😁) To be sure we fix this, can you add a test that includes jumps?


y1 = y0 + jnp.tensordot(m_sol, u, axes=[[0], [0]])
y1_lower = y0 + jnp.tensordot(m_error, u, axes=[[0], [0]])
y1_error = y1 - y1_lower

if start_stage == 0:
vf0 = stage_vf[0] # type: ignore
else:
vf0 = solver_state
vf1 = terms.vf(t1, y1, args)
k = jnp.stack((terms.prod(vf0, control), terms.prod(vf1, control)))

dense_info = dict(y0=y0, y1=y1, k=k)
return y1, y1_error, dense_info, vf1, RESULTS.successful

def func(
self,
terms: AbstractTerm[ArrayLike, ArrayLike],
t0: RealScalarLike,
y0: Y,
args: Args,
) -> VF:
return terms.vf(t0, y0, args)


Ros3p.__init__.__doc__ = """**Arguments:** None"""
1 change: 1 addition & 0 deletions test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
diffrax.Kvaerno3(),
diffrax.Kvaerno4(),
diffrax.Kvaerno5(),
diffrax.Ros3p(),
)

all_split_solvers = (
Expand Down
7 changes: 7 additions & 0 deletions test/test_detest.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,12 @@ def _test(solver, problems, higher):
# size. (To avoid the adaptive step sizing sabotaging us.)
dt0 = 0.001
stepsize_controller = diffrax.ConstantStepSize()
elif type(solver) is diffrax.Ros3p and problem is _a1:
# Ros3p underestimates the error for _a1. This causes the step-size controller
# to take larger steps and results in an inaccurate solution.
dt0 = 0.0001
max_steps = 20_000_001
stepsize_controller = diffrax.ConstantStepSize()
else:
dt0 = None
if solver.order(term) < 4: # pyright: ignore
Expand All @@ -427,6 +433,7 @@ def _test(solver, problems, higher):
rtol = 1e-8
atol = 1e-8
stepsize_controller = diffrax.PIDController(rtol=rtol, atol=atol)

sol = diffrax.diffeqsolve(
term,
solver=solver,
Expand Down
4 changes: 4 additions & 0 deletions test/test_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ def test_ode_order(solver, dtype):

A = jr.normal(akey, (10, 10), dtype=dtype) * 0.5

if isinstance(solver, diffrax.Ros3p) and dtype == jnp.complex128:
## complex support is not added to ros3p.
return

if (
solver.term_structure
== diffrax.MultiTerm[tuple[diffrax.AbstractTerm, diffrax.AbstractTerm]]
Expand Down
4 changes: 4 additions & 0 deletions test/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def test_derivative(dtype, getkey):
paths.append((local_linear_interp, "local linear", ys[0], ys[-1]))

for solver in all_ode_solvers:
if isinstance(solver, diffrax.Ros3p) and dtype == jnp.complex128:
# ros3p does not support complex type.
continue

solver = implicit_tol(solver)
y0 = jr.normal(getkey(), (3,), dtype=dtype)

Expand Down
39 changes: 36 additions & 3 deletions test/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ class _DoubleDopri5(diffrax.AbstractRungeKutta):
tableau: ClassVar[diffrax.MultiButcherTableau] = diffrax.MultiButcherTableau(
diffrax.Dopri5.tableau, diffrax.Dopri5.tableau
)
calculate_jacobian: ClassVar[diffrax.CalculateJacobian] = (
diffrax.CalculateJacobian.never
)
calculate_jacobian: ClassVar[
diffrax.CalculateJacobian
] = diffrax.CalculateJacobian.never

@staticmethod
def interpolation_cls(**kwargs):
Expand Down Expand Up @@ -415,6 +415,7 @@ def f2(t, y, args):
diffrax.KenCarp3(),
diffrax.KenCarp4(),
diffrax.KenCarp5(),
diffrax.Ros3p(),
),
)
def test_rober(solver):
Expand Down Expand Up @@ -479,6 +480,38 @@ def vector_field(t, y, args):
f(1.0)


def test_ros3p():
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have a few tests that run pretty much every solver, it would be good to add ros3p to these as well.

term = diffrax.ODETerm(lambda t, y, args: -50.0 * y + jnp.sin(t))
solver = diffrax.Ros3p()
t0 = 0
t1 = 5
y0 = 0
ts = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float64)
saveat = diffrax.SaveAt(ts=ts)

stepsize_controller = diffrax.PIDController(rtol=1e-10, atol=1e-12)
sol = diffrax.diffeqsolve(
term,
solver,
t0=t0,
t1=t1,
dt0=0.1,
y0=y0,
stepsize_controller=stepsize_controller,
max_steps=60000,
saveat=saveat,
)

def exact_sol(t):
return (
jnp.exp(-50.0 * t) * (y0 + 1 / 2501)
+ (50.0 * jnp.sin(t) - jnp.cos(t)) / 2501
)

ys_ref = jtu.tree_map(exact_sol, ts)
tree_allclose(ys_ref, sol.ys)


# Doesn't crash
def test_adaptive_dt0_semiimplicit_euler():
f = diffrax.ODETerm(lambda t, y, args: y)
Expand Down