Skip to content

Commit 6794342

Browse files
committed
add loss types
1 parent 1f15b7d commit 6794342

File tree

1 file changed

+67
-34
lines changed

1 file changed

+67
-34
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 67 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@ class VarianceType(Enum):
2727
EXPLODING = "exploding"
2828

2929

30+
class PredictionType(Enum):
31+
VELOCITY = "velocity"
32+
NOISE = "noise"
33+
X = "x"
34+
F = "F" # EDM
35+
SCORE = "score"
36+
37+
3038
@serializable
3139
class NoiseSchedule(ABC):
3240
r"""Noise schedule for diffusion models. We follow the notation from [1].
@@ -45,11 +53,12 @@ class NoiseSchedule(ABC):
4553
Augmentation: Kingma et al. (2023)
4654
"""
4755

48-
def __init__(self, name: str, variance_type: VarianceType):
56+
def __init__(self, name: str, variance_type: VarianceType, weighting: str = None):
4957
self.name = name
5058
self.variance_type = variance_type # 'exploding' or 'preserving'
5159
self._log_snr_min = -15 # should be set in the subclasses
5260
self._log_snr_max = 15 # should be set in the subclasses
61+
self.weighting = weighting
5362

5463
@abstractmethod
5564
def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor:
@@ -113,10 +122,18 @@ def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
113122
"""Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda). Default is 1.
114123
Generally, weighting functions should be defined for a noise prediction loss.
115124
"""
116-
# sigmoid: ops.sigmoid(-log_snr_t + 2), based on Kingma et al. (2023)
117-
# min-snr with gamma = 5, based on Hang et al. (2023)
118-
# 1 / ops.cosh(log_snr_t / 2) * ops.minimum(ops.ones_like(log_snr_t), gamma * ops.exp(-log_snr_t))
119-
return ops.ones_like(log_snr_t)
125+
if self.weighting is None:
126+
return ops.ones_like(log_snr_t)
127+
elif self.weighting == "sigmoid":
128+
# sigmoid weighting based on Kingma et al. (2023)
129+
return ops.sigmoid(-log_snr_t + 2)
130+
elif self.weighting == "likelihood_weighting":
131+
# likelihood weighting based on Song et al. (2021)
132+
g_squared = self.get_drift_diffusion(log_snr_t=log_snr_t)
133+
sigma_t = self.get_alpha_sigma(log_snr_t=log_snr_t, training=True)[1]
134+
return g_squared / ops.square(sigma_t)
135+
else:
136+
raise ValueError(f"Unknown weighting type: {self.weighting}")
120137

121138
def get_config(self):
122139
return dict(name=self.name, variance_type=self.variance_type)
@@ -154,7 +171,9 @@ class LinearNoiseSchedule(NoiseSchedule):
154171
"""
155172

156173
def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15):
157-
super().__init__(name="linear_noise_schedule", variance_type=VarianceType.PRESERVING)
174+
super().__init__(
175+
name="linear_noise_schedule", variance_type=VarianceType.PRESERVING, weighting="likelihood_weighting"
176+
)
158177
self._log_snr_min = min_log_snr
159178
self._log_snr_max = max_log_snr
160179

@@ -190,14 +209,6 @@ def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
190209
factor = ops.exp(-log_snr_t) / (1 + ops.exp(-log_snr_t))
191210
return -factor * dsnr_dt
192211

193-
def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
194-
"""Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda).
195-
Default is the likelihood weighting based on Song et al. (2021).
196-
"""
197-
g_squared = self.get_drift_diffusion(log_snr_t=log_snr_t)
198-
sigma_t = self.get_alpha_sigma(log_snr_t=log_snr_t, training=True)[1]
199-
return g_squared / ops.square(sigma_t)
200-
201212
def get_config(self):
202213
return dict(min_log_snr=self._log_snr_min, max_log_snr=self._log_snr_max)
203214

@@ -214,8 +225,10 @@ class CosineNoiseSchedule(NoiseSchedule):
214225
[1] Diffusion models beat gans on image synthesis: Dhariwal and Nichol (2022)
215226
"""
216227

217-
def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15, s_shift_cosine: float = 0.0):
218-
super().__init__(name="cosine_noise_schedule", variance_type=VarianceType.PRESERVING)
228+
def __init__(
229+
self, min_log_snr: float = -15, max_log_snr: float = 15, s_shift_cosine: float = 0.0, weighting: str = "sigmoid"
230+
):
231+
super().__init__(name="cosine_noise_schedule", variance_type=VarianceType.PRESERVING, weighting=weighting)
219232
self._s_shift_cosine = s_shift_cosine
220233
self._log_snr_min = min_log_snr
221234
self._log_snr_max = max_log_snr
@@ -252,12 +265,6 @@ def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
252265
factor = ops.exp(-log_snr_t) / (1 + ops.exp(-log_snr_t))
253266
return -factor * dsnr_dt
254267

255-
def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
256-
"""Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda).
257-
Default is the sigmoid weighting based on Kingma et al. (2023).
258-
"""
259-
return ops.sigmoid(-log_snr_t + 2)
260-
261268
def get_config(self):
262269
return dict(min_log_snr=self._log_snr_min, max_log_snr=self._log_snr_max, s_shift_cosine=self._s_shift_cosine)
263270

@@ -345,6 +352,7 @@ def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
345352

346353
def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
347354
"""Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda)."""
355+
# for F-prediction: w = (ops.exp(-log_snr_t) + sigma_data^2) / (ops.exp(-log_snr_t)*sigma_data^2)
348356
return ops.exp(-log_snr_t) / ops.square(self.sigma_data) + 1
349357

350358
def get_config(self):
@@ -384,7 +392,7 @@ def __init__(
384392
integrate_kwargs: dict[str, any] = None,
385393
subnet_kwargs: dict[str, any] = None,
386394
noise_schedule: str | NoiseSchedule = "cosine",
387-
prediction_type: str = "velocity",
395+
prediction_type: PredictionType = "velocity",
388396
**kwargs,
389397
):
390398
"""
@@ -431,13 +439,21 @@ def __init__(
431439
# validate noise model
432440
self.noise_schedule.validate()
433441

434-
if prediction_type not in ["velocity", "noise", "F"]: # F is EDM
442+
if prediction_type in [PredictionType.NOISE, PredictionType.VELOCITY, PredictionType.F]: # F is EDM
435443
raise ValueError(f"Unknown prediction type: {prediction_type}")
436-
self.prediction_type = prediction_type
437-
if noise_schedule.name == "edm_noise_schedule" and prediction_type != "F":
444+
self._prediction_type = prediction_type
445+
if noise_schedule.name == "edm_noise_schedule" and prediction_type != PredictionType.F:
438446
warnings.warn(
439447
"EDM noise schedule is build for F-prediction. Consider using F-prediction instead.",
440448
)
449+
self._loss_type = kwargs.get("loss_type", PredictionType.NOISE)
450+
if self._loss_type not in [PredictionType.NOISE, PredictionType.VELOCITY, PredictionType.F]:
451+
raise ValueError(f"Unknown loss type: {self._loss_type}")
452+
if self._loss_type != PredictionType.NOISE:
453+
warnings.warn(
454+
"the standard schedules have weighting functions defined for the noise prediction loss. "
455+
"You might want to replace them, if you use a different loss function."
456+
)
441457

442458
# clipping of prediction (after it was transformed to x-prediction)
443459
self._clip_min = -5.0
@@ -489,7 +505,8 @@ def get_config(self):
489505
"subnet": self.subnet,
490506
"noise_schedule": self.noise_schedule,
491507
"integrate_kwargs": self.integrate_kwargs,
492-
"prediction_type": self.prediction_type,
508+
"prediction_type": self._prediction_type,
509+
"loss_type": self._loss_type,
493510
}
494511
return base_config | serialize(config)
495512

@@ -501,18 +518,18 @@ def convert_prediction_to_x(
501518
self, pred: Tensor, z: Tensor, alpha_t: Tensor, sigma_t: Tensor, log_snr_t: Tensor, clip_x: bool
502519
) -> Tensor:
503520
"""Convert the prediction of the neural network to the x space."""
504-
if self.prediction_type == "velocity":
521+
if self._prediction_type == PredictionType.VELOCITY:
505522
# convert v into x
506523
x = alpha_t * z - sigma_t * pred
507-
elif self.prediction_type == "noise":
524+
elif self._prediction_type == PredictionType.NOISE:
508525
# convert noise prediction into x
509526
x = (z - sigma_t * pred) / alpha_t
510-
elif self.prediction_type == "F": # EDM
527+
elif self._prediction_type == PredictionType.F: # EDM
511528
sigma_data = self.noise_schedule.sigma_data
512529
x1 = (sigma_data**2 * alpha_t) / (ops.exp(-log_snr_t) + sigma_data**2)
513530
x2 = ops.exp(-log_snr_t / 2) * sigma_data / ops.sqrt(ops.exp(-log_snr_t) + sigma_data**2)
514531
x = x1 * z + x2 * pred
515-
elif self.prediction_type == "x":
532+
elif self._prediction_type == PredictionType.X:
516533
x = pred
517534
else: # "score"
518535
x = (z + sigma_t**2 * pred) / alpha_t
@@ -757,10 +774,26 @@ def compute_metrics(
757774
pred=pred, z=diffused_x, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t, clip_x=False
758775
)
759776

760-
# convert x to epsilon prediction
761-
noise_pred = (diffused_x - alpha_t * x_pred) / sigma_t
762777
# Calculate loss
763-
loss = weights_for_snr * ops.mean((noise_pred - eps_t) ** 2, axis=-1)
778+
if self._loss_type == PredictionType.NOISE:
779+
# convert x to epsilon prediction
780+
noise_pred = (diffused_x - alpha_t * x_pred) / sigma_t
781+
loss = weights_for_snr * ops.mean((noise_pred - eps_t) ** 2, axis=-1)
782+
elif self._loss_type == PredictionType.VELOCITY:
783+
# convert x to velocity prediction
784+
velocity_pred = (alpha_t * diffused_x - x_pred) / sigma_t
785+
v_t = alpha_t * eps_t - sigma_t * x
786+
loss = weights_for_snr * ops.mean((velocity_pred - v_t) ** 2, axis=-1)
787+
elif self._loss_type == PredictionType.F:
788+
# convert x to F prediction
789+
sigma_data = self.noise_schedule.sigma_data
790+
x1 = ops.sqrt(ops.exp(-log_snr_t) + sigma_data**2) / (ops.exp(-log_snr_t / 2) * sigma_data)
791+
x2 = (sigma_data * alpha_t) / (ops.exp(-log_snr_t / 2) * ops.sqrt(ops.exp(-log_snr_t) + sigma_data**2))
792+
f_pred = x1 * x_pred - x2 * diffused_x
793+
f_t = x1 * x - x2 * diffused_x
794+
loss = weights_for_snr * ops.mean((f_pred - f_t) ** 2, axis=-1)
795+
else:
796+
raise ValueError(f"Unknown loss type: {self._loss_type}")
764797

765798
# apply sample weight
766799
loss = weighted_mean(loss, sample_weight)

0 commit comments

Comments
 (0)