-
-
Notifications
You must be signed in to change notification settings - Fork 168
Add KL divergence terms for Latent SDEs #402
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 2 commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
98c6922
kl
lockwo 3b4e07c
add more training lengths
lockwo 7d62553
Add progress bar
abocquet 83eb8bf
Progress bar tweaks.
patrick-kidger 6c93faa
got Rid of internally using LevyVal in VBT
andyElking 34cbe5c
Parametric control types (#364)
tttc3 09cadab
Fixup test_term static type checking
tttc3 322852a
Create py.typed
lockwo 41c58a6
PID for complex dtype fixes (#391)
Randl 6200412
Merge branch 'patrick-kidger:main' into main
lockwo 085b68d
KL solver wrapper draft
lockwo f603f38
add saves
lockwo 97c1fc4
minor fixes, more to come
lockwo c9e1573
intermediate work
lockwo f21f26d
finalization for review
lockwo 2b56424
forgot saveat
lockwo 633afbd
_control term isn't recognized for some reason
lockwo 55d3c0f
Everything SRK related squashed on top of diffrax/dev (#344)
andyElking 6e34acd
Merge branch 'dev'
lockwo 7ea3127
fix test
lockwo 97ce9b0
Doc tweaks for SRK. (#409)
patrick-kidger 35af3e9
3.9 fix
lockwo f9e40e1
3.9 fix2
lockwo ac28ce4
fixed brownian_tree_times.py (#410)
andyElking a998093
Fixed progress meters with `jax.grad`
patrick-kidger c4deca4
Enable more complex tests, fix related errors (#392)
Randl 1865057
a
lockwo 4cb2475
Enable implicit solvers for complex inputs (#411)
Randl daed558
Bump version
patrick-kidger 6abbd3b
Improved Levy area documentation
patrick-kidger dd633e4
Merge remote-tracking branch 'upstream/dev'
lockwo 7e34706
adjoint implicit fix
lockwo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,175 @@ | ||
| import operator | ||
| from typing import Tuple, Callable | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| import jax.tree_util as jtu | ||
| import equinox as eqx | ||
| from jaxtyping import Array, PyTree | ||
| from ._custom_types import RealScalarLike, Control | ||
| from ._term import ( | ||
| AbstractTerm, | ||
| ControlTerm, | ||
| MultiTerm, | ||
| WeaklyDiagonalControlTerm, | ||
| ) | ||
|
|
||
|
|
||
| def _kl_diagonal(drift: Array, diffusion: Array): | ||
| """This is the case where diffusion matrix is | ||
| a diagonal matrix | ||
| """ | ||
| diffusion = jnp.where( | ||
| jax.lax.stop_gradient(diffusion) > 1e-7, | ||
| diffusion, | ||
| jnp.full_like(diffusion, fill_value=1e-7) * jnp.sign(diffusion), | ||
| ) | ||
| scale = drift / diffusion | ||
| return 0.5 * jnp.sum(scale**2) | ||
|
|
||
|
|
||
| def _kl_full_matrix(drift: Array, diffusion: Array): | ||
| """General case""" | ||
| scale = jnp.linalg.pinv(diffusion) @ drift | ||
| return 0.5 * jnp.sum(scale**2) | ||
|
|
||
| def _handle(drift: Array, diffusion: Array): | ||
| """According to the shape of drift and diffusion, | ||
| select the right way to compute KL divergence | ||
| """ | ||
| eqx.error_if(drift, not eqx.is_array(drift), "Only array drifts are supported") | ||
| eqx.error_if(diffusion, not eqx.is_array(diffusion), "Only array diffusion are supported") | ||
| if drift.shape == diffusion.shape: | ||
| return _kl_diagonal(drift, diffusion) | ||
| else: | ||
| return _kl_full_matrix(drift, diffusion) | ||
|
|
||
|
|
||
| def _kl_block_diffusion(drift: PyTree, diffusion: PyTree): | ||
| """The case where diffusion matrix is a block diagonal matrix""" | ||
| kl = jtu.tree_map( | ||
| _handle, | ||
| drift, | ||
| diffusion, | ||
| ) | ||
|
|
||
| kl = jtu.tree_reduce( | ||
| operator.add, | ||
| kl, | ||
| ) | ||
| return kl | ||
|
|
||
|
|
||
| class _AugDrift(AbstractTerm): | ||
|
|
||
| drift1: Callable | ||
| drift2: Callable | ||
| diffusion: AbstractTerm | ||
|
|
||
| def vf(self, t: RealScalarLike, y: PyTree, args) -> PyTree: | ||
| # In this implementation, we may restricte our case where the | ||
| # diffusion can be a block matrix. Each block can follow | ||
| # different `vf_prod` | ||
| # - PyTree of drift: (*, *, ..., *) : | ||
| # - PyTree of diffusion: (*, *, ..., *) | ||
| # For example, | ||
| # - output of drift can be | ||
| # drift = {"block1": jnp.zeros((2,)), | ||
| # "block2": jnp.zeros((2,)), | ||
| # "block3": jnp.zeros((3,))} | ||
| # - output of diffusion (which mixes between the two types) | ||
| # diffusion = {"block1": jnp.ones((2,)), #-> WeaklyDiagonal | ||
| # "block2": jnp.ones((2, 3)), #-> General case | ||
| # "block3": jnp.ones((3, 4))} #-> General case | ||
| # | ||
| # NOTE: `args` will take `context` as a function (normally, `args` | ||
| # is PyTree) | ||
|
|
||
| y, _ = y | ||
|
|
||
| # check if there is context | ||
| context = args | ||
| aug_y = y if context is None else jnp.concatenate([y, context(t)], axis=-1) | ||
|
|
||
| drift1 = self.drift1(t, aug_y, args) # we can't make these .vf becuase _broadcast_and_upcast | ||
| # requires that aug_y and drift(aug_y) are the same shape, but they aren't | ||
| drift2 = self.drift2(t, y, args) | ||
|
|
||
| drift = jtu.tree_map(operator.sub, drift1, drift2) | ||
| diffusion = self.diffusion.vf(t, y, args) | ||
|
|
||
| # get tree structure of drift and diffusion | ||
| drift_tree_structure = jtu.tree_structure(drift) | ||
| diffusion_tree_structure = jtu.tree_structure(diffusion) | ||
|
|
||
| if drift_tree_structure == diffusion_tree_structure: | ||
| # drift and diffusion has the same tree structure | ||
| # check the shape to determine how to compute KL | ||
| # however, it does not check the abstract yet | ||
|
|
||
| if isinstance(drift, jnp.ndarray): | ||
| # this case PyTree is (*) | ||
|
|
||
| # here we check the abstract level of ControlTerm | ||
| if isinstance(self.diffusion, WeaklyDiagonalControlTerm): | ||
| # diffusion must be jnp.ndarrary as well because | ||
| # diffusion and drift has the same structure | ||
| # therefore we don't need to check type of diffusion here | ||
| kl_divergence = _kl_diagonal(drift, diffusion) | ||
| elif isinstance(self.diffusion, ControlTerm): | ||
| kl_divergence = _kl_full_matrix(drift, diffusion) | ||
| else: | ||
| # a more general case, we assume that on each leave, | ||
| # if drift and diffusion have the same shape | ||
| # -> WeaklyDiagonalControlTerm | ||
| # else | ||
| # -> ControlTerm | ||
| kl_divergence = _kl_block_diffusion(drift, diffusion) | ||
| else: | ||
| raise ValueError( | ||
| "drift and diffusion should have the same PyTree structure" | ||
| + f" \n {drift_tree_structure} != {diffusion_tree_structure}" | ||
| ) | ||
| return drift1, kl_divergence | ||
|
|
||
| def contr(self, t0: RealScalarLike, t1: RealScalarLike) -> Control: | ||
| return t1 - t0 | ||
|
|
||
| def prod(self, vf: PyTree, control: RealScalarLike) -> PyTree: | ||
| return jtu.tree_map(lambda v: control * v, vf) | ||
|
|
||
|
|
||
| class _AugControlTerm(AbstractTerm): | ||
|
|
||
| control_term: AbstractTerm | ||
|
|
||
| def __init__(self, term: AbstractTerm) -> None: | ||
| self.control_term = term | ||
|
|
||
| def vf(self, t: RealScalarLike, y: PyTree, args: PyTree) -> PyTree: | ||
| y, _ = y | ||
| vf = self.control_term.vf(t, y, args) | ||
| return vf, 0.0 | ||
|
|
||
| def contr(self, t0: RealScalarLike, t1: RealScalarLike) -> PyTree: | ||
| return self.control_term.contr(t0, t1), 0.0 | ||
|
|
||
| def vf_prod(self, t: RealScalarLike, y: PyTree, args: PyTree, control: PyTree) -> PyTree: | ||
| y, _ = y | ||
| control, _ = control | ||
| return self.control_term.vf_prod(t, y, args, control), 0.0 | ||
|
|
||
| def prod(self, vf: PyTree, control: PyTree) -> PyTree: | ||
| vf, _ = vf | ||
| control, _ = control | ||
| return self.control_term.prod(vf, control), 0.0 | ||
|
|
||
|
|
||
| def sde_kl_divergence( | ||
| drift1: Callable, drift2: Callable, diffusion: AbstractTerm, y0: PyTree | ||
| ) -> Tuple[MultiTerm, PyTree]: | ||
| aug_y0 = (y0, 0.0) | ||
| aug_drift = _AugDrift(drift1, drift2, diffusion) | ||
| aug_control = _AugControlTerm(diffusion) | ||
| aug_sde = MultiTerm(aug_drift, aug_control) | ||
| return aug_sde, aug_y0 | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.