Skip to content

Commit 9402941

Browse files
committed
add predictor corrector sampling
1 parent 5b42368 commit 9402941

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

bayesflow/networks/diffusion_model/diffusion_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,7 @@ def score_fn(time, xz):
510510
drift_fn=deltas,
511511
diffusion_fn=diffusion,
512512
score_fn=score_fn,
513+
noise_schedule=self.noise_schedule,
513514
state=state,
514515
seed=self.seed_generator,
515516
**integrate_kwargs,
@@ -911,6 +912,7 @@ def score_fn(time, xz):
911912
drift_fn=deltas,
912913
diffusion_fn=diffusion,
913914
score_fn=score_fn,
915+
noise_schedule=self.noise_schedule,
914916
state=state,
915917
seed=self.seed_generator,
916918
**integrate_kwargs,

bayesflow/utils/integrate.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ def integrate_stochastic(
403403
method: str = "euler_maruyama",
404404
score_fn: Callable = None,
405405
corrector_steps: int = 0,
406+
noise_schedule=None,
406407
**kwargs,
407408
) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, Sequence[ArrayLike]]]]:
408409
"""
@@ -424,6 +425,7 @@ def integrate_stochastic(
424425
score_fn: Optional score function for predictor-corrector sampling.
425426
Should take (time, **state) and return score dict.
426427
corrector_steps: Number of corrector steps to take after each predictor step.
428+
noise_schedule: Noise schedule object for computing lambda_t and alpha_t in corrector.
427429
**kwargs: Additional arguments to pass to the step function.
428430
429431
Returns:
@@ -455,7 +457,10 @@ def integrate_stochastic(
455457

456458
# Pre-generate corrector noise if score_fn is provided: shape = (steps, corrector_steps, *state_shape)
457459
corrector_noise_history = {}
458-
if score_fn is not None and corrector_steps > 0:
460+
if corrector_steps > 0:
461+
if score_fn is None or noise_schedule is None:
462+
raise ValueError("Please provide both score_fn and noise_schedule when using corrector_steps > 0.")
463+
459464
for key, val in state.items():
460465
corrector_noise_history[key] = keras.random.normal(
461466
(steps, corrector_steps, *keras.ops.shape(val)), dtype=keras.ops.dtype(val), seed=seed
@@ -469,19 +474,29 @@ def body(_loop_var, _loop_state):
469474
new_state, new_time = step_fn(state=_current_state, time=_current_time, step_size=step_size, noise=_noise_i)
470475

471476
# Corrector steps: annealed Langevin dynamics if score_fn is provided
472-
if score_fn is not None:
473-
first_key = next(iter(new_state.keys()))
474-
dim = keras.ops.cast(keras.ops.shape(new_state[first_key])[-1], keras.ops.dtype(new_state[first_key]))
475-
e = keras.ops.sqrt(dim)
476-
sqrt_2e = keras.ops.sqrt(2.0 * e)
477-
477+
if corrector_steps > 0:
478478
for corrector_step in range(corrector_steps):
479479
score = score_fn(new_time, **filter_kwargs(new_state, score_fn))
480480
_corrector_noise = {k: corrector_noise_history[k][_loop_var, corrector_step] for k in new_state.keys()}
481481

482+
# Compute noise schedule components for corrector step size
483+
log_snr_t = noise_schedule.get_log_snr(t=new_time, training=False)
484+
alpha_t, _ = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t)
485+
lambda_t = keras.ops.exp(-log_snr_t) # lambda_t from noise schedule
486+
482487
# Corrector update: x_i+1 = x_i + e * score + sqrt(2e) * noise_corrector
488+
# where e = 2*alpha_t * (lambda_t * ||z|| / ||score||)**2
483489
for k in new_state.keys():
484490
if k in score:
491+
z_norm = keras.ops.norm(new_state[k], axis=-1, keepdims=True)
492+
score_norm = keras.ops.norm(score[k], axis=-1, keepdims=True)
493+
494+
# Prevent division by zero
495+
score_norm = keras.ops.maximum(score_norm, 1e-8)
496+
497+
e = 2.0 * alpha_t * (lambda_t * z_norm / score_norm) ** 2
498+
sqrt_2e = keras.ops.sqrt(2.0 * e)
499+
485500
new_state[k] = new_state[k] + e * score[k] + sqrt_2e * _corrector_noise[k]
486501

487502
return new_state, new_time

0 commit comments

Comments
 (0)