Skip to content

Commit e0b3bd5

Browse files
committed
add predictor corrector sampling
1 parent 9402941 commit e0b3bd5

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

bayesflow/utils/integrate.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ def integrate_stochastic(
404404
score_fn: Callable = None,
405405
corrector_steps: int = 0,
406406
noise_schedule=None,
407+
r: float = 0.1,
407408
**kwargs,
408409
) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, Sequence[ArrayLike]]]]:
409410
"""
@@ -426,6 +427,7 @@ def integrate_stochastic(
426427
Should take (time, **state) and return score dict.
427428
corrector_steps: Number of corrector steps to take after each predictor step.
428429
noise_schedule: Noise schedule object for computing lambda_t and alpha_t in corrector.
430+
r: Scaling factor for corrector step size.
429431
**kwargs: Additional arguments to pass to the step function.
430432
431433
Returns:
@@ -482,10 +484,9 @@ def body(_loop_var, _loop_state):
482484
# Compute noise schedule components for corrector step size
483485
log_snr_t = noise_schedule.get_log_snr(t=new_time, training=False)
484486
alpha_t, _ = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t)
485-
lambda_t = keras.ops.exp(-log_snr_t) # lambda_t from noise schedule
486487

487488
# Corrector update: x_i+1 = x_i + e * score + sqrt(2e) * noise_corrector
488-
# where e = 2*alpha_t * (lambda_t * ||z|| / ||score||)**2
489+
# where e = 2*alpha_t * (r * ||z|| / ||score||)**2
489490
for k in new_state.keys():
490491
if k in score:
491492
z_norm = keras.ops.norm(new_state[k], axis=-1, keepdims=True)
@@ -494,7 +495,7 @@ def body(_loop_var, _loop_state):
494495
# Prevent division by zero
495496
score_norm = keras.ops.maximum(score_norm, 1e-8)
496497

497-
e = 2.0 * alpha_t * (lambda_t * z_norm / score_norm) ** 2
498+
e = 2.0 * alpha_t * (r * z_norm / score_norm) ** 2
498499
sqrt_2e = keras.ops.sqrt(2.0 * e)
499500

500501
new_state[k] = new_state[k] + e * score[k] + sqrt_2e * _corrector_noise[k]

0 commit comments

Comments
 (0)