@@ -403,6 +403,7 @@ def integrate_stochastic(
403403 method : str = "euler_maruyama" ,
404404 score_fn : Callable = None ,
405405 corrector_steps : int = 0 ,
406+ noise_schedule = None ,
406407 ** kwargs ,
407408) -> Union [dict [str , ArrayLike ], tuple [dict [str , ArrayLike ], dict [str , Sequence [ArrayLike ]]]]:
408409 """
@@ -424,6 +425,7 @@ def integrate_stochastic(
424425 score_fn: Optional score function for predictor-corrector sampling.
425426 Should take (time, **state) and return score dict.
426427 corrector_steps: Number of corrector steps to take after each predictor step.
428+ noise_schedule: Noise schedule object for computing lambda_t and alpha_t in corrector.
427429 **kwargs: Additional arguments to pass to the step function.
428430
429431 Returns:
@@ -455,7 +457,10 @@ def integrate_stochastic(
455457
456458 # Pre-generate corrector noise if score_fn is provided: shape = (steps, corrector_steps, *state_shape)
457459 corrector_noise_history = {}
458- if score_fn is not None and corrector_steps > 0 :
460+ if corrector_steps > 0 :
461+ if score_fn is None or noise_schedule is None :
462+ raise ValueError ("Please provide both score_fn and noise_schedule when using corrector_steps > 0." )
463+
459464 for key , val in state .items ():
460465 corrector_noise_history [key ] = keras .random .normal (
461466 (steps , corrector_steps , * keras .ops .shape (val )), dtype = keras .ops .dtype (val ), seed = seed
@@ -469,19 +474,29 @@ def body(_loop_var, _loop_state):
469474 new_state , new_time = step_fn (state = _current_state , time = _current_time , step_size = step_size , noise = _noise_i )
470475
471476 # Corrector steps: annealed Langevin dynamics if score_fn is provided
472- if score_fn is not None :
473- first_key = next (iter (new_state .keys ()))
474- dim = keras .ops .cast (keras .ops .shape (new_state [first_key ])[- 1 ], keras .ops .dtype (new_state [first_key ]))
475- e = keras .ops .sqrt (dim )
476- sqrt_2e = keras .ops .sqrt (2.0 * e )
477-
477+ if corrector_steps > 0 :
478478 for corrector_step in range (corrector_steps ):
479479 score = score_fn (new_time , ** filter_kwargs (new_state , score_fn ))
480480 _corrector_noise = {k : corrector_noise_history [k ][_loop_var , corrector_step ] for k in new_state .keys ()}
481481
482+ # Compute noise schedule components for corrector step size
483+ log_snr_t = noise_schedule .get_log_snr (t = new_time , training = False )
484+ alpha_t , _ = noise_schedule .get_alpha_sigma (log_snr_t = log_snr_t )
485+ lambda_t = keras .ops .exp (- log_snr_t ) # lambda_t from noise schedule
486+
482487 # Corrector update: x_i+1 = x_i + e * score + sqrt(2e) * noise_corrector
488+ # where e = 2*alpha_t * (lambda_t * ||z|| / ||score||)**2
483489 for k in new_state .keys ():
484490 if k in score :
491+ z_norm = keras .ops .norm (new_state [k ], axis = - 1 , keepdims = True )
492+ score_norm = keras .ops .norm (score [k ], axis = - 1 , keepdims = True )
493+
494+ # Prevent division by zero
495+ score_norm = keras .ops .maximum (score_norm , 1e-8 )
496+
497+ e = 2.0 * alpha_t * (lambda_t * z_norm / score_norm ) ** 2
498+ sqrt_2e = keras .ops .sqrt (2.0 * e )
499+
485500 new_state [k ] = new_state [k ] + e * score [k ] + sqrt_2e * _corrector_noise [k ]
486501
487502 return new_state , new_time
0 commit comments