Skip to content

Commit e380f5e

Browse files
committed
cleanup: remove linear schedule, minor fixes
- set default sigma_data when F-prediction is used with other schedules than EDM - modify clip_x behavior - minor changes to docstrings/comments
1 parent f235671 commit e380f5e

File tree

1 file changed

+22
-75
lines changed

1 file changed

+22
-75
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 22 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -147,66 +147,12 @@ def validate(self):
147147
raise ValueError("dt/t log_snr(1) must be finite.")
148148

149149

150-
@serializable
151-
class LinearNoiseSchedule(NoiseSchedule):
152-
"""Linear noise schedule for diffusion models.
153-
154-
The linear noise schedule with likelihood weighting is based on [1].
155-
156-
[1] Maximum Likelihood Training of Score-Based Diffusion Models: Song et al. (2021)
157-
"""
158-
159-
def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15):
160-
super().__init__(name="linear_noise_schedule", variance_type="preserving", weighting="likelihood_weighting")
161-
self.log_snr_min = min_log_snr
162-
self.log_snr_max = max_log_snr
163-
164-
self._t_min = self.get_t_from_log_snr(log_snr_t=self.log_snr_max, training=True)
165-
self._t_max = self.get_t_from_log_snr(log_snr_t=self.log_snr_min, training=True)
166-
167-
def _truncated_t(self, t: Tensor) -> Tensor:
168-
return self._t_min + (self._t_max - self._t_min) * t
169-
170-
def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor:
171-
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
172-
t_trunc = self._truncated_t(t)
173-
# SNR = -log(exp(t^2) - 1)
174-
# equivalent, but more stable: -t^2 - log(1 - exp(-t^2))
175-
return -ops.square(t_trunc) - ops.log(1 - ops.exp(-ops.square(t_trunc)))
176-
177-
def get_t_from_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) -> Tensor:
178-
"""Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
179-
# SNR = -log(exp(t^2) - 1) => t = sqrt(log(1 + exp(-snr)))
180-
return ops.sqrt(ops.softplus(-log_snr_t))
181-
182-
def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
183-
"""Compute d/dt log(1 + e^(-snr(t))), which is used for the reverse SDE."""
184-
t = self.get_t_from_log_snr(log_snr_t=log_snr_t, training=training)
185-
186-
# Compute the truncated time t_trunc
187-
t_trunc = self._truncated_t(t)
188-
dsnr_dx = -2 * t_trunc / (1 - ops.exp(-(t_trunc**2)))
189-
190-
# Using the chain rule on f(t) = log(1 + e^(-snr(t))):
191-
# f'(t) = - (e^{-snr(t)} / (1 + e^{-snr(t)})) * dsnr_dt
192-
dsnr_dt = dsnr_dx * (self._t_max - self._t_min)
193-
factor = ops.exp(-log_snr_t) / (1 + ops.exp(-log_snr_t))
194-
return -factor * dsnr_dt
195-
196-
def get_config(self):
197-
return dict(min_log_snr=self.log_snr_min, max_log_snr=self.log_snr_max)
198-
199-
@classmethod
200-
def from_config(cls, config, custom_objects=None):
201-
return cls(**deserialize(config, custom_objects=custom_objects))
202-
203-
204150
@serializable
205151
class CosineNoiseSchedule(NoiseSchedule):
206152
"""Cosine noise schedule for diffusion models. This schedule is based on the cosine schedule from [1].
207153
For images, use s_shift_cosine = log(base_resolution / d), where d is the used resolution of the image.
208154
209-
[1] Diffusion models beat gans on image synthesis: Dhariwal and Nichol (2022)
155+
[1] Diffusion Models Beat GANs on Image Synthesis: Dhariwal and Nichol (2022)
210156
"""
211157

212158
def __init__(
@@ -371,6 +317,7 @@ class DiffusionModel(InferenceNetwork):
371317

372318
def __init__(
373319
self,
320+
*,
374321
subnet: str | type = "mlp",
375322
integrate_kwargs: dict[str, any] = None,
376323
subnet_kwargs: dict[str, any] = None,
@@ -384,8 +331,8 @@ def __init__(
384331
This model learns a transformation from a Gaussian latent distribution to a target distribution using a
385332
specified subnet type, which can be an MLP or a custom network.
386333
387-
The integration steps can be customized with additional parameters available in the respective
388-
configuration dictionary.
334+
The integration can be customized with additional parameters available in the integrate_kwargs
335+
configuration dictionary. Different noise schedules and prediction types are available.
389336
390337
Parameters
391338
----------
@@ -397,7 +344,7 @@ def __init__(
397344
subnet_kwargs : dict[str, any], optional
398345
Keyword arguments passed to the subnet constructor or used to update the default MLP settings.
399346
noise_schedule : str or NoiseSchedule, optional
400-
The noise schedule used for the diffusion process. Can be "linear", "cosine", or "edm".
347+
The noise schedule used for the diffusion process. Can be "cosine" or "edm".
401348
Default is "edm".
402349
prediction_type: str, optional
403350
The type of prediction used in the diffusion model. Can be "velocity", "noise" or "F" (EDM).
@@ -408,9 +355,7 @@ def __init__(
408355
super().__init__(base_distribution="normal", **kwargs)
409356

410357
if isinstance(noise_schedule, str):
411-
if noise_schedule == "linear":
412-
noise_schedule = LinearNoiseSchedule()
413-
elif noise_schedule == "cosine":
358+
if noise_schedule == "cosine":
414359
noise_schedule = CosineNoiseSchedule()
415360
elif noise_schedule == "edm":
416361
noise_schedule = EDMNoiseSchedule()
@@ -435,10 +380,12 @@ def __init__(
435380
)
436381

437382
# clipping of prediction (after it was transformed to x-prediction)
438-
self._clip_min = -5.0
439-
self._clip_max = 5.0
383+
# keeping this private for now, as it is usually not required in SBI and somewhat dangerous
384+
self._clip_x = kwargs.get("clip_x", None)
385+
if self._clip_x is not None:
386+
if len(self._clip_x) != 2 or self._clip_x[0] > self._clip_x[1]:
387+
raise ValueError("'clip_x' has to be a list or tuple with the values [x_min, x_max]")
440388

441-
# latent distribution (not configurable)
442389
self.integrate_kwargs = self.INTEGRATE_DEFAULT_CONFIG | (integrate_kwargs or {})
443390
self.seed_generator = keras.random.SeedGenerator()
444391

@@ -456,6 +403,8 @@ def __init__(
456403
self.subnet = find_network(subnet, **subnet_kwargs)
457404
self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros")
458405

406+
self._kwargs = kwargs
407+
459408
def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
460409
if self.built:
461410
return
@@ -480,7 +429,7 @@ def get_config(self):
480429
base_config = super().get_config()
481430
base_config = layer_kwargs(base_config)
482431

483-
config = {
432+
config = self._kwargs | {
484433
"subnet": self.subnet,
485434
"noise_schedule": self.noise_schedule,
486435
"integrate_kwargs": self.integrate_kwargs,
@@ -494,7 +443,7 @@ def from_config(cls, config, custom_objects=None):
494443
return cls(**deserialize(config, custom_objects=custom_objects))
495444

496445
def convert_prediction_to_x(
497-
self, pred: Tensor, z: Tensor, alpha_t: Tensor, sigma_t: Tensor, log_snr_t: Tensor, clip_x: bool
446+
self, pred: Tensor, z: Tensor, alpha_t: Tensor, sigma_t: Tensor, log_snr_t: Tensor
498447
) -> Tensor:
499448
"""Convert the prediction of the neural network to the x space."""
500449
if self._prediction_type == "velocity":
@@ -504,7 +453,7 @@ def convert_prediction_to_x(
504453
# convert noise prediction into x
505454
x = (z - sigma_t * pred) / alpha_t
506455
elif self._prediction_type == "F": # EDM
507-
sigma_data = self.noise_schedule.sigma_data
456+
sigma_data = self.noise_schedule.sigma_data if hasattr(self.noise_schedule, "sigma_data") else 1.0
508457
x1 = (sigma_data**2 * alpha_t) / (ops.exp(-log_snr_t) + sigma_data**2)
509458
x2 = ops.exp(-log_snr_t / 2) * sigma_data / ops.sqrt(ops.exp(-log_snr_t) + sigma_data**2)
510459
x = x1 * z + x2 * pred
@@ -513,8 +462,8 @@ def convert_prediction_to_x(
513462
else: # "score"
514463
x = (z + sigma_t**2 * pred) / alpha_t
515464

516-
if clip_x:
517-
x = ops.clip(x, self._clip_min, self._clip_max)
465+
if self._clip_x is not None:
466+
x = ops.clip(x, self._clip_x[0], self._clip_x[1])
518467
return x
519468

520469
def velocity(
@@ -524,7 +473,6 @@ def velocity(
524473
stochastic_solver: bool,
525474
conditions: Tensor = None,
526475
training: bool = False,
527-
clip_x: bool = False,
528476
) -> Tensor:
529477
# calculate the current noise level and transform into correct shape
530478
log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz)
@@ -537,9 +485,7 @@ def velocity(
537485
xtc = ops.concatenate([xz, self._transform_log_snr(log_snr_t), conditions], axis=-1)
538486
pred = self.output_projector(self.subnet(xtc, training=training), training=training)
539487

540-
x_pred = self.convert_prediction_to_x(
541-
pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t, clip_x=clip_x
542-
)
488+
x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t)
543489
# convert x to score
544490
score = (alpha_t * x_pred - xz) / ops.square(sigma_t)
545491

@@ -725,6 +671,7 @@ def compute_metrics(
725671
stage: str = "training",
726672
) -> dict[str, Tensor]:
727673
training = stage == "training"
674+
# use same noise schedule for training and validation to keep them comparable
728675
noise_schedule_training_stage = stage == "training" or stage == "validation"
729676
if not self.built:
730677
xz_shape = ops.shape(x)
@@ -760,7 +707,7 @@ def compute_metrics(
760707
pred = self.output_projector(self.subnet(xtc, training=training), training=training)
761708

762709
x_pred = self.convert_prediction_to_x(
763-
pred=pred, z=diffused_x, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t, clip_x=False
710+
pred=pred, z=diffused_x, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t
764711
)
765712

766713
# Calculate loss
@@ -775,7 +722,7 @@ def compute_metrics(
775722
loss = weights_for_snr * ops.mean((velocity_pred - v_t) ** 2, axis=-1)
776723
elif self._loss_type == "F":
777724
# convert x to F prediction
778-
sigma_data = self.noise_schedule.sigma_data
725+
sigma_data = self.noise_schedule.sigma_data if hasattr(self.noise_schedule, "sigma_data") else 1.0
779726
x1 = ops.sqrt(ops.exp(-log_snr_t) + sigma_data**2) / (ops.exp(-log_snr_t / 2) * sigma_data)
780727
x2 = (sigma_data * alpha_t) / (ops.exp(-log_snr_t / 2) * ops.sqrt(ops.exp(-log_snr_t) + sigma_data**2))
781728
f_pred = x1 * x_pred - x2 * diffused_x

0 commit comments

Comments
 (0)