@@ -302,50 +302,39 @@ def euler_maruyama_step(
302302 state : dict [str , ArrayLike ],
303303 time : ArrayLike ,
304304 step_size : ArrayLike ,
305- seed : keras . random . SeedGenerator ,
305+ noise : dict [ str , ArrayLike ] ,
306306) -> (dict [str , ArrayLike ], ArrayLike , ArrayLike ):
307307 """
308308 Performs a single Euler-Maruyama step for stochastic differential equations.
309309
310310 Args:
311- drift_fn: Function that computes the drift term.
312- diffusion_fn: Function that computes the diffusion term.
313- state: Dictionary containing the current state .
314- time: Current time.
315- step_size: Size of the integration step .
316- seed: Random seed for noise generation .
311+ drift_fn: Function computing the drift term f(t, **state) .
312+ diffusion_fn: Function computing the diffusion term g(t, **state) .
313+ state: Current state, mapping variable names to tensors .
314+ time: Current time scalar tensor .
315+ step_size: Time increment dt .
316+ noise: Mapping of variable names to dW noise tensors .
317317
318318 Returns:
319- Tuple of (new_state, new_time, new_step_size).
319+ new_state: Updated state after one Euler-Maruyama step.
320+ new_time: time + dt.
320321 """
321- # Compute drift term
322+ # Compute drift and diffusion
322323 drift = drift_fn (time , ** filter_kwargs (state , drift_fn ))
323-
324- # Compute diffusion term
325324 diffusion = diffusion_fn (time , ** filter_kwargs (state , diffusion_fn ))
326325
327- # Generate noise for this step
328- noise = {}
329- for key in state .keys ():
330- eps = keras .random .normal (keras .ops .shape (state [key ]), dtype = keras .ops .dtype (state [key ]), seed = seed )
331- noise [key ] = eps * keras .ops .sqrt (keras .ops .abs (step_size ))
332-
333- # Check if diffusion and noise have the same keys
326+ # Check noise keys
334327 if set (diffusion .keys ()) != set (noise .keys ()):
335328 raise ValueError ("Keys of diffusion terms and noise do not match." )
336329
337- # Apply updates using Euler-Maruyama formula: dx = f(x)dt + g(x)dW
338- new_state = state .copy ()
339- for key in drift .keys ():
340- if key in diffusion :
341- new_state [key ] = state [key ] + (step_size * drift [key ]) + (diffusion [key ] * noise [key ])
342- else :
343- # If no diffusion term for this variable, apply deterministic update
344- new_state [key ] = state [key ] + step_size * drift [key ]
330+ new_state = {}
331+ for key , d in drift .items ():
332+ base = state [key ] + step_size * d
333+ if key in diffusion : # stochastic update
334+ base = base + diffusion [key ] * noise [key ]
335+ new_state [key ] = base
345336
346- new_time = time + step_size
347-
348- return new_state , new_time
337+ return new_state , time + step_size
349338
350339
351340def integrate_stochastic (
@@ -356,7 +345,7 @@ def integrate_stochastic(
356345 stop_time : ArrayLike ,
357346 steps : int ,
358347 seed : keras .random .SeedGenerator ,
359- method : Literal [ "euler_maruyama" ] = "euler_maruyama" ,
348+ method : str = "euler_maruyama" ,
360349 ** kwargs ,
361350) -> Union [dict [str , ArrayLike ], tuple [dict [str , ArrayLike ], dict [str , Sequence [ArrayLike ]]]]:
362351 """
@@ -370,7 +359,7 @@ def integrate_stochastic(
370359 stop_time: Ending time for integration.
371360 steps: Number of integration steps.
372361 seed: Random seed for noise generation.
373- method: Integration method to use ( 'euler_maruyama') .
362+ method: Integration method to use, e.g., 'euler_maruyama'.
374363 **kwargs: Additional arguments to pass to the step function.
375364
376365 Returns:
@@ -390,18 +379,22 @@ def integrate_stochastic(
390379 # Prepare step function with partial application
391380 step_fn = partial (step_fn , drift_fn = drift_fn , diffusion_fn = diffusion_fn , ** kwargs )
392381
382+ # Time step
393383 step_size = (stop_time - start_time ) / steps
394- time = start_time
395- current_state = state .copy ()
396-
397- # keras.ops.fori_loop does not support keras seed generator in jax
398- for i in range (steps ):
399- # Execute the step with the specific seed for this step
400- current_state , time = step_fn (
401- state = current_state ,
402- time = time ,
403- step_size = step_size ,
404- seed = seed ,
384+ sqrt_dt = keras .ops .sqrt (keras .ops .abs (step_size ))
385+
386+ # Pre-generate noise history: shape = (steps, *state_shape)
387+ noise_history = {}
388+ for key , val in state .items ():
389+ noise_history [key ] = (
390+ keras .random .normal ((steps , * keras .ops .shape (val )), dtype = keras .ops .dtype (val ), seed = seed ) * sqrt_dt
405391 )
406392
407- return current_state
393+ def body (_loop_var , _loop_state ):
394+ _current_state , _current_time = _loop_state
395+ _noise_i = {k : noise_history [k ][_loop_var ] for k in _current_state .keys ()}
396+ new_state , new_time = step_fn (state = _current_state , time = _current_time , step_size = step_size , noise = _noise_i )
397+ return new_state , new_time
398+
399+ final_state , final_time = keras .ops .fori_loop (0 , steps , body , (state , start_time ))
400+ return final_state
0 commit comments