Skip to content

Commit 89361f7

Browse files
committed
add predictor corrector sampling
1 parent e0b3bd5 commit 89361f7

File tree

2 files changed

+4
-71
lines changed

2 files changed

+4
-71
lines changed

bayesflow/networks/diffusion_model/diffusion_model.py

Lines changed: 0 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,12 @@
1616
integrate_stochastic,
1717
logging,
1818
tensor_utils,
19-
filter_kwargs,
2019
)
2120
from bayesflow.utils.serialization import serialize, deserialize, serializable
2221

2322
from .schedules.noise_schedule import NoiseSchedule
2423
from .dispatch import find_noise_schedule
2524

26-
ArrayLike = int | float | Tensor
27-
2825

2926
# disable module check, use potential module after moving from experimental
3027
@serializable("bayesflow.networks", disable_module_check=True)
@@ -917,27 +914,6 @@ def score_fn(time, xz):
917914
seed=self.seed_generator,
918915
**integrate_kwargs,
919916
)
920-
elif integrate_kwargs["method"] == "langevin":
921-
922-
def scores(time, xz):
923-
return {
924-
"xz": self.compositional_score(
925-
xz,
926-
time=time,
927-
conditions=conditions,
928-
compute_prior_score=compute_prior_score,
929-
mini_batch_size=mini_batch_size,
930-
training=training,
931-
)
932-
}
933-
934-
state = annealed_langevin(
935-
score_fn=scores,
936-
noise_schedule=self.noise_schedule,
937-
state=state,
938-
seed=self.seed_generator,
939-
**filter_kwargs(integrate_kwargs, annealed_langevin),
940-
)
941917
else:
942918

943919
def deltas(time, xz):
@@ -957,46 +933,3 @@ def deltas(time, xz):
957933

958934
x = state["xz"]
959935
return x
960-
961-
962-
def annealed_langevin(
963-
score_fn: Callable,
964-
noise_schedule: Callable,
965-
state: dict[str, ArrayLike],
966-
steps: int,
967-
seed: keras.random.SeedGenerator,
968-
start_time: ArrayLike = None,
969-
stop_time: ArrayLike = None,
970-
langevin_corrector_steps: int = 5,
971-
step_size_factor: float = 0.1,
972-
) -> dict[str, ArrayLike]:
973-
"""
974-
Annealed Langevin dynamics for diffusion sampling.
975-
976-
for t = T-1,...,1:
977-
for s = 1,...,L:
978-
eta ~ N(0, I)
979-
theta <- theta + (dt[t]/2) * psi(theta, t) + sqrt(dt[t]) * eta
980-
"""
981-
log_snr_t = noise_schedule.get_log_snr(t=start_time, training=False)
982-
_, max_sigma_t = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t)
983-
984-
# main loops
985-
for step in range(steps - 1, 0, -1):
986-
t = step / steps
987-
log_snr_t = noise_schedule.get_log_snr(t=t, training=False)
988-
_, sigma_t = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t)
989-
annealing_step_size = step_size_factor * keras.ops.square(sigma_t / max_sigma_t)
990-
991-
sqrt_dt = keras.ops.sqrt(keras.ops.abs(annealing_step_size))
992-
for _ in range(langevin_corrector_steps):
993-
drift = score_fn(t, **filter_kwargs(state, score_fn))
994-
noise = {
995-
k: keras.random.normal(keras.ops.shape(v), dtype=keras.ops.dtype(v), seed=seed)
996-
for k, v in state.items()
997-
}
998-
999-
# update
1000-
for k, d in drift.items():
1001-
state[k] = state[k] + 0.5 * annealing_step_size * d + sqrt_dt * noise[k]
1002-
return state

bayesflow/utils/integrate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def integrate_stochastic(
404404
score_fn: Callable = None,
405405
corrector_steps: int = 0,
406406
noise_schedule=None,
407-
r: float = 0.1,
407+
step_size_factor: float = 0.1,
408408
**kwargs,
409409
) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, Sequence[ArrayLike]]]]:
410410
"""
@@ -427,7 +427,7 @@ def integrate_stochastic(
427427
Should take (time, **state) and return score dict.
428428
corrector_steps: Number of corrector steps to take after each predictor step.
429429
noise_schedule: Noise schedule object for computing lambda_t and alpha_t in corrector.
430-
r: Scaling factor for corrector step size.
430+
step_size_factor: Scaling factor for corrector step size.
431431
**kwargs: Additional arguments to pass to the step function.
432432
433433
Returns:
@@ -489,13 +489,13 @@ def body(_loop_var, _loop_state):
489489
# where e = 2*alpha_t * (r * ||z|| / ||score||)**2
490490
for k in new_state.keys():
491491
if k in score:
492-
z_norm = keras.ops.norm(new_state[k], axis=-1, keepdims=True)
492+
z_norm = keras.ops.norm(_corrector_noise[k], axis=-1, keepdims=True)
493493
score_norm = keras.ops.norm(score[k], axis=-1, keepdims=True)
494494

495495
# Prevent division by zero
496496
score_norm = keras.ops.maximum(score_norm, 1e-8)
497497

498-
e = 2.0 * alpha_t * (r * z_norm / score_norm) ** 2
498+
e = 2.0 * alpha_t * (step_size_factor * z_norm / score_norm) ** 2
499499
sqrt_2e = keras.ops.sqrt(2.0 * e)
500500

501501
new_state[k] = new_state[k] + e * score[k] + sqrt_2e * _corrector_noise[k]

0 commit comments

Comments
 (0)