Skip to content

Commit 0a87694

Browse files
committed
fix annealed_langevin
1 parent df23f89 commit 0a87694

File tree

1 file changed

+16
-19
lines changed

1 file changed

+16
-19
lines changed

bayesflow/networks/diffusion_model/diffusion_model.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,7 @@ def scores(time, xz):
859859

860860
state = annealed_langevin(
861861
score_fn=scores,
862+
noise_schedule=self.noise_schedule,
862863
state=state,
863864
seed=self.seed_generator,
864865
**filter_kwargs(integrate_kwargs, annealed_langevin),
@@ -886,13 +887,14 @@ def deltas(time, xz):
886887

887888
def annealed_langevin(
888889
score_fn: Callable,
890+
noise_schedule: Callable,
889891
state: dict[str, ArrayLike],
890892
steps: int,
891893
seed: keras.random.SeedGenerator,
892-
L: int = 5,
893894
start_time: ArrayLike = None,
894895
stop_time: ArrayLike = None,
895-
eps: float = 0.01,
896+
langevin_corrector_steps: int = 5,
897+
step_size_factor: float = 0.1,
896898
) -> dict[str, ArrayLike]:
897899
"""
898900
Annealed Langevin dynamics for diffusion sampling.
@@ -902,30 +904,25 @@ def annealed_langevin(
902904
eta ~ N(0, I)
903905
theta <- theta + (dt[t]/2) * psi(theta, t) + sqrt(dt[t]) * eta
904906
"""
905-
ratio = keras.ops.convert_to_tensor(
906-
(stop_time + eps) / start_time, dtype=keras.ops.dtype(next(iter(state.values())))
907-
)
907+
log_snr_t = noise_schedule.get_log_snr(t=start_time, training=False)
908+
_, max_sigma_t = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t)
908909

909-
T = steps
910910
# main loops
911-
for t_T in range(T - 1, 0, -1):
912-
t = t_T / T
913-
dt = keras.ops.convert_to_tensor(stop_time, dtype=keras.ops.dtype(next(iter(state.values())))) * (
914-
ratio ** (stop_time - t)
915-
)
916-
917-
sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt))
918-
# inner L Langevin steps at level t
919-
for _ in range(L):
920-
# score
911+
for step in range(steps - 1, 0, -1):
912+
t = step / steps
913+
log_snr_t = noise_schedule.get_log_snr(t=t, training=False)
914+
_, sigma_t = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t)
915+
annealing_step_size = step_size_factor * keras.ops.square(sigma_t / max_sigma_t)
916+
917+
sqrt_dt = keras.ops.sqrt(keras.ops.abs(annealing_step_size))
918+
for _ in range(langevin_corrector_steps):
921919
drift = score_fn(t, **filter_kwargs(state, score_fn))
922-
# noise
923-
eta = {
920+
noise = {
924921
k: keras.random.normal(keras.ops.shape(v), dtype=keras.ops.dtype(v), seed=seed)
925922
for k, v in state.items()
926923
}
927924

928925
# update
929926
for k, d in drift.items():
930-
state[k] = state[k] + 0.5 * dt * d + sqrt_dt * eta[k]
927+
state[k] = state[k] + 0.5 * annealing_step_size * d + sqrt_dt * noise[k]
931928
return state

0 commit comments

Comments
 (0)