Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
98c6922
kl
lockwo Apr 17, 2024
3b4e07c
add more training lengths
lockwo Apr 18, 2024
7d62553
Add progress bar
abocquet Jan 18, 2024
83eb8bf
Progress bar tweaks.
patrick-kidger Feb 1, 2024
6c93faa
got Rid of internally using LevyVal in VBT
andyElking Feb 16, 2024
34cbe5c
Parametric control types (#364)
tttc3 Feb 20, 2024
09cadab
Fixup test_term static type checking
tttc3 Apr 21, 2024
322852a
Create py.typed
lockwo Apr 21, 2024
41c58a6
PID for complex dtype fixes (#391)
Randl Apr 22, 2024
6200412
Merge branch 'patrick-kidger:main' into main
lockwo Apr 23, 2024
085b68d
KL solver wrapper draft
lockwo Apr 24, 2024
f603f38
add saves
lockwo Apr 24, 2024
97c1fc4
minor fixes, more to come
lockwo Apr 25, 2024
c9e1573
intermediate work
lockwo Apr 26, 2024
f21f26d
finalization for review
lockwo Apr 27, 2024
2b56424
forgot saveat
lockwo Apr 27, 2024
633afbd
_control term isn't recognized for some reason
lockwo Apr 27, 2024
55d3c0f
Everything SRK related squashed on top of diffrax/dev (#344)
andyElking Apr 27, 2024
6e34acd
Merge branch 'dev'
lockwo Apr 27, 2024
7ea3127
fix test
lockwo Apr 27, 2024
97ce9b0
Doc tweaks for SRK. (#409)
patrick-kidger Apr 27, 2024
35af3e9
3.9 fix
lockwo Apr 27, 2024
f9e40e1
3.9 fix2
lockwo May 3, 2024
ac28ce4
fixed brownian_tree_times.py (#410)
andyElking May 4, 2024
a998093
Fixed progress meters with `jax.grad`
patrick-kidger Apr 7, 2024
c4deca4
Enable more complex tests, fix related errors (#392)
Randl May 4, 2024
1865057
a
lockwo May 8, 2024
4cb2475
Enable implicit solvers for complex inputs (#411)
Randl May 13, 2024
daed558
Bump version
patrick-kidger May 18, 2024
6abbd3b
Improved Levy area documentation
patrick-kidger May 19, 2024
dd633e4
Merge remote-tracking branch 'upstream/dev'
lockwo Jun 3, 2024
7e34706
adjoint implicit fix
lockwo Jun 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.7
rev: v0.2.2
hooks:
- id: ruff # linter
types_or: [ python, pyi, jupyter ]
args: [ --fix ]
- id: ruff-format # formatter
types_or: [ python, pyi, jupyter ]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.316
rev: v1.1.350
hooks:
- id: pyright
additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typing_extensions]
11 changes: 6 additions & 5 deletions benchmarks/brownian_tree_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import cast, Optional, Union
from typing_extensions import TypeAlias

import diffrax
import equinox as eqx
import equinox.internal as eqxi
import jax
Expand Down Expand Up @@ -50,9 +51,9 @@ def __init__(
tol: RealScalarLike,
shape: tuple[int, ...],
key: PRNGKeyArray,
levy_area: str,
levy_area: type[diffrax.AbstractBrownianIncrement] = diffrax.BrownianIncrement,
):
assert levy_area == ""
assert levy_area == diffrax.BrownianIncrement
self.t0 = t0
self.t1 = t1
self.tol = tol
Expand Down Expand Up @@ -187,13 +188,13 @@ def run(_ts):
)


for levy_area in ("", "space-time"):
for levy_area in (diffrax.BrownianIncrement, diffrax.SpaceTimeLevyArea):
print(f"- {levy_area=}")
for tol in (2**-3, 2**-12):
print(f"-- {tol=}")
for num_ts in (1, 100):
for num_ts in (1, 10000):
print(f"--- {num_ts=}")
if levy_area == "":
if levy_area == diffrax.BrownianIncrement:
print(f"Old: {time_tree(OldVBT, num_ts, tol, levy_area):.5f}")
print(f"new: {time_tree(VirtualBrownianTree, num_ts, tol, levy_area):.5f}")
print("")
12 changes: 6 additions & 6 deletions benchmarks/small_neural_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
class FuncTorch(torch.nn.Module):
def __init__(self):
super().__init__()
self.func = torch.jit.script( # pyright: ignore
self.func = torch.jit.script(
torch.nn.Sequential(
torch.nn.Linear(4, 32),
torch.nn.Softplus(),
Expand All @@ -30,7 +30,7 @@ def __init__(self):
)

def forward(self, t, y):
return self.func(y) # pyright: ignore
return self.func(y)


class FuncJax(eqx.Module):
Expand Down Expand Up @@ -177,10 +177,10 @@ def run(multiple, grad, batch_size=64, t1=100):
with torch.no_grad():
func_jax = neural_ode_diffrax.func.func
func_torch = neural_ode_torch.func.func
func_torch[0].weight.copy_(torch.tensor(np.asarray(func_jax.layers[0].weight))) # pyright: ignore
func_torch[0].bias.copy_(torch.tensor(np.asarray(func_jax.layers[0].bias))) # pyright: ignore
func_torch[2].weight.copy_(torch.tensor(np.asarray(func_jax.layers[1].weight))) # pyright: ignore
func_torch[2].bias.copy_(torch.tensor(np.asarray(func_jax.layers[1].bias))) # pyright: ignore
func_torch[0].weight.copy_(torch.tensor(np.asarray(func_jax.layers[0].weight)))
func_torch[0].bias.copy_(torch.tensor(np.asarray(func_jax.layers[0].bias)))
func_torch[2].weight.copy_(torch.tensor(np.asarray(func_jax.layers[1].weight)))
func_torch[2].bias.copy_(torch.tensor(np.asarray(func_jax.layers[1].bias)))

y0_jax = jr.normal(jr.PRNGKey(1), (batch_size, 4))
y0_torch = torch.tensor(np.asarray(y0_jax))
Expand Down
24 changes: 23 additions & 1 deletion diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@
UnsafeBrownianPath as UnsafeBrownianPath,
VirtualBrownianTree as VirtualBrownianTree,
)
from ._custom_types import LevyVal as LevyVal
from ._custom_types import (
AbstractBrownianIncrement as AbstractBrownianIncrement,
AbstractSpaceTimeLevyArea as AbstractSpaceTimeLevyArea,
AbstractSpaceTimeTimeLevyArea as AbstractSpaceTimeTimeLevyArea,
BrownianIncrement as BrownianIncrement,
SpaceTimeLevyArea as SpaceTimeLevyArea,
SpaceTimeTimeLevyArea as SpaceTimeTimeLevyArea,
)
from ._event import (
AbstractDiscreteTerminatingEvent as AbstractDiscreteTerminatingEvent,
DiscreteTerminatingEvent as DiscreteTerminatingEvent,
Expand All @@ -37,6 +44,12 @@
)
from ._misc import adjoint_rms_seminorm as adjoint_rms_seminorm
from ._path import AbstractPath as AbstractPath
from ._progress_meter import (
AbstractProgressMeter as AbstractProgressMeter,
NoProgressMeter as NoProgressMeter,
TextProgressMeter as TextProgressMeter,
TqdmProgressMeter as TqdmProgressMeter,
)
from ._root_finder import (
VeryChord as VeryChord,
with_stepsize_controller_tols as with_stepsize_controller_tols,
Expand All @@ -59,6 +72,7 @@
AbstractRungeKutta as AbstractRungeKutta,
AbstractSDIRK as AbstractSDIRK,
AbstractSolver as AbstractSolver,
AbstractSRK as AbstractSRK,
AbstractStratonovichSolver as AbstractStratonovichSolver,
AbstractWrappedSolver as AbstractWrappedSolver,
Bosh3 as Bosh3,
Expand All @@ -68,13 +82,15 @@
Dopri8 as Dopri8,
Euler as Euler,
EulerHeun as EulerHeun,
GeneralShARK as GeneralShARK,
HalfSolver as HalfSolver,
Heun as Heun,
ImplicitEuler as ImplicitEuler,
ItoMilstein as ItoMilstein,
KenCarp3 as KenCarp3,
KenCarp4 as KenCarp4,
KenCarp5 as KenCarp5,
KLSolver as KLSolver,
Kvaerno3 as Kvaerno3,
Kvaerno4 as Kvaerno4,
Kvaerno5 as Kvaerno5,
Expand All @@ -83,8 +99,14 @@
MultiButcherTableau as MultiButcherTableau,
Ralston as Ralston,
ReversibleHeun as ReversibleHeun,
SEA as SEA,
SemiImplicitEuler as SemiImplicitEuler,
ShARK as ShARK,
Sil3 as Sil3,
SlowRK as SlowRK,
SPaRK as SPaRK,
SRA1 as SRA1,
StochasticButcherTableau as StochasticButcherTableau,
StratonovichMilstein as StratonovichMilstein,
Tsit5 as Tsit5,
)
Expand Down
19 changes: 16 additions & 3 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import abc
import functools as ft
import warnings
from collections.abc import Iterable
from typing import Any, Optional, Union
from collections.abc import Callable, Iterable
from typing import Any, cast, Optional, Union

import equinox as eqx
import equinox.internal as eqxi
Expand All @@ -20,6 +20,9 @@
from ._term import AbstractTerm, AdjointTerm


ω = cast(Callable, ω)


def _is_none(x):
return x is None

Expand Down Expand Up @@ -128,6 +131,7 @@ def loop(
init_state,
passed_solver_state,
passed_controller_state,
progress_meter,
) -> Any:
"""Runs the main solve loop. Subclasses can override this to provide custom
backpropagation behaviour; see for example the implementation of
Expand Down Expand Up @@ -425,6 +429,14 @@ def _solve(inputs):
)


# Unwrap jaxtyping decorator during tests, so that these are global functions.
# This is needed to ensure `optx.implicit_jvp` is happy.
if _vf.__globals__["__name__"].startswith("jaxtyping"):
_vf = _vf.__wrapped__ # pyright: ignore[reportFunctionMemberAccess]
if _solve.__globals__["__name__"].startswith("jaxtyping"):
_solve = _solve.__wrapped__ # pyright: ignore[reportFunctionMemberAccess]


def _frozenset(x: Union[object, Iterable[object]]) -> frozenset[object]:
try:
iter_x = iter(x) # pyright: ignore
Expand Down Expand Up @@ -559,6 +571,7 @@ def _loop_backsolve_bwd(
max_steps,
throw,
init_state,
progress_meter,
):
assert discrete_terminating_event is None

Expand All @@ -567,7 +580,7 @@ def _loop_backsolve_bwd(
# using them later.
#

del perturbed, init_state, t1
del perturbed, init_state, t1, progress_meter
ts, ys = residuals
del residuals
grad_final_state, _ = grad_final_state__aux_stats
Expand Down
25 changes: 24 additions & 1 deletion diffrax/_autocitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,23 @@
from ._saveat import SubSaveAt
from ._solver import (
AbstractImplicitSolver,
AbstractItoSolver,
AbstractSRK,
AbstractStratonovichSolver,
Dopri5,
Dopri8,
GeneralShARK,
Kvaerno3,
Kvaerno4,
Kvaerno5,
LeapfrogMidpoint,
ReversibleHeun,
SEA,
SemiImplicitEuler,
ShARK,
SlowRK,
SPaRK,
SRA1,
Tsit5,
)
from ._step_size_controller import PIDController
Expand Down Expand Up @@ -374,7 +383,15 @@ def _backsolve_rms_norm(adjoint):

@citation_rules.append
def _explicit_solver(solver, terms=None):
if not isinstance(solver, AbstractImplicitSolver) and not is_sde(terms):
if not isinstance(
solver,
(
AbstractImplicitSolver,
AbstractSRK,
AbstractItoSolver,
AbstractStratonovichSolver,
),
) and not is_sde(terms):
return r"""
% You are using an explicit solver, and may wish to cite the standard textbook:
@book{hairer2008solving-i,
Expand Down Expand Up @@ -467,6 +484,12 @@ def _solvers(solver, saveat=None):
Kvaerno5,
ReversibleHeun,
LeapfrogMidpoint,
ShARK,
SRA1,
SlowRK,
GeneralShARK,
SPaRK,
SEA,
):
return (
r"""
Expand Down
18 changes: 13 additions & 5 deletions diffrax/_brownian/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
import abc
from typing import Optional, Union
from typing import Optional, TypeVar, Union

from equinox.internal import AbstractVar
from jaxtyping import Array, PyTree

from .._custom_types import LevyArea, LevyVal, RealScalarLike
from .._custom_types import (
AbstractBrownianIncrement,
BrownianIncrement,
RealScalarLike,
SpaceTimeLevyArea,
)
from .._path import AbstractPath


class AbstractBrownianPath(AbstractPath):
_Control = TypeVar("_Control", bound=Union[PyTree[Array], AbstractBrownianIncrement])


class AbstractBrownianPath(AbstractPath[_Control]):
"""Abstract base class for all Brownian paths."""

levy_area: AbstractVar[LevyArea]
levy_area: AbstractVar[type[Union[BrownianIncrement, SpaceTimeLevyArea]]]

@abc.abstractmethod
def evaluate(
Expand All @@ -20,7 +28,7 @@ def evaluate(
t1: Optional[RealScalarLike] = None,
left: bool = True,
use_levy: bool = False,
) -> Union[PyTree[Array], LevyVal]:
) -> _Control:
r"""Samples a Brownian increment $w(t_1) - w(t_0)$.

Each increment has distribution $\mathcal{N}(0, t_1 - t_0)$.
Expand Down
Loading