Skip to content

Commit 548f51b

Browse files
committed
stochastic sampler fix
1 parent 2fd5a90 commit 548f51b

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -528,9 +528,9 @@ def velocity(
528528
self,
529529
xz: Tensor,
530530
time: float | Tensor,
531+
stochastic_solver: bool,
531532
conditions: Tensor = None,
532533
training: bool = False,
533-
stochastic_solver: bool = False,
534534
clip_x: bool = False,
535535
) -> Tensor:
536536
# calculate the current noise level and transform into correct shape
@@ -583,7 +583,7 @@ def _velocity_trace(
583583
training: bool = False,
584584
) -> (Tensor, Tensor):
585585
def f(x):
586-
return self.velocity(x, time=time, conditions=conditions, training=training)
586+
return self.velocity(x, time=time, stochastic_solver=False, conditions=conditions, training=training)
587587

588588
v, trace = jacobian_trace(f, xz, max_steps=max_steps, seed=self.seed_generator, return_output=True)
589589

@@ -630,7 +630,9 @@ def deltas(time, xz):
630630
return z, log_density
631631

632632
def deltas(time, xz):
633-
return {"xz": self.velocity(xz, time=time, conditions=conditions, training=training)}
633+
return {
634+
"xz": self.velocity(xz, time=time, stochastic_solver=False, conditions=conditions, training=training)
635+
}
634636

635637
state = {"xz": x}
636638
state = integrate(
@@ -676,12 +678,14 @@ def deltas(time, xz):
676678

677679
return x, log_density
678680

679-
def deltas(time, xz):
680-
return {"xz": self.velocity(xz, time=time, conditions=conditions, training=training)}
681-
682681
state = {"xz": z}
683682
if integrate_kwargs["method"] == "euler_maruyama":
684683

684+
def deltas(time, xz):
685+
return {
686+
"xz": self.velocity(xz, time=time, stochastic_solver=True, conditions=conditions, training=training)
687+
}
688+
685689
def diffusion(time, xz):
686690
return {"xz": self.compute_diffusion_term(xz, time=time, training=training)}
687691

@@ -692,6 +696,14 @@ def diffusion(time, xz):
692696
**integrate_kwargs,
693697
)
694698
else:
699+
700+
def deltas(time, xz):
701+
return {
702+
"xz": self.velocity(
703+
xz, time=time, stochastic_solver=False, conditions=conditions, training=training
704+
)
705+
}
706+
695707
state = integrate(
696708
deltas,
697709
state,
@@ -709,6 +721,7 @@ def compute_metrics(
709721
stage: str = "training",
710722
) -> dict[str, Tensor]:
711723
training = stage == "training"
724+
noise_schedule_training_stage = stage == "training" or stage == "validation"
712725
if not self.built:
713726
xz_shape = keras.ops.shape(x)
714727
conditions_shape = None if conditions is None else keras.ops.shape(conditions)
@@ -723,8 +736,10 @@ def compute_metrics(
723736
# t = keras.ops.cast(i, keras.ops.dtype(x)) / keras.ops.cast(self._timesteps, keras.ops.dtype(x))
724737

725738
# calculate the noise level
726-
log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t, training=training), x)
727-
alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t, training=training)
739+
log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t, training=noise_schedule_training_stage), x)
740+
alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(
741+
log_snr_t=log_snr_t, training=noise_schedule_training_stage
742+
)
728743

729744
# generate noise vector
730745
eps_t = keras.random.normal(ops.shape(x), dtype=ops.dtype(x), seed=self.seed_generator)

0 commit comments

Comments
 (0)