@@ -301,7 +301,7 @@ def euler_maruyama_step(
301301 state : dict [str , ArrayLike ],
302302 time : ArrayLike ,
303303 step_size : ArrayLike ,
304- noise : dict [str , ArrayLike ] = None ,
304+ noise : dict [str , ArrayLike ],
305305 tolerance : ArrayLike = 1e-6 ,
306306 min_step_size : ArrayLike = - float ("inf" ),
307307 max_step_size : ArrayLike = float ("inf" ),
@@ -331,13 +331,6 @@ def euler_maruyama_step(
331331 # Compute diffusion term
332332 diffusion = diffusion_fn (time , ** filter_kwargs (state , diffusion_fn ))
333333
334- # Generate noise if not provided
335- if noise is None :
336- noise = {}
337- for key in diffusion .keys ():
338- shape = keras .ops .shape (diffusion [key ])
339- noise [key ] = keras .random .normal (shape ) * keras .ops .sqrt (keras .ops .abs (step_size ))
340-
341334 # Check if diffusion and noise have the same keys
342335 if set (diffusion .keys ()) != set (noise .keys ()):
343336 raise ValueError ("Keys of diffusion terms and noise do not match." )
@@ -414,10 +407,6 @@ def integrate_stochastic(
414407 if steps <= 0 :
415408 raise ValueError ("Number of steps must be positive." )
416409
417- # Set random seed if provided
418- if seed is not None :
419- keras .random .set_seed (seed )
420-
421410 # Select step function based on method
422411 match method :
423412 case "euler_maruyama" :
@@ -440,7 +429,7 @@ def body(_loop_var, _loop_state):
440429 _noise = {}
441430 for key in _state .keys ():
442431 shape = keras .ops .shape (_state [key ])
443- _noise [key ] = keras .random .normal (shape ) * keras .ops .sqrt (keras .ops .abs (step_size ))
432+ _noise [key ] = keras .random .normal (shape , seed = seed ) * keras .ops .sqrt (keras .ops .abs (step_size ))
444433
445434 # Perform integration step
446435 _state , _time , _ = step_fn (_state , _time , step_size , noise = _noise )
0 commit comments