1616 integrate_stochastic ,
1717 logging ,
1818 tensor_utils ,
19- filter_kwargs ,
2019)
2120from bayesflow .utils .serialization import serialize , deserialize , serializable
2221
2322from .schedules .noise_schedule import NoiseSchedule
2423from .dispatch import find_noise_schedule
2524
26- ArrayLike = int | float | Tensor
27-
2825
2926# disable module check, use potential module after moving from experimental
3027@serializable ("bayesflow.networks" , disable_module_check = True )
@@ -917,27 +914,6 @@ def score_fn(time, xz):
917914 seed = self .seed_generator ,
918915 ** integrate_kwargs ,
919916 )
920- elif integrate_kwargs ["method" ] == "langevin" :
921-
922- def scores (time , xz ):
923- return {
924- "xz" : self .compositional_score (
925- xz ,
926- time = time ,
927- conditions = conditions ,
928- compute_prior_score = compute_prior_score ,
929- mini_batch_size = mini_batch_size ,
930- training = training ,
931- )
932- }
933-
934- state = annealed_langevin (
935- score_fn = scores ,
936- noise_schedule = self .noise_schedule ,
937- state = state ,
938- seed = self .seed_generator ,
939- ** filter_kwargs (integrate_kwargs , annealed_langevin ),
940- )
941917 else :
942918
943919 def deltas (time , xz ):
@@ -957,46 +933,3 @@ def deltas(time, xz):
957933
958934 x = state ["xz" ]
959935 return x
960-
961-
962- def annealed_langevin (
963- score_fn : Callable ,
964- noise_schedule : Callable ,
965- state : dict [str , ArrayLike ],
966- steps : int ,
967- seed : keras .random .SeedGenerator ,
968- start_time : ArrayLike = None ,
969- stop_time : ArrayLike = None ,
970- langevin_corrector_steps : int = 5 ,
971- step_size_factor : float = 0.1 ,
972- ) -> dict [str , ArrayLike ]:
973- """
974- Annealed Langevin dynamics for diffusion sampling.
975-
976- for t = T-1,...,1:
977- for s = 1,...,L:
978- eta ~ N(0, I)
979- theta <- theta + (dt[t]/2) * psi(theta, t) + sqrt(dt[t]) * eta
980- """
981- log_snr_t = noise_schedule .get_log_snr (t = start_time , training = False )
982- _ , max_sigma_t = noise_schedule .get_alpha_sigma (log_snr_t = log_snr_t )
983-
984- # main loops
985- for step in range (steps - 1 , 0 , - 1 ):
986- t = step / steps
987- log_snr_t = noise_schedule .get_log_snr (t = t , training = False )
988- _ , sigma_t = noise_schedule .get_alpha_sigma (log_snr_t = log_snr_t )
989- annealing_step_size = step_size_factor * keras .ops .square (sigma_t / max_sigma_t )
990-
991- sqrt_dt = keras .ops .sqrt (keras .ops .abs (annealing_step_size ))
992- for _ in range (langevin_corrector_steps ):
993- drift = score_fn (t , ** filter_kwargs (state , score_fn ))
994- noise = {
995- k : keras .random .normal (keras .ops .shape (v ), dtype = keras .ops .dtype (v ), seed = seed )
996- for k , v in state .items ()
997- }
998-
999- # update
1000- for k , d in drift .items ():
1001- state [k ] = state [k ] + 0.5 * annealing_step_size * d + sqrt_dt * noise [k ]
1002- return state
0 commit comments