@@ -401,11 +401,17 @@ def integrate_stochastic(
401401 steps : int ,
402402 seed : keras .random .SeedGenerator ,
403403 method : str = "euler_maruyama" ,
404+ score_fn : Callable = None ,
405+ corrector_steps : int = 0 ,
404406 ** kwargs ,
405407) -> Union [dict [str , ArrayLike ], tuple [dict [str , ArrayLike ], dict [str , Sequence [ArrayLike ]]]]:
406408 """
407409 Integrates a stochastic differential equation from start_time to stop_time.
408410
411+ When score_fn is provided, performs predictor-corrector sampling where:
412+ - Predictor: reverse diffusion SDE solver
413+ - Corrector: annealed Langevin dynamics with step size e = sqrt(dim)
414+
409415 Args:
410416 drift_fn: Function that computes the drift term.
411417 diffusion_fn: Function that computes the diffusion term.
@@ -415,11 +421,13 @@ def integrate_stochastic(
415421 steps: Number of integration steps.
416422 seed: Random seed for noise generation.
417423 method: Integration method to use, e.g., 'euler_maruyama'.
424+ score_fn: Optional score function for predictor-corrector sampling.
425+ Should take (time, **state) and return score dict.
426+ corrector_steps: Number of corrector steps to take after each predictor step.
418427 **kwargs: Additional arguments to pass to the step function.
419428
420429 Returns:
421- If return_noise is False, returns the final state dictionary.
422- If return_noise is True, returns a tuple of (final_state, noise_history).
430+ Final state dictionary after integration.
423431 """
424432 if steps <= 0 :
425433 raise ValueError ("Number of steps must be positive." )
@@ -438,17 +446,44 @@ def integrate_stochastic(
438446 step_size = (stop_time - start_time ) / steps
439447 sqrt_dt = keras .ops .sqrt (keras .ops .abs (step_size ))
440448
441- # Pre-generate noise history: shape = (steps, *state_shape)
449+ # Pre-generate noise history for predictor : shape = (steps, *state_shape)
442450 noise_history = {}
443451 for key , val in state .items ():
444452 noise_history [key ] = (
445453 keras .random .normal ((steps , * keras .ops .shape (val )), dtype = keras .ops .dtype (val ), seed = seed ) * sqrt_dt
446454 )
447455
456+ # Pre-generate corrector noise if score_fn is provided: shape = (steps, corrector_steps, *state_shape)
457+ corrector_noise_history = {}
458+ if score_fn is not None and corrector_steps > 0 :
459+ for key , val in state .items ():
460+ corrector_noise_history [key ] = keras .random .normal (
461+ (steps , corrector_steps , * keras .ops .shape (val )), dtype = keras .ops .dtype (val ), seed = seed
462+ )
463+
448464 def body (_loop_var , _loop_state ):
449465 _current_state , _current_time = _loop_state
450466 _noise_i = {k : noise_history [k ][_loop_var ] for k in _current_state .keys ()}
467+
468+ # Predictor step
451469 new_state , new_time = step_fn (state = _current_state , time = _current_time , step_size = step_size , noise = _noise_i )
470+
471+ # Corrector steps: annealed Langevin dynamics if score_fn is provided
472+ if score_fn is not None :
473+ first_key = next (iter (new_state .keys ()))
474+ dim = keras .ops .cast (keras .ops .shape (new_state [first_key ])[- 1 ], keras .ops .dtype (new_state [first_key ]))
475+ e = keras .ops .sqrt (dim )
476+ sqrt_2e = keras .ops .sqrt (2.0 * e )
477+
478+ for corrector_step in range (corrector_steps ):
479+ score = score_fn (new_time , ** filter_kwargs (new_state , score_fn ))
480+ _corrector_noise = {k : corrector_noise_history [k ][_loop_var , corrector_step ] for k in new_state .keys ()}
481+
482+ # Corrector update: x_i+1 = x_i + e * score + sqrt(2e) * noise_corrector
483+ for k in new_state .keys ():
484+ if k in score :
485+ new_state [k ] = new_state [k ] + e * score [k ] + sqrt_2e * _corrector_noise [k ]
486+
452487 return new_state , new_time
453488
454489 final_state , final_time = keras .ops .fori_loop (0 , steps , body , (state , start_time ))
0 commit comments