Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
98c6922
kl
lockwo Apr 17, 2024
3b4e07c
add more training lengths
lockwo Apr 18, 2024
7d62553
Add progress bar
abocquet Jan 18, 2024
83eb8bf
Progress bar tweaks.
patrick-kidger Feb 1, 2024
6c93faa
got Rid of internally using LevyVal in VBT
andyElking Feb 16, 2024
34cbe5c
Parametric control types (#364)
tttc3 Feb 20, 2024
09cadab
Fixup test_term static type checking
tttc3 Apr 21, 2024
322852a
Create py.typed
lockwo Apr 21, 2024
41c58a6
PID for complex dtype fixes (#391)
Randl Apr 22, 2024
6200412
Merge branch 'patrick-kidger:main' into main
lockwo Apr 23, 2024
085b68d
KL solver wrapper draft
lockwo Apr 24, 2024
f603f38
add saves
lockwo Apr 24, 2024
97c1fc4
minor fixes, more to come
lockwo Apr 25, 2024
c9e1573
intermediate work
lockwo Apr 26, 2024
f21f26d
finalization for review
lockwo Apr 27, 2024
2b56424
forgot saveat
lockwo Apr 27, 2024
633afbd
_control term isn't recognized for some reason
lockwo Apr 27, 2024
55d3c0f
Everything SRK related squashed on top of diffrax/dev (#344)
andyElking Apr 27, 2024
6e34acd
Merge branch 'dev'
lockwo Apr 27, 2024
7ea3127
fix test
lockwo Apr 27, 2024
97ce9b0
Doc tweaks for SRK. (#409)
patrick-kidger Apr 27, 2024
35af3e9
3.9 fix
lockwo Apr 27, 2024
f9e40e1
3.9 fix2
lockwo May 3, 2024
ac28ce4
fixed brownian_tree_times.py (#410)
andyElking May 4, 2024
a998093
Fixed progress meters with `jax.grad`
patrick-kidger Apr 7, 2024
c4deca4
Enable more complex tests, fix related errors (#392)
Randl May 4, 2024
1865057
a
lockwo May 8, 2024
4cb2475
Enable implicit solvers for complex inputs (#411)
Randl May 13, 2024
daed558
Bump version
patrick-kidger May 18, 2024
6abbd3b
Improved Levy area documentation
patrick-kidger May 19, 2024
dd633e4
Merge remote-tracking branch 'upstream/dev'
lockwo Jun 3, 2024
7e34706
adjoint implicit fix
lockwo Jun 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@
ODETerm as ODETerm,
WeaklyDiagonalControlTerm as WeaklyDiagonalControlTerm,
)
from ._kl_term import (
sde_kl_divergence
)


__version__ = importlib.metadata.version("diffrax")
175 changes: 175 additions & 0 deletions diffrax/_kl_term.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import operator
from typing import Tuple, Callable

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import equinox as eqx
from jaxtyping import Array, PyTree
from ._custom_types import RealScalarLike, Control
from ._term import (
AbstractTerm,
ControlTerm,
MultiTerm,
WeaklyDiagonalControlTerm,
)


def _kl_diagonal(drift: Array, diffusion: Array):
"""This is the case where diffusion matrix is
a diagonal matrix
"""
diffusion = jnp.where(
jax.lax.stop_gradient(diffusion) > 1e-7,
diffusion,
jnp.full_like(diffusion, fill_value=1e-7) * jnp.sign(diffusion),
)
scale = drift / diffusion
return 0.5 * jnp.sum(scale**2)


def _kl_full_matrix(drift: Array, diffusion: Array):
"""General case"""
scale = jnp.linalg.pinv(diffusion) @ drift
return 0.5 * jnp.sum(scale**2)

def _handle(drift: Array, diffusion: Array):
"""According to the shape of drift and diffusion,
select the right way to compute KL divergence
"""
eqx.error_if(drift, not eqx.is_array(drift), "Only array drifts are supported")
eqx.error_if(diffusion, not eqx.is_array(diffusion), "Only array diffusion are supported")
if drift.shape == diffusion.shape:
return _kl_diagonal(drift, diffusion)
else:
return _kl_full_matrix(drift, diffusion)


def _kl_block_diffusion(drift: PyTree, diffusion: PyTree):
"""The case where diffusion matrix is a block diagonal matrix"""
kl = jtu.tree_map(
_handle,
drift,
diffusion,
)

kl = jtu.tree_reduce(
operator.add,
kl,
)
return kl


class _AugDrift(AbstractTerm):

drift1: Callable
drift2: Callable
diffusion: AbstractTerm

def vf(self, t: RealScalarLike, y: PyTree, args) -> PyTree:
# In this implementation, we may restricte our case where the
# diffusion can be a block matrix. Each block can follow
# different `vf_prod`
# - PyTree of drift: (*, *, ..., *) :
# - PyTree of diffusion: (*, *, ..., *)
# For example,
# - output of drift can be
# drift = {"block1": jnp.zeros((2,)),
# "block2": jnp.zeros((2,)),
# "block3": jnp.zeros((3,))}
# - output of diffusion (which mixes between the two types)
# diffusion = {"block1": jnp.ones((2,)), #-> WeaklyDiagonal
# "block2": jnp.ones((2, 3)), #-> General case
# "block3": jnp.ones((3, 4))} #-> General case
#
# NOTE: `args` will take `context` as a function (normally, `args`
# is PyTree)

y, _ = y

# check if there is context
context = args
aug_y = y if context is None else jnp.concatenate([y, context(t)], axis=-1)

drift1 = self.drift1(t, aug_y, args) # we can't make these .vf becuase _broadcast_and_upcast
# requires that aug_y and drift(aug_y) are the same shape, but they aren't
drift2 = self.drift2(t, y, args)

drift = jtu.tree_map(operator.sub, drift1, drift2)
diffusion = self.diffusion.vf(t, y, args)

# get tree structure of drift and diffusion
drift_tree_structure = jtu.tree_structure(drift)
diffusion_tree_structure = jtu.tree_structure(diffusion)

if drift_tree_structure == diffusion_tree_structure:
# drift and diffusion has the same tree structure
# check the shape to determine how to compute KL
# however, it does not check the abstract yet

if isinstance(drift, jnp.ndarray):
# this case PyTree is (*)

# here we check the abstract level of ControlTerm
if isinstance(self.diffusion, WeaklyDiagonalControlTerm):
# diffusion must be jnp.ndarrary as well because
# diffusion and drift has the same structure
# therefore we don't need to check type of diffusion here
kl_divergence = _kl_diagonal(drift, diffusion)
elif isinstance(self.diffusion, ControlTerm):
kl_divergence = _kl_full_matrix(drift, diffusion)
else:
# a more general case, we assume that on each leave,
# if drift and diffusion have the same shape
# -> WeaklyDiagonalControlTerm
# else
# -> ControlTerm
kl_divergence = _kl_block_diffusion(drift, diffusion)
else:
raise ValueError(
"drift and diffusion should have the same PyTree structure"
+ f" \n {drift_tree_structure} != {diffusion_tree_structure}"
)
return drift1, kl_divergence

def contr(self, t0: RealScalarLike, t1: RealScalarLike) -> Control:
return t1 - t0

def prod(self, vf: PyTree, control: RealScalarLike) -> PyTree:
return jtu.tree_map(lambda v: control * v, vf)


class _AugControlTerm(AbstractTerm):

control_term: AbstractTerm

def __init__(self, term: AbstractTerm) -> None:
self.control_term = term

def vf(self, t: RealScalarLike, y: PyTree, args: PyTree) -> PyTree:
y, _ = y
vf = self.control_term.vf(t, y, args)
return vf, 0.0

def contr(self, t0: RealScalarLike, t1: RealScalarLike) -> PyTree:
return self.control_term.contr(t0, t1), 0.0

def vf_prod(self, t: RealScalarLike, y: PyTree, args: PyTree, control: PyTree) -> PyTree:
y, _ = y
control, _ = control
return self.control_term.vf_prod(t, y, args, control), 0.0

def prod(self, vf: PyTree, control: PyTree) -> PyTree:
vf, _ = vf
control, _ = control
return self.control_term.prod(vf, control), 0.0


def sde_kl_divergence(
drift1: Callable, drift2: Callable, diffusion: AbstractTerm, y0: PyTree
) -> Tuple[MultiTerm, PyTree]:
aug_y0 = (y0, 0.0)
aug_drift = _AugDrift(drift1, drift2, diffusion)
aug_control = _AugControlTerm(diffusion)
aug_sde = MultiTerm(aug_drift, aug_control)
return aug_sde, aug_y0
Loading