Skip to content

Commit 7c527a5

Browse files
committed
add loss types
1 parent 6794342 commit 7c527a5

File tree

1 file changed

+21
-37
lines changed

1 file changed

+21
-37
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 21 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import keras
55
from keras import ops
66
import warnings
7-
from enum import Enum
87

98
from bayesflow.utils.serialization import serialize, deserialize, serializable
109
from bayesflow.types import Tensor, Shape
@@ -22,19 +21,6 @@
2221
)
2322

2423

25-
class VarianceType(Enum):
26-
PRESERVING = "preserving"
27-
EXPLODING = "exploding"
28-
29-
30-
class PredictionType(Enum):
31-
VELOCITY = "velocity"
32-
NOISE = "noise"
33-
X = "x"
34-
F = "F" # EDM
35-
SCORE = "score"
36-
37-
3824
@serializable
3925
class NoiseSchedule(ABC):
4026
r"""Noise schedule for diffusion models. We follow the notation from [1].
@@ -53,7 +39,7 @@ class NoiseSchedule(ABC):
5339
Augmentation: Kingma et al. (2023)
5440
"""
5541

56-
def __init__(self, name: str, variance_type: VarianceType, weighting: str = None):
42+
def __init__(self, name: str, variance_type: str, weighting: str = None):
5743
self.name = name
5844
self.variance_type = variance_type # 'exploding' or 'preserving'
5945
self._log_snr_min = -15 # should be set in the subclasses
@@ -90,9 +76,9 @@ def get_drift_diffusion(self, log_snr_t: Tensor, x: Tensor = None, training: boo
9076
beta = self.derivative_log_snr(log_snr_t=log_snr_t, training=training)
9177
if x is None: # return g^2 only
9278
return beta
93-
if self.variance_type == VarianceType.PRESERVING:
79+
if self.variance_type == "preserving":
9480
f = -0.5 * beta * x
95-
elif self.variance_type == VarianceType.EXPLODING:
81+
elif self.variance_type == "exploding":
9682
f = ops.zeros_like(beta)
9783
else:
9884
raise ValueError(f"Unknown variance type: {self.variance_type}")
@@ -106,11 +92,11 @@ def get_alpha_sigma(self, log_snr_t: Tensor, training: bool) -> tuple[Tensor, Te
10692
sigma(t) = sqrt(sigmoid(-log_snr_t))
10793
For a variance exploding schedule, one should set alpha^2 = 1 and sigma^2 = exp(-lambda)
10894
"""
109-
if self.variance_type == VarianceType.PRESERVING:
95+
if self.variance_type == "preserving":
11096
# variance preserving schedule
11197
alpha_t = ops.sqrt(ops.sigmoid(log_snr_t))
11298
sigma_t = ops.sqrt(ops.sigmoid(-log_snr_t))
113-
elif self.variance_type == VarianceType.EXPLODING:
99+
elif self.variance_type == "exploding":
114100
# variance exploding schedule
115101
alpha_t = ops.ones_like(log_snr_t)
116102
sigma_t = ops.sqrt(ops.exp(-log_snr_t))
@@ -171,9 +157,7 @@ class LinearNoiseSchedule(NoiseSchedule):
171157
"""
172158

173159
def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15):
174-
super().__init__(
175-
name="linear_noise_schedule", variance_type=VarianceType.PRESERVING, weighting="likelihood_weighting"
176-
)
160+
super().__init__(name="linear_noise_schedule", variance_type="preserving", weighting="likelihood_weighting")
177161
self._log_snr_min = min_log_snr
178162
self._log_snr_max = max_log_snr
179163

@@ -228,7 +212,7 @@ class CosineNoiseSchedule(NoiseSchedule):
228212
def __init__(
229213
self, min_log_snr: float = -15, max_log_snr: float = 15, s_shift_cosine: float = 0.0, weighting: str = "sigmoid"
230214
):
231-
super().__init__(name="cosine_noise_schedule", variance_type=VarianceType.PRESERVING, weighting=weighting)
215+
super().__init__(name="cosine_noise_schedule", variance_type="preserving", weighting=weighting)
232216
self._s_shift_cosine = s_shift_cosine
233217
self._log_snr_min = min_log_snr
234218
self._log_snr_max = max_log_snr
@@ -283,7 +267,7 @@ class EDMNoiseSchedule(NoiseSchedule):
283267
"""
284268

285269
def __init__(self, sigma_data: float = 1.0, sigma_min: float = 1e-4, sigma_max: float = 80.0):
286-
super().__init__(name="edm_noise_schedule", variance_type=VarianceType.PRESERVING)
270+
super().__init__(name="edm_noise_schedule", variance_type="preserving")
287271
self.sigma_data = sigma_data
288272
# training settings
289273
self.p_mean = -1.2
@@ -392,7 +376,7 @@ def __init__(
392376
integrate_kwargs: dict[str, any] = None,
393377
subnet_kwargs: dict[str, any] = None,
394378
noise_schedule: str | NoiseSchedule = "cosine",
395-
prediction_type: PredictionType = "velocity",
379+
prediction_type: str = "velocity",
396380
**kwargs,
397381
):
398382
"""
@@ -439,17 +423,17 @@ def __init__(
439423
# validate noise model
440424
self.noise_schedule.validate()
441425

442-
if prediction_type in [PredictionType.NOISE, PredictionType.VELOCITY, PredictionType.F]: # F is EDM
426+
if prediction_type not in ["noise", "velocity", "F"]: # F is EDM
443427
raise ValueError(f"Unknown prediction type: {prediction_type}")
444428
self._prediction_type = prediction_type
445-
if noise_schedule.name == "edm_noise_schedule" and prediction_type != PredictionType.F:
429+
if noise_schedule.name == "edm_noise_schedule" and prediction_type != "F":
446430
warnings.warn(
447431
"EDM noise schedule is build for F-prediction. Consider using F-prediction instead.",
448432
)
449-
self._loss_type = kwargs.get("loss_type", PredictionType.NOISE)
450-
if self._loss_type not in [PredictionType.NOISE, PredictionType.VELOCITY, PredictionType.F]:
433+
self._loss_type = kwargs.get("loss_type", "noise")
434+
if self._loss_type not in ["noise", "velocity", "F"]:
451435
raise ValueError(f"Unknown loss type: {self._loss_type}")
452-
if self._loss_type != PredictionType.NOISE:
436+
if self._loss_type != "noise":
453437
warnings.warn(
454438
"the standard schedules have weighting functions defined for the noise prediction loss. "
455439
"You might want to replace them, if you use a different loss function."
@@ -518,18 +502,18 @@ def convert_prediction_to_x(
518502
self, pred: Tensor, z: Tensor, alpha_t: Tensor, sigma_t: Tensor, log_snr_t: Tensor, clip_x: bool
519503
) -> Tensor:
520504
"""Convert the prediction of the neural network to the x space."""
521-
if self._prediction_type == PredictionType.VELOCITY:
505+
if self._prediction_type == "velocity":
522506
# convert v into x
523507
x = alpha_t * z - sigma_t * pred
524-
elif self._prediction_type == PredictionType.NOISE:
508+
elif self._prediction_type == "noise":
525509
# convert noise prediction into x
526510
x = (z - sigma_t * pred) / alpha_t
527-
elif self._prediction_type == PredictionType.F: # EDM
511+
elif self._prediction_type == "F": # EDM
528512
sigma_data = self.noise_schedule.sigma_data
529513
x1 = (sigma_data**2 * alpha_t) / (ops.exp(-log_snr_t) + sigma_data**2)
530514
x2 = ops.exp(-log_snr_t / 2) * sigma_data / ops.sqrt(ops.exp(-log_snr_t) + sigma_data**2)
531515
x = x1 * z + x2 * pred
532-
elif self._prediction_type == PredictionType.X:
516+
elif self._prediction_type == "x":
533517
x = pred
534518
else: # "score"
535519
x = (z + sigma_t**2 * pred) / alpha_t
@@ -775,16 +759,16 @@ def compute_metrics(
775759
)
776760

777761
# Calculate loss
778-
if self._loss_type == PredictionType.NOISE:
762+
if self._loss_type == "noise":
779763
# convert x to epsilon prediction
780764
noise_pred = (diffused_x - alpha_t * x_pred) / sigma_t
781765
loss = weights_for_snr * ops.mean((noise_pred - eps_t) ** 2, axis=-1)
782-
elif self._loss_type == PredictionType.VELOCITY:
766+
elif self._loss_type == "velocity":
783767
# convert x to velocity prediction
784768
velocity_pred = (alpha_t * diffused_x - x_pred) / sigma_t
785769
v_t = alpha_t * eps_t - sigma_t * x
786770
loss = weights_for_snr * ops.mean((velocity_pred - v_t) ** 2, axis=-1)
787-
elif self._loss_type == PredictionType.F:
771+
elif self._loss_type == "F":
788772
# convert x to F prediction
789773
sigma_data = self.noise_schedule.sigma_data
790774
x1 = ops.sqrt(ops.exp(-log_snr_t) + sigma_data**2) / (ops.exp(-log_snr_t / 2) * sigma_data)

0 commit comments

Comments
 (0)