Skip to content

Commit 9ed482d

Browse files
committed
stochastic sampler
1 parent de532c7 commit 9ed482d

File tree

1 file changed

+34
-41
lines changed

1 file changed

+34
-41
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 34 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ class DiffusionModel(InferenceNetwork):
374374
}
375375

376376
INTEGRATE_DEFAULT_CONFIG = {
377-
"method": "euler",
377+
"method": "euler", # or euler_maruyama
378378
"steps": 100,
379379
}
380380

@@ -530,6 +530,7 @@ def velocity(
530530
time: float | Tensor,
531531
conditions: Tensor = None,
532532
training: bool = False,
533+
stochastic_solver: bool = False,
533534
clip_x: bool = False,
534535
) -> Tensor:
535536
# calculate the current noise level and transform into correct shape
@@ -549,44 +550,28 @@ def velocity(
549550
# convert x to score
550551
score = (alpha_t * x_pred - xz) / ops.square(sigma_t)
551552

552-
# compute velocity for the ODE depending on the noise schedule
553+
# compute velocity f, g of the SDE or ODE
553554
f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz)
554-
# out = f - 0.5 * g_squared * score
555-
out = f - g_squared * score
556555

557-
# todo: for the SDE: d(z) = [ f(z, t) - g(t)^2 * score(z, lambda) ] dt + g(t) dW
556+
if stochastic_solver:
557+
# for the SDE: d(z) = [f(z, t) - g(t) ^ 2 * score(z, lambda )] dt + g(t) dW
558+
out = f - g_squared * score
559+
else:
560+
# for the ODE: d(z) = [f(z, t) - 0.5 * g(t) ^ 2 * score(z, lambda )] dt
561+
out = f - 0.5 * g_squared * score
562+
558563
return out
559564

560-
def velocity2(
565+
def compute_diffusion_term(
561566
self,
562567
xz: Tensor,
563568
time: float | Tensor,
564-
conditions: Tensor = None,
565569
training: bool = False,
566-
clip_x: bool = False,
567570
) -> Tensor:
568571
# calculate the current noise level and transform into correct shape
569572
log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz)
570573
log_snr_t = keras.ops.broadcast_to(log_snr_t, keras.ops.shape(xz)[:-1] + (1,))
571-
# alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t, training=training)
572-
573-
# if conditions is None:
574-
# xtc = keras.ops.concatenate([xz, log_snr_t], axis=-1)
575-
# else:
576-
# xtc = keras.ops.concatenate([xz, log_snr_t, conditions], axis=-1)
577-
# pred = self.output_projector(self.subnet(xtc, training=training), training=training)
578-
579-
# x_pred = self.convert_prediction_to_x(
580-
# pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t, clip_x=clip_x
581-
# )
582-
# convert x to score
583-
# score = (alpha_t * x_pred - xz) / ops.square(sigma_t)
584-
585-
# compute velocity for the ODE depending on the noise schedule
586-
f, g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t, x=xz)
587-
# out = f - 0.5 * g_squared * score
588-
# out = f - g_squared * score
589-
574+
g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t)
590575
return ops.sqrt(g_squared)
591576

592577
def _velocity_trace(
@@ -620,6 +605,9 @@ def _forward(
620605
| self.integrate_kwargs
621606
| kwargs
622607
)
608+
if integrate_kwargs["method"] == "euler_maruyama":
609+
raise ValueError("Stoachastic methods are not supported for forward integration.")
610+
623611
if density:
624612

625613
def deltas(time, xz):
@@ -670,6 +658,8 @@ def _inverse(
670658
| kwargs
671659
)
672660
if density:
661+
if integrate_kwargs["method"] == "euler_maruyama":
662+
raise ValueError("Stoachastic methods are not supported for density computation.")
673663

674664
def deltas(time, xz):
675665
v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training)
@@ -689,21 +679,24 @@ def deltas(time, xz):
689679
def deltas(time, xz):
690680
return {"xz": self.velocity(xz, time=time, conditions=conditions, training=training)}
691681

692-
def diffusion(time, xz):
693-
return {"xz": self.velocity2(xz, time=time, conditions=conditions, training=training)}
694-
695682
state = {"xz": z}
696-
# state = integrate(
697-
# deltas,
698-
# state,
699-
# **integrate_kwargs,
700-
# )
701-
state = integrate_stochastic(
702-
deltas,
703-
diffusion,
704-
state,
705-
**integrate_kwargs,
706-
)
683+
if integrate_kwargs["method"] == "euler_maruyama":
684+
685+
def diffusion(time, xz):
686+
return {"xz": self.compute_diffusion_term(xz, time=time, training=training)}
687+
688+
state = integrate_stochastic(
689+
deltas,
690+
diffusion,
691+
state,
692+
**integrate_kwargs,
693+
)
694+
else:
695+
state = integrate(
696+
deltas,
697+
state,
698+
**integrate_kwargs,
699+
)
707700

708701
x = state["xz"]
709702
return x

0 commit comments

Comments
 (0)