@@ -859,6 +859,7 @@ def scores(time, xz):
859859
860860 state = annealed_langevin (
861861 score_fn = scores ,
862+ noise_schedule = self .noise_schedule ,
862863 state = state ,
863864 seed = self .seed_generator ,
864865 ** filter_kwargs (integrate_kwargs , annealed_langevin ),
@@ -886,13 +887,14 @@ def deltas(time, xz):
886887
887888def annealed_langevin (
888889 score_fn : Callable ,
890+ noise_schedule : Callable ,
889891 state : dict [str , ArrayLike ],
890892 steps : int ,
891893 seed : keras .random .SeedGenerator ,
892- L : int = 5 ,
893894 start_time : ArrayLike = None ,
894895 stop_time : ArrayLike = None ,
895- eps : float = 0.01 ,
896+ langevin_corrector_steps : int = 5 ,
897+ step_size_factor : float = 0.1 ,
896898) -> dict [str , ArrayLike ]:
897899 """
898900 Annealed Langevin dynamics for diffusion sampling.
@@ -902,30 +904,25 @@ def annealed_langevin(
902904 eta ~ N(0, I)
903905 theta <- theta + (dt[t]/2) * psi(theta, t) + sqrt(dt[t]) * eta
904906 """
905- ratio = keras .ops .convert_to_tensor (
906- (stop_time + eps ) / start_time , dtype = keras .ops .dtype (next (iter (state .values ())))
907- )
907+ log_snr_t = noise_schedule .get_log_snr (t = start_time , training = False )
908+ _ , max_sigma_t = noise_schedule .get_alpha_sigma (log_snr_t = log_snr_t )
908909
909- T = steps
910910 # main loops
911- for t_T in range (T - 1 , 0 , - 1 ):
912- t = t_T / T
913- dt = keras .ops .convert_to_tensor (stop_time , dtype = keras .ops .dtype (next (iter (state .values ())))) * (
914- ratio ** (stop_time - t )
915- )
916-
917- sqrt_dt = keras .ops .sqrt (keras .ops .abs (dt ))
918- # inner L Langevin steps at level t
919- for _ in range (L ):
920- # score
911+ for step in range (steps - 1 , 0 , - 1 ):
912+ t = step / steps
913+ log_snr_t = noise_schedule .get_log_snr (t = t , training = False )
914+ _ , sigma_t = noise_schedule .get_alpha_sigma (log_snr_t = log_snr_t )
915+ annealing_step_size = step_size_factor * keras .ops .square (sigma_t / max_sigma_t )
916+
917+ sqrt_dt = keras .ops .sqrt (keras .ops .abs (annealing_step_size ))
918+ for _ in range (langevin_corrector_steps ):
921919 drift = score_fn (t , ** filter_kwargs (state , score_fn ))
922- # noise
923- eta = {
920+ noise = {
924921 k : keras .random .normal (keras .ops .shape (v ), dtype = keras .ops .dtype (v ), seed = seed )
925922 for k , v in state .items ()
926923 }
927924
928925 # update
929926 for k , d in drift .items ():
930- state [k ] = state [k ] + 0.5 * dt * d + sqrt_dt * eta [k ]
927+ state [k ] = state [k ] + 0.5 * annealing_step_size * d + sqrt_dt * noise [k ]
931928 return state
0 commit comments