1616 integrate_stochastic ,
1717 logging ,
1818 tensor_utils ,
19+ filter_kwargs ,
1920)
2021from bayesflow .utils .serialization import serialize , deserialize , serializable
2122
2223from .schedules .noise_schedule import NoiseSchedule
2324from .dispatch import find_noise_schedule
2425
26+ ArrayLike = int | float | Tensor
27+
2528
2629# disable module check, use potential module after moving from experimental
2730@serializable ("bayesflow.networks" , disable_module_check = True )
@@ -840,6 +843,26 @@ def diffusion(time, xz):
840843 seed = self .seed_generator ,
841844 ** integrate_kwargs ,
842845 )
846+ elif integrate_kwargs ["method" ] == "langevin" :
847+
848+ def scores (time , xz ):
849+ return {
850+ "xz" : self .compositional_score (
851+ xz ,
852+ time = time ,
853+ conditions = conditions ,
854+ compute_prior_score = compute_prior_score ,
855+ mini_batch_size = mini_batch_size ,
856+ training = training ,
857+ )
858+ }
859+
860+ state = annealed_langevin (
861+ score_fn = scores ,
862+ state = state ,
863+ seed = self .seed_generator ,
864+ ** filter_kwargs (integrate_kwargs , annealed_langevin ),
865+ )
843866 else :
844867
845868 def deltas (time , xz ):
@@ -859,3 +882,50 @@ def deltas(time, xz):
859882
860883 x = state ["xz" ]
861884 return x
885+
886+
887+ def annealed_langevin (
888+ score_fn : Callable ,
889+ state : dict [str , ArrayLike ],
890+ steps : int ,
891+ seed : keras .random .SeedGenerator ,
892+ L : int = 5 ,
893+ start_time : ArrayLike = None ,
894+ stop_time : ArrayLike = None ,
895+ eps : float = 0.01 ,
896+ ) -> dict [str , ArrayLike ]:
897+ """
898+ Annealed Langevin dynamics for diffusion sampling.
899+
900+ for t = T-1,...,1:
901+ for s = 1,...,L:
902+ eta ~ N(0, I)
903+ theta <- theta + (dt[t]/2) * psi(theta, t) + sqrt(dt[t]) * eta
904+ """
905+ ratio = keras .ops .convert_to_tensor (
906+ (stop_time + eps ) / start_time , dtype = keras .ops .dtype (next (iter (state .values ())))
907+ )
908+
909+ T = steps
910+ # main loops
911+ for t_T in range (T - 1 , 0 , - 1 ):
912+ t = t_T / T
913+ dt = keras .ops .convert_to_tensor (stop_time , dtype = keras .ops .dtype (next (iter (state .values ())))) * (
914+ ratio ** (stop_time - t )
915+ )
916+
917+ sqrt_dt = keras .ops .sqrt (keras .ops .abs (dt ))
918+ # inner L Langevin steps at level t
919+ for _ in range (L ):
920+ # score
921+ drift = score_fn (t , ** filter_kwargs (state , score_fn ))
922+ # noise
923+ eta = {
924+ k : keras .random .normal (keras .ops .shape (v ), dtype = keras .ops .dtype (v ), seed = seed )
925+ for k , v in state .items ()
926+ }
927+
928+ # update
929+ for k , d in drift .items ():
930+ state [k ] = state [k ] + 0.5 * dt * d + sqrt_dt * eta [k ]
931+ return state
0 commit comments