Skip to content

Commit f9e40e1

Browse files
committed
3.9 fix2
1 parent 35af3e9 commit f9e40e1

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

diffrax/_solver/kl.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@
3333
)
3434

3535

36-
DiffusionTerm = Union[ControlTerm, WeaklyDiagonalControlTerm]
36+
_DiffusionTerm = Union[ControlTerm, WeaklyDiagonalControlTerm]
3737

3838

3939
def _compute_kl_integral(
4040
drift_term1: ODETerm,
4141
drift_term2: ODETerm,
42-
diffusion_term: DiffusionTerm,
42+
diffusion_term: _DiffusionTerm,
4343
t0: RealScalarLike,
4444
y0: Y,
4545
args: Args,
@@ -95,7 +95,7 @@ def _compute_kl_integral(
9595
class _KLDrift(AbstractTerm):
9696
drift1: ODETerm
9797
drift2: ODETerm
98-
diffusion: DiffusionTerm
98+
diffusion: _DiffusionTerm
9999
linear_solver: lx.AbstractLinearSolver
100100

101101
def vf(self, t: RealScalarLike, y: Y, args: Args) -> Tuple[VF, RealScalarLike]:
@@ -112,7 +112,7 @@ def prod(self, vf: VF, control: RealScalarLike) -> Y:
112112

113113

114114
class _KLControlTerm(AbstractTerm):
115-
control_term: DiffusionTerm
115+
control_term: _DiffusionTerm
116116

117117
def vf(self, t: RealScalarLike, y: Y, args: Args) -> Tuple[VF, RealScalarLike]:
118118
y, _ = y
@@ -160,7 +160,7 @@ class KLSolver(AbstractWrappedSolver[_SolverState]):
160160
The input must be a `MultiTerm` composed of the first SDE with drift `f`
161161
and diffusion `g` and the second either a SDE or just the drift term
162162
(since the diffusion is assumed to be the same). For example, a type
163-
of: `MuliTerm(MultiTerm(ODETerm, DiffusionTerm), ODETerm)`.
163+
of: `MuliTerm(MultiTerm(ODETerm, _DiffusionTerm), ODETerm)`.
164164
165165
??? cite "References"
166166
@@ -260,12 +260,17 @@ def step(
260260
drift_term1, drift_term2 = drift_term1[0], drift_term2[0]
261261

262262
diffusion_term = jtu.tree_map(
263-
lambda x: x if isinstance(x, DiffusionTerm) else None,
263+
lambda x: x
264+
if isinstance(x, WeaklyDiagonalControlTerm) or isinstance(x, ControlTerm)
265+
else None,
264266
terms1,
265-
is_leaf=lambda x: isinstance(x, DiffusionTerm),
267+
is_leaf=lambda x: isinstance(x, WeaklyDiagonalControlTerm)
268+
or isinstance(x, ControlTerm),
266269
)
267270
diffusion_term = jtu.tree_leaves(
268-
diffusion_term, is_leaf=lambda x: isinstance(x, DiffusionTerm)
271+
diffusion_term,
272+
is_leaf=lambda x: isinstance(x, WeaklyDiagonalControlTerm)
273+
or isinstance(x, ControlTerm),
269274
)
270275

271276
diffusion_term = eqx.error_if(

0 commit comments

Comments
 (0)