Skip to content

Commit df23f89

Browse files
committed
add annealed_langevin
1 parent 1ac9bff commit df23f89

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed

bayesflow/networks/diffusion_model/diffusion_model.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@
1616
integrate_stochastic,
1717
logging,
1818
tensor_utils,
19+
filter_kwargs,
1920
)
2021
from bayesflow.utils.serialization import serialize, deserialize, serializable
2122

2223
from .schedules.noise_schedule import NoiseSchedule
2324
from .dispatch import find_noise_schedule
2425

26+
ArrayLike = int | float | Tensor
27+
2528

2629
# disable module check, use potential module after moving from experimental
2730
@serializable("bayesflow.networks", disable_module_check=True)
@@ -840,6 +843,26 @@ def diffusion(time, xz):
840843
seed=self.seed_generator,
841844
**integrate_kwargs,
842845
)
846+
elif integrate_kwargs["method"] == "langevin":
847+
848+
def scores(time, xz):
849+
return {
850+
"xz": self.compositional_score(
851+
xz,
852+
time=time,
853+
conditions=conditions,
854+
compute_prior_score=compute_prior_score,
855+
mini_batch_size=mini_batch_size,
856+
training=training,
857+
)
858+
}
859+
860+
state = annealed_langevin(
861+
score_fn=scores,
862+
state=state,
863+
seed=self.seed_generator,
864+
**filter_kwargs(integrate_kwargs, annealed_langevin),
865+
)
843866
else:
844867

845868
def deltas(time, xz):
@@ -859,3 +882,50 @@ def deltas(time, xz):
859882

860883
x = state["xz"]
861884
return x
885+
886+
887+
def annealed_langevin(
888+
score_fn: Callable,
889+
state: dict[str, ArrayLike],
890+
steps: int,
891+
seed: keras.random.SeedGenerator,
892+
L: int = 5,
893+
start_time: ArrayLike = None,
894+
stop_time: ArrayLike = None,
895+
eps: float = 0.01,
896+
) -> dict[str, ArrayLike]:
897+
"""
898+
Annealed Langevin dynamics for diffusion sampling.
899+
900+
for t = T-1,...,1:
901+
for s = 1,...,L:
902+
eta ~ N(0, I)
903+
theta <- theta + (dt[t]/2) * psi(theta, t) + sqrt(dt[t]) * eta
904+
"""
905+
ratio = keras.ops.convert_to_tensor(
906+
(stop_time + eps) / start_time, dtype=keras.ops.dtype(next(iter(state.values())))
907+
)
908+
909+
T = steps
910+
# main loops
911+
for t_T in range(T - 1, 0, -1):
912+
t = t_T / T
913+
dt = keras.ops.convert_to_tensor(stop_time, dtype=keras.ops.dtype(next(iter(state.values())))) * (
914+
ratio ** (stop_time - t)
915+
)
916+
917+
sqrt_dt = keras.ops.sqrt(keras.ops.abs(dt))
918+
# inner L Langevin steps at level t
919+
for _ in range(L):
920+
# score
921+
drift = score_fn(t, **filter_kwargs(state, score_fn))
922+
# noise
923+
eta = {
924+
k: keras.random.normal(keras.ops.shape(v), dtype=keras.ops.dtype(v), seed=seed)
925+
for k, v in state.items()
926+
}
927+
928+
# update
929+
for k, d in drift.items():
930+
state[k] = state[k] + 0.5 * dt * d + sqrt_dt * eta[k]
931+
return state

0 commit comments

Comments
 (0)