Skip to content

Commit 64d4373

Browse files
committed
add predictor corrector sampling
1 parent 0a87694 commit 64d4373

File tree

2 files changed

+55
-3
lines changed

2 files changed

+55
-3
lines changed

bayesflow/networks/diffusion_model/diffusion_model.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -836,9 +836,26 @@ def deltas(time, xz):
836836
def diffusion(time, xz):
837837
return {"xz": self.diffusion_term(xz, time=time, training=training)}
838838

839+
scores = None
840+
if "corrector_steps" in integrate_kwargs:
841+
if integrate_kwargs["corrector_steps"] > 0:
842+
843+
def scores(time, xz):
844+
return {
845+
"xz": self.compositional_score(
846+
xz,
847+
time=time,
848+
conditions=conditions,
849+
compute_prior_score=compute_prior_score,
850+
mini_batch_size=mini_batch_size,
851+
training=training,
852+
)
853+
}
854+
839855
state = integrate_stochastic(
840856
drift_fn=deltas,
841857
diffusion_fn=diffusion,
858+
score_fn=scores,
842859
state=state,
843860
seed=self.seed_generator,
844861
**integrate_kwargs,

bayesflow/utils/integrate.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -401,11 +401,17 @@ def integrate_stochastic(
401401
steps: int,
402402
seed: keras.random.SeedGenerator,
403403
method: str = "euler_maruyama",
404+
score_fn: Callable = None,
405+
corrector_steps: int = 0,
404406
**kwargs,
405407
) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, Sequence[ArrayLike]]]]:
406408
"""
407409
Integrates a stochastic differential equation from start_time to stop_time.
408410
411+
When score_fn is provided, performs predictor-corrector sampling where:
412+
- Predictor: reverse diffusion SDE solver
413+
- Corrector: annealed Langevin dynamics with step size e = sqrt(dim)
414+
409415
Args:
410416
drift_fn: Function that computes the drift term.
411417
diffusion_fn: Function that computes the diffusion term.
@@ -415,11 +421,13 @@ def integrate_stochastic(
415421
steps: Number of integration steps.
416422
seed: Random seed for noise generation.
417423
method: Integration method to use, e.g., 'euler_maruyama'.
424+
score_fn: Optional score function for predictor-corrector sampling.
425+
Should take (time, **state) and return score dict.
426+
corrector_steps: Number of corrector steps to take after each predictor step.
418427
**kwargs: Additional arguments to pass to the step function.
419428
420429
Returns:
421-
If return_noise is False, returns the final state dictionary.
422-
If return_noise is True, returns a tuple of (final_state, noise_history).
430+
Final state dictionary after integration.
423431
"""
424432
if steps <= 0:
425433
raise ValueError("Number of steps must be positive.")
@@ -438,17 +446,44 @@ def integrate_stochastic(
438446
step_size = (stop_time - start_time) / steps
439447
sqrt_dt = keras.ops.sqrt(keras.ops.abs(step_size))
440448

441-
# Pre-generate noise history: shape = (steps, *state_shape)
449+
# Pre-generate noise history for predictor: shape = (steps, *state_shape)
442450
noise_history = {}
443451
for key, val in state.items():
444452
noise_history[key] = (
445453
keras.random.normal((steps, *keras.ops.shape(val)), dtype=keras.ops.dtype(val), seed=seed) * sqrt_dt
446454
)
447455

456+
# Pre-generate corrector noise if score_fn is provided: shape = (steps, corrector_steps, *state_shape)
457+
corrector_noise_history = {}
458+
if score_fn is not None and corrector_steps > 0:
459+
for key, val in state.items():
460+
corrector_noise_history[key] = keras.random.normal(
461+
(steps, corrector_steps, *keras.ops.shape(val)), dtype=keras.ops.dtype(val), seed=seed
462+
)
463+
448464
def body(_loop_var, _loop_state):
449465
_current_state, _current_time = _loop_state
450466
_noise_i = {k: noise_history[k][_loop_var] for k in _current_state.keys()}
467+
468+
# Predictor step
451469
new_state, new_time = step_fn(state=_current_state, time=_current_time, step_size=step_size, noise=_noise_i)
470+
471+
# Corrector steps: annealed Langevin dynamics if score_fn is provided
472+
if score_fn is not None:
473+
first_key = next(iter(new_state.keys()))
474+
dim = keras.ops.cast(keras.ops.shape(new_state[first_key])[-1], keras.ops.dtype(new_state[first_key]))
475+
e = keras.ops.sqrt(dim)
476+
sqrt_2e = keras.ops.sqrt(2.0 * e)
477+
478+
for corrector_step in range(corrector_steps):
479+
score = score_fn(new_time, **filter_kwargs(new_state, score_fn))
480+
_corrector_noise = {k: corrector_noise_history[k][_loop_var, corrector_step] for k in new_state.keys()}
481+
482+
# Corrector update: x_i+1 = x_i + e * score + sqrt(2e) * noise_corrector
483+
for k in new_state.keys():
484+
if k in score:
485+
new_state[k] = new_state[k] + e * score[k] + sqrt_2e * _corrector_noise[k]
486+
452487
return new_state, new_time
453488

454489
final_state, final_time = keras.ops.fori_loop(0, steps, body, (state, start_time))

0 commit comments

Comments
 (0)