Skip to content

Commit 5b42368

Browse files
committed
add predictor corrector sampling
1 parent 64d4373 commit 5b42368

File tree

1 file changed

+68
-13
lines changed

1 file changed

+68
-13
lines changed

bayesflow/networks/diffusion_model/diffusion_model.py

Lines changed: 68 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,55 @@ def _apply_subnet(
246246
else:
247247
return self.subnet(x=xz, t=log_snr, conditions=conditions, training=training)
248248

249+
def score(
250+
self,
251+
xz: Tensor,
252+
time: float | Tensor = None,
253+
log_snr_t: Tensor = None,
254+
conditions: Tensor = None,
255+
training: bool = False,
256+
) -> Tensor:
257+
"""
258+
Computes the score of the target or latent variable `xz`.
259+
260+
Parameters
261+
----------
262+
xz : Tensor
263+
The current state of the latent variable `z`, typically of shape (..., D),
264+
where D is the dimensionality of the latent space.
265+
time : float or Tensor
266+
Scalar or tensor representing the time (or noise level) at which the velocity
267+
should be computed. Will be broadcasted to xz. If None, log_snr_t must be provided.
268+
log_snr_t : Tensor
269+
The log signal-to-noise ratio at time `t`. If None, time must be provided.
270+
conditions : Tensor, optional
271+
Conditional inputs to the network, such as conditioning variables
272+
or encoder outputs. Shape must be broadcastable with `xz`. Default is None.
273+
training : bool, optional
274+
Whether the model is in training mode. Affects behavior of dropout, batch norm,
275+
or other stochastic layers. Default is False.
276+
277+
Returns
278+
-------
279+
Tensor
280+
The velocity tensor of the same shape as `xz`, representing the right-hand
281+
side of the SDE or ODE at the given `time`.
282+
"""
283+
if log_snr_t is None:
284+
log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz)
285+
log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,))
286+
alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t)
287+
288+
subnet_out = self._apply_subnet(
289+
xz, self._transform_log_snr(log_snr_t), conditions=conditions, training=training
290+
)
291+
pred = self.output_projector(subnet_out, training=training)
292+
293+
x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t)
294+
295+
score = (alpha_t * x_pred - xz) / ops.square(sigma_t)
296+
return score
297+
249298
def velocity(
250299
self,
251300
xz: Tensor,
@@ -282,19 +331,10 @@ def velocity(
282331
The velocity tensor of the same shape as `xz`, representing the right-hand
283332
side of the SDE or ODE at the given `time`.
284333
"""
285-
# calculate the current noise level and transform into correct shape
286334
log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz)
287335
log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,))
288-
alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t)
289-
290-
subnet_out = self._apply_subnet(
291-
xz, self._transform_log_snr(log_snr_t), conditions=conditions, training=training
292-
)
293-
pred = self.output_projector(subnet_out, training=training)
294336

295-
x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t)
296-
297-
score = (alpha_t * x_pred - xz) / ops.square(sigma_t)
337+
score = self.score(xz, log_snr_t=log_snr_t, conditions=conditions, training=training)
298338

299339
# compute velocity f, g of the SDE or ODE
300340
f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz, training=training)
@@ -452,9 +492,24 @@ def deltas(time, xz):
452492
def diffusion(time, xz):
453493
return {"xz": self.diffusion_term(xz, time=time, training=training)}
454494

495+
score_fn = None
496+
if "corrector_steps" in integrate_kwargs:
497+
if integrate_kwargs["corrector_steps"] > 0:
498+
499+
def score_fn(time, xz):
500+
return {
501+
"xz": self.score(
502+
xz,
503+
time=time,
504+
conditions=conditions,
505+
training=training,
506+
)
507+
}
508+
455509
state = integrate_stochastic(
456510
drift_fn=deltas,
457511
diffusion_fn=diffusion,
512+
score_fn=score_fn,
458513
state=state,
459514
seed=self.seed_generator,
460515
**integrate_kwargs,
@@ -836,11 +891,11 @@ def deltas(time, xz):
836891
def diffusion(time, xz):
837892
return {"xz": self.diffusion_term(xz, time=time, training=training)}
838893

839-
scores = None
894+
score_fn = None
840895
if "corrector_steps" in integrate_kwargs:
841896
if integrate_kwargs["corrector_steps"] > 0:
842897

843-
def scores(time, xz):
898+
def score_fn(time, xz):
844899
return {
845900
"xz": self.compositional_score(
846901
xz,
@@ -855,7 +910,7 @@ def scores(time, xz):
855910
state = integrate_stochastic(
856911
drift_fn=deltas,
857912
diffusion_fn=diffusion,
858-
score_fn=scores,
913+
score_fn=score_fn,
859914
state=state,
860915
seed=self.seed_generator,
861916
**integrate_kwargs,

0 commit comments

Comments
 (0)