Skip to content

Commit 88f0c14

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents 1865057 + c6cc85c commit 88f0c14

28 files changed

+535
-359
lines changed

benchmarks/brownian_tree_times.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import cast, Optional, Union
1111
from typing_extensions import TypeAlias
1212

13+
import diffrax
1314
import equinox as eqx
1415
import equinox.internal as eqxi
1516
import jax
@@ -50,9 +51,9 @@ def __init__(
5051
tol: RealScalarLike,
5152
shape: tuple[int, ...],
5253
key: PRNGKeyArray,
53-
levy_area: str,
54+
levy_area: type[diffrax.AbstractBrownianIncrement] = diffrax.BrownianIncrement,
5455
):
55-
assert levy_area == ""
56+
assert levy_area == diffrax.BrownianIncrement
5657
self.t0 = t0
5758
self.t1 = t1
5859
self.tol = tol
@@ -187,13 +188,13 @@ def run(_ts):
187188
)
188189

189190

190-
for levy_area in ("", "space-time"):
191+
for levy_area in (diffrax.BrownianIncrement, diffrax.SpaceTimeLevyArea):
191192
print(f"- {levy_area=}")
192193
for tol in (2**-3, 2**-12):
193194
print(f"-- {tol=}")
194-
for num_ts in (1, 100):
195+
for num_ts in (1, 10000):
195196
print(f"--- {num_ts=}")
196-
if levy_area == "":
197+
if levy_area == diffrax.BrownianIncrement:
197198
print(f"Old: {time_tree(OldVBT, num_ts, tol, levy_area):.5f}")
198199
print(f"new: {time_tree(VirtualBrownianTree, num_ts, tol, levy_area):.5f}")
199200
print("")

diffrax/_adjoint.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,14 @@ def _solve(inputs):
429429
)
430430

431431

432+
# Unwrap jaxtyping decorator during tests, so that these are global functions.
433+
# This is needed to ensure `optx.implicit_jvp` is happy.
434+
if _vf.__globals__["__name__"].startswith("jaxtyping"):
435+
_vf = _vf.__wrapped__ # pyright: ignore[reportFunctionMemberAccess]
436+
if _solve.__globals__["__name__"].startswith("jaxtyping"):
437+
_solve = _solve.__wrapped__ # pyright: ignore[reportFunctionMemberAccess]
438+
439+
432440
def _frozenset(x: Union[object, Iterable[object]]) -> frozenset[object]:
433441
try:
434442
iter_x = iter(x) # pyright: ignore

diffrax/_brownian/path.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import jax.tree_util as jtu
1010
import lineax.internal as lxi
1111
from jaxtyping import Array, PRNGKeyArray, PyTree
12+
from lineax.internal import complex_to_real_dtype
1213

1314
from .._custom_types import (
1415
AbstractBrownianIncrement,
@@ -44,7 +45,19 @@ class UnsafeBrownianPath(AbstractBrownianPath):
4445
motion. Hence the restrictions above. (They describe the general case for which the
4546
correlation structure isn't needed.)
4647
47-
Depending on the `levy_area` argument, this can also be used to generate Levy area.
48+
!!! info "Levy Area"
49+
50+
Can be initialised with `levy_area` set to `diffrax.BrownianIncrement`, or
51+
`diffrax.SpaceTimeLevyArea`. If `levy_area=diffrax.SpaceTimeLevyArea`, then it
52+
also computes space-time Lévy area `H`. This is an additional source of
53+
randomness required for certain stochastic Runge--Kutta solvers; see
54+
[`diffrax.AbstractSRK`][] for more information.
55+
56+
An error will be thrown during tracing if Lévy area is required but is not
57+
available.
58+
59+
The choice here will impact the Brownian path, so even with the same key, the
60+
trajectory will be different depending on the value of `levy_area`.
4861
"""
4962

5063
shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
@@ -130,7 +143,7 @@ def _evaluate_leaf(
130143
):
131144
w_std = jnp.sqrt(t1 - t0).astype(shape.dtype)
132145
w = jr.normal(key, shape.shape, shape.dtype) * w_std
133-
dt = jnp.asarray(t1 - t0, dtype=shape.dtype)
146+
dt = jnp.asarray(t1 - t0, dtype=complex_to_real_dtype(shape.dtype))
134147

135148
if levy_area is SpaceTimeLevyArea:
136149
key, key_hh = jr.split(key, 2)

diffrax/_brownian/tree.py

Lines changed: 54 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import jax.random as jr
1111
import jax.tree_util as jtu
1212
import lineax.internal as lxi
13-
from jaxtyping import Array, Float, PRNGKeyArray, PyTree
13+
from jaxtyping import Array, Inexact, PRNGKeyArray, PyTree
14+
from lineax.internal import complex_to_real_dtype
1415

1516
from .._custom_types import (
1617
AbstractBrownianIncrement,
@@ -54,9 +55,9 @@
5455
# For the midpoint rule for generating space-time Levy area see Theorem 6.1.6.
5556
# For the general interpolation rule for space-time Levy area see Theorem 6.1.4.
5657

57-
FloatDouble: TypeAlias = tuple[Float[Array, " *shape"], Float[Array, " *shape"]]
58+
FloatDouble: TypeAlias = tuple[Inexact[Array, " *shape"], Inexact[Array, " *shape"]]
5859
FloatTriple: TypeAlias = tuple[
59-
Float[Array, " *shape"], Float[Array, " *shape"], Float[Array, " *shape"]
60+
Inexact[Array, " *shape"], Inexact[Array, " *shape"], Inexact[Array, " *shape"]
6061
]
6162
_Spline: TypeAlias = Literal["sqrt", "quad", "zero"]
6263
_BrownianReturn = TypeVar("_BrownianReturn", bound=AbstractBrownianIncrement)
@@ -90,7 +91,7 @@ def _levy_diff(_, x0: tuple, x1: tuple) -> AbstractBrownianIncrement:
9091
assert len(x1) == 2
9192
dt0, w0 = x0
9293
dt1, w1 = x1
93-
su = jnp.asarray(dt1 - dt0, dtype=w0.dtype)
94+
su = jnp.asarray(dt1 - dt0, dtype=complex_to_real_dtype(w0.dtype))
9495
return BrownianIncrement(dt=su, W=w1 - w0)
9596

9697
elif len(x0) == 4: # space-time levy area case
@@ -99,12 +100,13 @@ def _levy_diff(_, x0: tuple, x1: tuple) -> AbstractBrownianIncrement:
99100
dt1, w1, hh1, bhh1 = x1
100101

101102
w_su = w1 - w0
102-
su = jnp.asarray(dt1 - dt0, dtype=w0.dtype)
103+
su = jnp.asarray(dt1 - dt0, dtype=complex_to_real_dtype(w0.dtype))
103104
_su = jnp.where(jnp.abs(su) < jnp.finfo(su).eps, jnp.inf, su)
104105
inverse_su = 1 / _su
105-
u_bb_s = dt1 * w0 - dt0 * w1
106-
bhh_su = bhh1 - bhh0 - 0.5 * u_bb_s # bhh_su = H_{s,u} * (u-s)
107-
hh_su = inverse_su * bhh_su
106+
with jax.numpy_dtype_promotion("standard"):
107+
u_bb_s = dt1 * w0 - dt0 * w1
108+
bhh_su = bhh1 - bhh0 - 0.5 * u_bb_s # bhh_su = H_{s,u} * (u-s)
109+
hh_su = inverse_su * bhh_su
108110
return SpaceTimeLevyArea(dt=su, W=w_su, H=hh_su)
109111
else:
110112
assert False
@@ -135,10 +137,19 @@ def _split_interval(
135137
class VirtualBrownianTree(AbstractBrownianPath):
136138
"""Brownian simulation that discretises the interval `[t0, t1]` to tolerance `tol`.
137139
138-
Can be initialised with `levy_area` set to `""`, or `"space-time"`.
139-
If `levy_area="space_time"`, then it also computes space-time Lévy area `H`.
140-
This will impact the Brownian path, so even with the same key, the trajectory will
141-
be different depending on the value of `levy_area`.
140+
!!! info "Levy Area"
141+
142+
Can be initialised with `levy_area` set to `diffrax.BrownianIncrement`, or
143+
`diffrax.SpaceTimeLevyArea`. If `levy_area=diffrax.SpaceTimeLevyArea`, then it
144+
also computes space-time Lévy area `H`. This is an additional source of
145+
randomness required for certain stochastic Runge--Kutta solvers; see
146+
[`diffrax.AbstractSRK`][] for more information.
147+
148+
An error will be thrown during tracing if Lévy area is required but is not
149+
available.
150+
151+
The choice here will impact the Brownian path, so even with the same key, the
152+
trajectory will be different depending on the value of `levy_area`.
142153
143154
??? cite "Reference"
144155
@@ -283,9 +294,10 @@ def _evaluate_leaf(
283294
tuple[RealScalarLike, Array], tuple[RealScalarLike, Array, Array, Array]
284295
]:
285296
shape, dtype = struct.shape, struct.dtype
297+
tdtype = complex_to_real_dtype(dtype)
286298

287-
t0 = jnp.zeros((), dtype)
288-
r = jnp.asarray(r, dtype)
299+
t0 = jnp.zeros((), tdtype)
300+
r = jnp.asarray(r, tdtype)
289301

290302
if self.levy_area is SpaceTimeLevyArea:
291303
state_key, init_key_w, init_key_la = jr.split(key, 3)
@@ -394,14 +406,33 @@ def _body_fun(_state: _State):
394406
a = d_prime * sr3 * sr_ru_half
395407
b = d_prime * ru3 * sr_ru_half
396408

397-
w_sr = sr / su * w_su + 6 * sr * ru / su3 * bhh_su + 2 * (a + b) / su * x1
398-
w_r = w_s + w_sr
399-
c = jnp.sqrt(3 * sr3 * ru3) / (6 * d)
400-
bhh_sr = sr3 / su3 * bhh_su - a * x1 + c * x2
401-
bhh_r = bhh_s + bhh_sr + 0.5 * (r * w_s - s * w_r)
409+
with jax.numpy_dtype_promotion("standard"):
410+
w_sr = (
411+
sr / su * w_su + 6 * sr * ru / su3 * bhh_su + 2 * (a + b) / su * x1
412+
)
413+
w_r = w_s + w_sr
414+
c = jnp.sqrt(3 * sr3 * ru3) / (6 * d)
415+
bhh_sr = sr3 / su3 * bhh_su - a * x1 + c * x2
416+
bhh_r = bhh_s + bhh_sr + 0.5 * (r * w_s - s * w_r)
402417

403-
inverse_r = 1 / jnp.where(jnp.abs(r) < jnp.finfo(r).eps, jnp.inf, r)
404-
hh_r = inverse_r * bhh_r
418+
inverse_r = 1 / jnp.where(jnp.abs(r) < jnp.finfo(r).eps, jnp.inf, r)
419+
hh_r = inverse_r * bhh_r
420+
421+
elif self.levy_area is BrownianIncrement:
422+
with jax.numpy_dtype_promotion("standard"):
423+
w_mean = w_s + sr / su * w_su
424+
if self._spline == "sqrt":
425+
z = jr.normal(final_state.key, shape, dtype)
426+
bb = jnp.sqrt(sr * ru / su) * z
427+
elif self._spline == "quad":
428+
z = jr.normal(final_state.key, shape, dtype)
429+
bb = (sr * ru / su) * z
430+
elif self._spline == "zero":
431+
bb = jnp.zeros(shape, dtype)
432+
else:
433+
assert False
434+
w_r = w_mean + bb
435+
return r, w_r
405436

406437
elif self.levy_area is BrownianIncrement:
407438
w_mean = w_s + sr / su * w_su
@@ -497,8 +528,8 @@ def _brownian_arch(
497528

498529
w_t = w_s + w_st
499530
w_stu = (w_s, w_t, w_u)
500-
501-
bhh_t = bhh_s + bhh_st + 0.5 * (t * w_s - s * w_t)
531+
with jax.numpy_dtype_promotion("standard"):
532+
bhh_t = bhh_s + bhh_st + 0.5 * (t * w_s - s * w_t)
502533
bhh_stu = (bhh_s, bhh_t, bhh_u)
503534
bkk_stu = None
504535
bkk_st_tu = None

diffrax/_global_interpolation.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,10 @@ def _index(_ys):
143143
next_t = self.ts[index + 1]
144144
diff_t = next_t - prev_t
145145

146-
return (prev_ys**ω + (next_ys**ω - prev_ys**ω) * (fractional_part / diff_t)).ω
146+
with jax.numpy_dtype_promotion("standard"):
147+
return (
148+
prev_ys**ω + (next_ys**ω - prev_ys**ω) * (fractional_part / diff_t)
149+
).ω
147150

148151
@eqx.filter_jit
149152
def derivative(self, t: RealScalarLike, left: bool = True) -> PyTree[Array]:
@@ -165,10 +168,11 @@ def derivative(self, t: RealScalarLike, left: bool = True) -> PyTree[Array]:
165168

166169
index, _ = self._interpret_t(t, left)
167170

168-
return (
169-
(ω(self.ys)[index + 1] - ω(self.ys)[index])
170-
/ (self.ts[index + 1] - self.ts[index])
171-
).ω
171+
with jax.numpy_dtype_promotion("standard"):
172+
return (
173+
(ω(self.ys)[index + 1] - ω(self.ys)[index])
174+
/ (self.ts[index + 1] - self.ts[index])
175+
).ω
172176

173177

174178
LinearInterpolation.__init__.__doc__ = """**Arguments:**
@@ -254,10 +258,11 @@ def evaluate(
254258

255259
d, c, b, a = self.coeffs
256260

257-
return (
258-
ω(a)[index]
259-
+ frac * (ω(b)[index] + frac * (ω(c)[index] + frac * ω(d)[index]))
260-
).ω
261+
with jax.numpy_dtype_promotion("standard"):
262+
return (
263+
ω(a)[index]
264+
+ frac * (ω(b)[index] + frac * (ω(c)[index] + frac * ω(d)[index]))
265+
).ω
261266

262267
@eqx.filter_jit
263268
def derivative(
@@ -283,7 +288,8 @@ def derivative(
283288

284289
d, c, b, _ = self.coeffs
285290

286-
return (ω(b)[index] + frac * (2 * ω(c)[index] + frac * 3 * ω(d)[index])).ω
291+
with jax.numpy_dtype_promotion("standard"):
292+
return (ω(b)[index] + frac * (2 * ω(c)[index] + frac * 3 * ω(d)[index])).ω
287293

288294

289295
CubicInterpolation.__init__.__doc__ = """**Arguments:**
@@ -622,8 +628,9 @@ def _hermite_forward(
622628
]:
623629
prev_ti, prev_yi, prev_deriv_i = carry
624630
ti, yi, next_ti, next_yi = value
625-
first_deriv_i = (next_yi - yi) / (next_ti - ti)
626-
later_deriv_i = (yi - prev_yi) / (ti - prev_ti)
631+
with jax.numpy_dtype_promotion("standard"):
632+
first_deriv_i = (next_yi - yi) / (next_ti - ti)
633+
later_deriv_i = (yi - prev_yi) / (ti - prev_ti)
627634
deriv_i = jnp.where(jnp.isnan(prev_yi), first_deriv_i, later_deriv_i)
628635
cond = jnp.isnan(yi)
629636
carry_ti = jnp.where(cond, prev_ti, ti)
@@ -635,13 +642,15 @@ def _hermite_forward(
635642

636643
def _hermite_coeffs(t0, y0, deriv0, t1, y1):
637644
t_diff = t1 - t0
638-
deriv1 = (y1 - y0) / t_diff
639-
d_deriv = deriv1 - deriv0
640645

641-
a = y0
642-
b = deriv0
643-
c = 2 * d_deriv / t_diff
644-
d = -d_deriv / t_diff**2
646+
with jax.numpy_dtype_promotion("standard"):
647+
deriv1 = (y1 - y0) / t_diff
648+
d_deriv = deriv1 - deriv0
649+
650+
a = y0
651+
b = deriv0
652+
c = 2 * d_deriv / t_diff
653+
d = -d_deriv / (t_diff**2)
645654

646655
return d, c, b, a
647656

@@ -684,7 +693,8 @@ def _backward_hermite_coefficients(
684693
else:
685694
y0 = jnp.broadcast_to(replace_nans_at_start, ys[0].shape)
686695
if deriv0 is None:
687-
deriv0 = (next_ys[0] - y0) / (next_ts[0] - t0)
696+
with jax.numpy_dtype_promotion("standard"):
697+
deriv0 = (next_ys[0] - y0) / (next_ts[0] - t0)
688698
else:
689699
deriv0 = jnp.broadcast_to(deriv0, ys[0].shape)
690700
ts = ts[:-1]

diffrax/_integrate.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@ def _check(term_cls, term, term_contr_kwargs, yi):
156156
# If we've got to this point then the term is compatible
157157

158158
try:
159-
jtu.tree_map(_check, term_structure, terms, contr_kwargs, y)
159+
with jax.numpy_dtype_promotion("standard"):
160+
jtu.tree_map(_check, term_structure, terms, contr_kwargs, y)
160161
except ValueError:
161162
# ValueError may also arise from mismatched tree structures
162163
return False
@@ -709,11 +710,6 @@ def diffeqsolve(
709710
eqx.is_array(xi) and jnp.iscomplexobj(xi)
710711
for xi in jtu.tree_leaves((terms, y0, args))
711712
):
712-
if isinstance(solver, AbstractImplicitSolver):
713-
raise ValueError(
714-
"Implicit solvers in conjunction with complex dtypes is currently not "
715-
"supported."
716-
)
717713
warnings.warn(
718714
"Complex dtype support is work in progress, please read "
719715
"https://github.com/patrick-kidger/diffrax/pull/197 and proceed carefully.",
@@ -808,9 +804,6 @@ def _promote(yi):
808804
"`UnsafeBrownianPath` cannot be used with adaptive step sizes."
809805
)
810806

811-
if isinstance(solver, KLSolver):
812-
y0 = (y0, 0.0)
813-
y0 = jtu.tree_map(_promote, y0)
814807
# Normalises time: if t0 > t1 then flip things around.
815808
direction = jnp.where(t0 < t1, 1, -1)
816809
t0 = t0 * direction

0 commit comments

Comments
 (0)