@@ -336,7 +336,7 @@ def euler_maruyama_step(
336336 noise = {}
337337 for key in diffusion .keys ():
338338 shape = keras .ops .shape (diffusion [key ])
339- noise [key ] = keras .random .normal (shape ) * keras .ops .sqrt (step_size )
339+ noise [key ] = keras .random .normal (shape ) * keras .ops .sqrt (keras . ops . abs ( step_size ) )
340340
341341 # Check if diffusion and noise have the same keys
342342 if set (diffusion .keys ()) != set (noise .keys ()):
@@ -391,7 +391,6 @@ def integrate_stochastic(
391391 steps : int ,
392392 method : str = "euler_maruyama" ,
393393 seed : int = None ,
394- return_noise : bool = False ,
395394 ** kwargs ,
396395) -> Union [dict [str , ArrayLike ], tuple [dict [str , ArrayLike ], dict [str , List [ArrayLike ]]]]:
397396 """
@@ -406,7 +405,6 @@ def integrate_stochastic(
406405 steps: Number of integration steps.
407406 method: Integration method to use ('euler_maruyama').
408407 seed: Random seed for noise generation.
409- return_noise: Whether to return the generated noise terms.
410408 **kwargs: Additional arguments to pass to the step function.
411409
412410 Returns:
@@ -435,31 +433,19 @@ def integrate_stochastic(
435433
436434 time = start_time
437435
438- # Store noise history if requested
439- noise_history = {key : [] for key in state .keys ()} if return_noise else None
440-
441436 def body (_loop_var , _loop_state ):
442437 _state , _time = _loop_state
443438
444439 # Generate noise for this step
445440 _noise = {}
446441 for key in _state .keys ():
447442 shape = keras .ops .shape (_state [key ])
448- _noise [key ] = keras .random .normal (shape ) * keras .ops .sqrt (step_size )
449-
450- # Store noise if requested
451- if return_noise :
452- for key in _noise :
453- noise_history [key ].append (_noise [key ])
443+ _noise [key ] = keras .random .normal (shape ) * keras .ops .sqrt (keras .ops .abs (step_size ))
454444
455445 # Perform integration step
456446 _state , _time , _ = step_fn (_state , _time , step_size , noise = _noise )
457447
458448 return _state , _time
459449
460450 state , time = keras .ops .fori_loop (0 , steps , body , (state , time ))
461-
462- if return_noise :
463- return state , noise_history
464- else :
465- return state
451+ return state
0 commit comments