3333)
3434
3535
36- DiffusionTerm = Union [ControlTerm , WeaklyDiagonalControlTerm ]
36+ _DiffusionTerm = Union [ControlTerm , WeaklyDiagonalControlTerm ]
3737
3838
3939def _compute_kl_integral (
4040 drift_term1 : ODETerm ,
4141 drift_term2 : ODETerm ,
42- diffusion_term : DiffusionTerm ,
42+ diffusion_term : _DiffusionTerm ,
4343 t0 : RealScalarLike ,
4444 y0 : Y ,
4545 args : Args ,
@@ -95,7 +95,7 @@ def _compute_kl_integral(
9595class _KLDrift (AbstractTerm ):
9696 drift1 : ODETerm
9797 drift2 : ODETerm
98- diffusion : DiffusionTerm
98+ diffusion : _DiffusionTerm
9999 linear_solver : lx .AbstractLinearSolver
100100
101101 def vf (self , t : RealScalarLike , y : Y , args : Args ) -> Tuple [VF , RealScalarLike ]:
@@ -112,7 +112,7 @@ def prod(self, vf: VF, control: RealScalarLike) -> Y:
112112
113113
114114class _KLControlTerm (AbstractTerm ):
115- control_term : DiffusionTerm
115+ control_term : _DiffusionTerm
116116
117117 def vf (self , t : RealScalarLike , y : Y , args : Args ) -> Tuple [VF , RealScalarLike ]:
118118 y , _ = y
@@ -160,7 +160,7 @@ class KLSolver(AbstractWrappedSolver[_SolverState]):
160160 The input must be a `MultiTerm` composed of the first SDE with drift `f`
161161 and diffusion `g` and the second either a SDE or just the drift term
162162 (since the diffusion is assumed to be the same). For example, a type
163- of: `MuliTerm(MultiTerm(ODETerm, DiffusionTerm ), ODETerm)`.
163+ of: `MuliTerm(MultiTerm(ODETerm, _DiffusionTerm ), ODETerm)`.
164164
165165 ??? cite "References"
166166
@@ -260,12 +260,17 @@ def step(
260260 drift_term1 , drift_term2 = drift_term1 [0 ], drift_term2 [0 ]
261261
262262 diffusion_term = jtu .tree_map (
263- lambda x : x if isinstance (x , DiffusionTerm ) else None ,
263+ lambda x : x
264+ if isinstance (x , WeaklyDiagonalControlTerm ) or isinstance (x , ControlTerm )
265+ else None ,
264266 terms1 ,
265- is_leaf = lambda x : isinstance (x , DiffusionTerm ),
267+ is_leaf = lambda x : isinstance (x , WeaklyDiagonalControlTerm )
268+ or isinstance (x , ControlTerm ),
266269 )
267270 diffusion_term = jtu .tree_leaves (
268- diffusion_term , is_leaf = lambda x : isinstance (x , DiffusionTerm )
271+ diffusion_term ,
272+ is_leaf = lambda x : isinstance (x , WeaklyDiagonalControlTerm )
273+ or isinstance (x , ControlTerm ),
269274 )
270275
271276 diffusion_term = eqx .error_if (
0 commit comments