@@ -404,6 +404,7 @@ def integrate_stochastic(
404404 score_fn : Callable = None ,
405405 corrector_steps : int = 0 ,
406406 noise_schedule = None ,
407+ r : float = 0.1 ,
407408 ** kwargs ,
408409) -> Union [dict [str , ArrayLike ], tuple [dict [str , ArrayLike ], dict [str , Sequence [ArrayLike ]]]]:
409410 """
@@ -426,6 +427,7 @@ def integrate_stochastic(
426427 Should take (time, **state) and return score dict.
427428 corrector_steps: Number of corrector steps to take after each predictor step.
428429 noise_schedule: Noise schedule object for computing lambda_t and alpha_t in corrector.
430+ r: Scaling factor for corrector step size.
429431 **kwargs: Additional arguments to pass to the step function.
430432
431433 Returns:
@@ -482,10 +484,9 @@ def body(_loop_var, _loop_state):
482484 # Compute noise schedule components for corrector step size
483485 log_snr_t = noise_schedule .get_log_snr (t = new_time , training = False )
484486 alpha_t , _ = noise_schedule .get_alpha_sigma (log_snr_t = log_snr_t )
485- lambda_t = keras .ops .exp (- log_snr_t ) # lambda_t from noise schedule
486487
487488 # Corrector update: x_i+1 = x_i + e * score + sqrt(2e) * noise_corrector
488- # where e = 2*alpha_t * (lambda_t * ||z|| / ||score||)**2
489+ # where e = 2*alpha_t * (r * ||z|| / ||score||)**2
489490 for k in new_state .keys ():
490491 if k in score :
491492 z_norm = keras .ops .norm (new_state [k ], axis = - 1 , keepdims = True )
@@ -494,7 +495,7 @@ def body(_loop_var, _loop_state):
494495 # Prevent division by zero
495496 score_norm = keras .ops .maximum (score_norm , 1e-8 )
496497
497- e = 2.0 * alpha_t * (lambda_t * z_norm / score_norm ) ** 2
498+ e = 2.0 * alpha_t * (r * z_norm / score_norm ) ** 2
498499 sqrt_2e = keras .ops .sqrt (2.0 * e )
499500
500501 new_state [k ] = new_state [k ] + e * score [k ] + sqrt_2e * _corrector_noise [k ]
0 commit comments