1010import jax .random as jr
1111import jax .tree_util as jtu
1212import lineax .internal as lxi
13- from jaxtyping import Array , Float , PRNGKeyArray , PyTree
13+ from jaxtyping import Array , Inexact , PRNGKeyArray , PyTree
14+ from lineax .internal import complex_to_real_dtype
1415
1516from .._custom_types import (
1617 AbstractBrownianIncrement ,
5455# For the midpoint rule for generating space-time Levy area see Theorem 6.1.6.
5556# For the general interpolation rule for space-time Levy area see Theorem 6.1.4.
5657
57- FloatDouble : TypeAlias = tuple [Float [Array , " *shape" ], Float [Array , " *shape" ]]
58+ FloatDouble : TypeAlias = tuple [Inexact [Array , " *shape" ], Inexact [Array , " *shape" ]]
5859FloatTriple : TypeAlias = tuple [
59- Float [Array , " *shape" ], Float [Array , " *shape" ], Float [Array , " *shape" ]
60+ Inexact [Array , " *shape" ], Inexact [Array , " *shape" ], Inexact [Array , " *shape" ]
6061]
6162_Spline : TypeAlias = Literal ["sqrt" , "quad" , "zero" ]
6263_BrownianReturn = TypeVar ("_BrownianReturn" , bound = AbstractBrownianIncrement )
@@ -90,7 +91,7 @@ def _levy_diff(_, x0: tuple, x1: tuple) -> AbstractBrownianIncrement:
9091 assert len (x1 ) == 2
9192 dt0 , w0 = x0
9293 dt1 , w1 = x1
93- su = jnp .asarray (dt1 - dt0 , dtype = w0 .dtype )
94+ su = jnp .asarray (dt1 - dt0 , dtype = complex_to_real_dtype ( w0 .dtype ) )
9495 return BrownianIncrement (dt = su , W = w1 - w0 )
9596
9697 elif len (x0 ) == 4 : # space-time levy area case
@@ -99,12 +100,13 @@ def _levy_diff(_, x0: tuple, x1: tuple) -> AbstractBrownianIncrement:
99100 dt1 , w1 , hh1 , bhh1 = x1
100101
101102 w_su = w1 - w0
102- su = jnp .asarray (dt1 - dt0 , dtype = w0 .dtype )
103+ su = jnp .asarray (dt1 - dt0 , dtype = complex_to_real_dtype ( w0 .dtype ) )
103104 _su = jnp .where (jnp .abs (su ) < jnp .finfo (su ).eps , jnp .inf , su )
104105 inverse_su = 1 / _su
105- u_bb_s = dt1 * w0 - dt0 * w1
106- bhh_su = bhh1 - bhh0 - 0.5 * u_bb_s # bhh_su = H_{s,u} * (u-s)
107- hh_su = inverse_su * bhh_su
106+ with jax .numpy_dtype_promotion ("standard" ):
107+ u_bb_s = dt1 * w0 - dt0 * w1
108+ bhh_su = bhh1 - bhh0 - 0.5 * u_bb_s # bhh_su = H_{s,u} * (u-s)
109+ hh_su = inverse_su * bhh_su
108110 return SpaceTimeLevyArea (dt = su , W = w_su , H = hh_su )
109111 else :
110112 assert False
@@ -135,10 +137,19 @@ def _split_interval(
135137class VirtualBrownianTree (AbstractBrownianPath ):
136138 """Brownian simulation that discretises the interval `[t0, t1]` to tolerance `tol`.
137139
138- Can be initialised with `levy_area` set to `""`, or `"space-time"`.
139- If `levy_area="space_time"`, then it also computes space-time Lévy area `H`.
140- This will impact the Brownian path, so even with the same key, the trajectory will
141- be different depending on the value of `levy_area`.
140+ !!! info "Levy Area"
141+
142+ Can be initialised with `levy_area` set to `diffrax.BrownianIncrement`, or
143+ `diffrax.SpaceTimeLevyArea`. If `levy_area=diffrax.SpaceTimeLevyArea`, then it
144+ also computes space-time Lévy area `H`. This is an additional source of
145+ randomness required for certain stochastic Runge--Kutta solvers; see
146+ [`diffrax.AbstractSRK`][] for more information.
147+
148+ An error will be thrown during tracing if Lévy area is required but is not
149+ available.
150+
151+ The choice here will impact the Brownian path, so even with the same key, the
152+ trajectory will be different depending on the value of `levy_area`.
142153
143154 ??? cite "Reference"
144155
@@ -283,9 +294,10 @@ def _evaluate_leaf(
283294 tuple [RealScalarLike , Array ], tuple [RealScalarLike , Array , Array , Array ]
284295 ]:
285296 shape , dtype = struct .shape , struct .dtype
297+ tdtype = complex_to_real_dtype (dtype )
286298
287- t0 = jnp .zeros ((), dtype )
288- r = jnp .asarray (r , dtype )
299+ t0 = jnp .zeros ((), tdtype )
300+ r = jnp .asarray (r , tdtype )
289301
290302 if self .levy_area is SpaceTimeLevyArea :
291303 state_key , init_key_w , init_key_la = jr .split (key , 3 )
@@ -394,14 +406,33 @@ def _body_fun(_state: _State):
394406 a = d_prime * sr3 * sr_ru_half
395407 b = d_prime * ru3 * sr_ru_half
396408
397- w_sr = sr / su * w_su + 6 * sr * ru / su3 * bhh_su + 2 * (a + b ) / su * x1
398- w_r = w_s + w_sr
399- c = jnp .sqrt (3 * sr3 * ru3 ) / (6 * d )
400- bhh_sr = sr3 / su3 * bhh_su - a * x1 + c * x2
401- bhh_r = bhh_s + bhh_sr + 0.5 * (r * w_s - s * w_r )
409+ with jax .numpy_dtype_promotion ("standard" ):
410+ w_sr = (
411+ sr / su * w_su + 6 * sr * ru / su3 * bhh_su + 2 * (a + b ) / su * x1
412+ )
413+ w_r = w_s + w_sr
414+ c = jnp .sqrt (3 * sr3 * ru3 ) / (6 * d )
415+ bhh_sr = sr3 / su3 * bhh_su - a * x1 + c * x2
416+ bhh_r = bhh_s + bhh_sr + 0.5 * (r * w_s - s * w_r )
402417
403- inverse_r = 1 / jnp .where (jnp .abs (r ) < jnp .finfo (r ).eps , jnp .inf , r )
404- hh_r = inverse_r * bhh_r
418+ inverse_r = 1 / jnp .where (jnp .abs (r ) < jnp .finfo (r ).eps , jnp .inf , r )
419+ hh_r = inverse_r * bhh_r
420+
421+ elif self .levy_area is BrownianIncrement :
422+ with jax .numpy_dtype_promotion ("standard" ):
423+ w_mean = w_s + sr / su * w_su
424+ if self ._spline == "sqrt" :
425+ z = jr .normal (final_state .key , shape , dtype )
426+ bb = jnp .sqrt (sr * ru / su ) * z
427+ elif self ._spline == "quad" :
428+ z = jr .normal (final_state .key , shape , dtype )
429+ bb = (sr * ru / su ) * z
430+ elif self ._spline == "zero" :
431+ bb = jnp .zeros (shape , dtype )
432+ else :
433+ assert False
434+ w_r = w_mean + bb
435+ return r , w_r
405436
406437 elif self .levy_area is BrownianIncrement :
407438 w_mean = w_s + sr / su * w_su
@@ -497,8 +528,8 @@ def _brownian_arch(
497528
498529 w_t = w_s + w_st
499530 w_stu = (w_s , w_t , w_u )
500-
501- bhh_t = bhh_s + bhh_st + 0.5 * (t * w_s - s * w_t )
531+ with jax . numpy_dtype_promotion ( "standard" ):
532+ bhh_t = bhh_s + bhh_st + 0.5 * (t * w_s - s * w_t )
502533 bhh_stu = (bhh_s , bhh_t , bhh_u )
503534 bkk_stu = None
504535 bkk_st_tu = None
0 commit comments