Skip to content

Commit eb96620

Browse files
committed
fix base distribution
1 parent 5b52499 commit eb96620

File tree

1 file changed

+24
-37
lines changed

1 file changed

+24
-37
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 24 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,6 @@ def __init__(self, name: str, variance_type: str):
4545
self.variance_type = variance_type # 'exploding' or 'preserving'
4646
self._log_snr_min = -15 # should be set in the subclasses
4747
self._log_snr_max = 15 # should be set in the subclasses
48-
self.sigma_data = 1.0
49-
50-
@property
51-
def scale_base_distribution(self):
52-
"""Get the scale of the base distribution."""
53-
if self.variance_type == "preserving":
54-
return 1.0
55-
elif self.variance_type == "exploding":
56-
# e.g., EDM is a variance exploding schedule
57-
return ops.sqrt(ops.exp(-self._log_snr_min))
58-
else:
59-
raise ValueError(f"Unknown variance type: {self.variance_type}")
6048

6149
@abstractmethod
6250
def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor:
@@ -106,8 +94,8 @@ def get_alpha_sigma(self, log_snr_t: Tensor, training: bool) -> tuple[Tensor, Te
10694
"""
10795
if self.variance_type == "preserving":
10896
# variance preserving schedule
109-
alpha_t = keras.ops.sqrt(keras.ops.sigmoid(log_snr_t))
110-
sigma_t = keras.ops.sqrt(keras.ops.sigmoid(-log_snr_t))
97+
alpha_t = ops.sqrt(ops.sigmoid(log_snr_t))
98+
sigma_t = ops.sqrt(ops.sigmoid(-log_snr_t))
11199
elif self.variance_type == "exploding":
112100
# variance exploding schedule
113101
alpha_t = ops.ones_like(log_snr_t)
@@ -271,6 +259,7 @@ def from_config(cls, config, custom_objects=None):
271259
class EDMNoiseSchedule(NoiseSchedule):
272260
"""EDM noise schedule for diffusion models. This schedule is based on the EDM paper [1].
273261
This should be used with the F-prediction type in the diffusion model.
262+
Since the schedule is variance exploding, the base distribution is a Gaussian with scale 'sigma_max'.
274263
275264
[1] Elucidating the Design Space of Diffusion-Based Generative Models: Karras et al. (2022)
276265
"""
@@ -301,7 +290,7 @@ def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor:
301290
loc = -2 * self.p_mean
302291
scale = 2 * self.p_std
303292
snr = -(loc + scale * ops.erfinv(2 * t - 1) * math.sqrt(2))
304-
snr = keras.ops.clip(snr, x_min=self._log_snr_min_training, x_max=self._log_snr_max_training)
293+
snr = ops.clip(snr, x_min=self._log_snr_min_training, x_max=self._log_snr_max_training)
305294
else: # sampling
306295
sigma_min_rho = self.sigma_min ** (1 / self.rho)
307296
sigma_max_rho = self.sigma_max ** (1 / self.rho)
@@ -375,7 +364,7 @@ class DiffusionModel(InferenceNetwork):
375364

376365
INTEGRATE_DEFAULT_CONFIG = {
377366
"method": "euler", # or euler_maruyama
378-
"steps": 100,
367+
"steps": 250,
379368
}
380369

381370
def __init__(
@@ -444,9 +433,7 @@ def __init__(
444433
self._clip_max = 5.0
445434

446435
# latent distribution (not configurable)
447-
self.base_distribution = bf.distributions.DiagonalNormal(
448-
mean=0.0, std=self.noise_schedule.scale_base_distribution
449-
)
436+
self.base_distribution = bf.distributions.DiagonalNormal()
450437
self.integrate_kwargs = self.INTEGRATE_DEFAULT_CONFIG | (integrate_kwargs or {})
451438
self.seed_generator = keras.random.SeedGenerator()
452439

@@ -521,7 +508,7 @@ def convert_prediction_to_x(
521508
x = (z + sigma_t**2 * pred) / alpha_t
522509

523510
if clip_x:
524-
x = keras.ops.clip(x, self._clip_min, self._clip_max)
511+
x = ops.clip(x, self._clip_min, self._clip_max)
525512
return x
526513

527514
def velocity(
@@ -535,13 +522,13 @@ def velocity(
535522
) -> Tensor:
536523
# calculate the current noise level and transform into correct shape
537524
log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz)
538-
log_snr_t = keras.ops.broadcast_to(log_snr_t, keras.ops.shape(xz)[:-1] + (1,))
525+
log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,))
539526
alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t, training=training)
540527

541528
if conditions is None:
542-
xtc = keras.ops.concatenate([xz, log_snr_t], axis=-1)
529+
xtc = ops.concatenate([xz, log_snr_t], axis=-1)
543530
else:
544-
xtc = keras.ops.concatenate([xz, log_snr_t, conditions], axis=-1)
531+
xtc = ops.concatenate([xz, log_snr_t, conditions], axis=-1)
545532
pred = self.output_projector(self.subnet(xtc, training=training), training=training)
546533

547534
x_pred = self.convert_prediction_to_x(
@@ -570,7 +557,7 @@ def compute_diffusion_term(
570557
) -> Tensor:
571558
# calculate the current noise level and transform into correct shape
572559
log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz)
573-
log_snr_t = keras.ops.broadcast_to(log_snr_t, keras.ops.shape(xz)[:-1] + (1,))
560+
log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,))
574561
g_squared = self.noise_schedule.get_drift_diffusion(log_snr_t=log_snr_t)
575562
return ops.sqrt(g_squared)
576563

@@ -587,7 +574,7 @@ def f(x):
587574

588575
v, trace = jacobian_trace(f, xz, max_steps=max_steps, seed=self.seed_generator, return_output=True)
589576

590-
return v, keras.ops.expand_dims(trace, axis=-1)
577+
return v, ops.expand_dims(trace, axis=-1)
591578

592579
def _forward(
593580
self,
@@ -616,7 +603,7 @@ def deltas(time, xz):
616603

617604
state = {
618605
"xz": x,
619-
"trace": keras.ops.zeros(keras.ops.shape(x)[:-1] + (1,), dtype=keras.ops.dtype(x)),
606+
"trace": ops.zeros(ops.shape(x)[:-1] + (1,), dtype=ops.dtype(x)),
620607
}
621608
state = integrate(
622609
deltas,
@@ -625,7 +612,7 @@ def deltas(time, xz):
625612
)
626613

627614
z = state["xz"]
628-
log_density = self.base_distribution.log_prob(z) + keras.ops.squeeze(state["trace"], axis=-1)
615+
log_density = self.base_distribution.log_prob(z) + ops.squeeze(state["trace"], axis=-1)
629616

630617
return z, log_density
631618

@@ -669,12 +656,12 @@ def deltas(time, xz):
669656

670657
state = {
671658
"xz": z,
672-
"trace": keras.ops.zeros(keras.ops.shape(z)[:-1] + (1,), dtype=keras.ops.dtype(z)),
659+
"trace": ops.zeros(ops.shape(z)[:-1] + (1,), dtype=ops.dtype(z)),
673660
}
674661
state = integrate(deltas, state, **integrate_kwargs)
675662

676663
x = state["xz"]
677-
log_density = self.base_distribution.log_prob(z) - keras.ops.squeeze(state["trace"], axis=-1)
664+
log_density = self.base_distribution.log_prob(z) - ops.squeeze(state["trace"], axis=-1)
678665

679666
return x, log_density
680667

@@ -723,17 +710,17 @@ def compute_metrics(
723710
training = stage == "training"
724711
noise_schedule_training_stage = stage == "training" or stage == "validation"
725712
if not self.built:
726-
xz_shape = keras.ops.shape(x)
727-
conditions_shape = None if conditions is None else keras.ops.shape(conditions)
713+
xz_shape = ops.shape(x)
714+
conditions_shape = None if conditions is None else ops.shape(conditions)
728715
self.build(xz_shape, conditions_shape)
729716

730717
# sample training diffusion time as low discrepancy sequence to decrease variance
731718
# t_i = \mod (u_0 + i/k, 1)
732719
u0 = keras.random.uniform(shape=(1,), dtype=ops.dtype(x), seed=self.seed_generator)
733-
i = ops.arange(0, keras.ops.shape(x)[0], dtype=ops.dtype(x)) # tensor of indices
734-
t = (u0 + i / ops.cast(keras.ops.shape(x)[0], dtype=ops.dtype(x))) % 1
735-
# i = keras.random.randint((keras.ops.shape(x)[0],), minval=0, maxval=self._timesteps)
736-
# t = keras.ops.cast(i, keras.ops.dtype(x)) / keras.ops.cast(self._timesteps, keras.ops.dtype(x))
720+
i = ops.arange(0, ops.shape(x)[0], dtype=ops.dtype(x)) # tensor of indices
721+
t = (u0 + i / ops.cast(ops.shape(x)[0], dtype=ops.dtype(x))) % 1
722+
# i = keras.random.randint((ops.shape(x)[0],), minval=0, maxval=self._timesteps)
723+
# t = ops.cast(i, ops.dtype(x)) / ops.cast(self._timesteps, ops.dtype(x))
737724

738725
# calculate the noise level
739726
log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t, training=noise_schedule_training_stage), x)
@@ -749,9 +736,9 @@ def compute_metrics(
749736

750737
# calculate output of the network
751738
if conditions is None:
752-
xtc = keras.ops.concatenate([diffused_x, log_snr_t], axis=-1)
739+
xtc = ops.concatenate([diffused_x, log_snr_t], axis=-1)
753740
else:
754-
xtc = keras.ops.concatenate([diffused_x, log_snr_t, conditions], axis=-1)
741+
xtc = ops.concatenate([diffused_x, log_snr_t, conditions], axis=-1)
755742
pred = self.output_projector(self.subnet(xtc, training=training), training=training)
756743

757744
x_pred = self.convert_prediction_to_x(

0 commit comments

Comments
 (0)