Skip to content

Commit 5ca609f

Browse files
committed
scale snr
1 parent 7c527a5 commit 5ca609f

File tree

1 file changed

+48
-39
lines changed

1 file changed

+48
-39
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 48 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ class NoiseSchedule(ABC):
4141

4242
def __init__(self, name: str, variance_type: str, weighting: str = None):
4343
self.name = name
44-
self.variance_type = variance_type # 'exploding' or 'preserving'
45-
self._log_snr_min = -15 # should be set in the subclasses
46-
self._log_snr_max = 15 # should be set in the subclasses
47-
self.weighting = weighting
44+
self._variance_type = variance_type # 'exploding' or 'preserving'
45+
self.log_snr_min = -15 # should be set in the subclasses
46+
self.log_snr_max = 15 # should be set in the subclasses
47+
self._weighting = weighting
4848

4949
@abstractmethod
5050
def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor:
@@ -76,12 +76,12 @@ def get_drift_diffusion(self, log_snr_t: Tensor, x: Tensor = None, training: boo
7676
beta = self.derivative_log_snr(log_snr_t=log_snr_t, training=training)
7777
if x is None: # return g^2 only
7878
return beta
79-
if self.variance_type == "preserving":
79+
if self._variance_type == "preserving":
8080
f = -0.5 * beta * x
81-
elif self.variance_type == "exploding":
81+
elif self._variance_type == "exploding":
8282
f = ops.zeros_like(beta)
8383
else:
84-
raise ValueError(f"Unknown variance type: {self.variance_type}")
84+
raise ValueError(f"Unknown variance type: {self._variance_type}")
8585
return f, beta
8686

8787
def get_alpha_sigma(self, log_snr_t: Tensor, training: bool) -> tuple[Tensor, Tensor]:
@@ -92,58 +92,58 @@ def get_alpha_sigma(self, log_snr_t: Tensor, training: bool) -> tuple[Tensor, Te
9292
sigma(t) = sqrt(sigmoid(-log_snr_t))
9393
For a variance exploding schedule, one should set alpha^2 = 1 and sigma^2 = exp(-lambda)
9494
"""
95-
if self.variance_type == "preserving":
95+
if self._variance_type == "preserving":
9696
# variance preserving schedule
9797
alpha_t = ops.sqrt(ops.sigmoid(log_snr_t))
9898
sigma_t = ops.sqrt(ops.sigmoid(-log_snr_t))
99-
elif self.variance_type == "exploding":
99+
elif self._variance_type == "exploding":
100100
# variance exploding schedule
101101
alpha_t = ops.ones_like(log_snr_t)
102102
sigma_t = ops.sqrt(ops.exp(-log_snr_t))
103103
else:
104-
raise ValueError(f"Unknown variance type: {self.variance_type}")
104+
raise ValueError(f"Unknown variance type: {self._variance_type}")
105105
return alpha_t, sigma_t
106106

107107
def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
108108
"""Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda). Default is 1.
109109
Generally, weighting functions should be defined for a noise prediction loss.
110110
"""
111-
if self.weighting is None:
111+
if self._weighting is None:
112112
return ops.ones_like(log_snr_t)
113-
elif self.weighting == "sigmoid":
113+
elif self._weighting == "sigmoid":
114114
# sigmoid weighting based on Kingma et al. (2023)
115115
return ops.sigmoid(-log_snr_t + 2)
116-
elif self.weighting == "likelihood_weighting":
116+
elif self._weighting == "likelihood_weighting":
117117
# likelihood weighting based on Song et al. (2021)
118118
g_squared = self.get_drift_diffusion(log_snr_t=log_snr_t)
119119
sigma_t = self.get_alpha_sigma(log_snr_t=log_snr_t, training=True)[1]
120120
return g_squared / ops.square(sigma_t)
121121
else:
122-
raise ValueError(f"Unknown weighting type: {self.weighting}")
122+
raise ValueError(f"Unknown weighting type: {self._weighting}")
123123

124124
def get_config(self):
125-
return dict(name=self.name, variance_type=self.variance_type)
125+
return dict(name=self.name, variance_type=self._variance_type)
126126

127127
@classmethod
128128
def from_config(cls, config, custom_objects=None):
129129
return cls(**deserialize(config, custom_objects=custom_objects))
130130

131131
def validate(self):
132132
"""Validate the noise schedule."""
133-
if self._log_snr_min >= self._log_snr_max:
133+
if self.log_snr_min >= self.log_snr_max:
134134
raise ValueError("min_log_snr must be less than max_log_snr.")
135135
for training in [True, False]:
136136
if not ops.isfinite(self.get_log_snr(0.0, training=training)):
137137
raise ValueError("log_snr(0) must be finite.")
138138
if not ops.isfinite(self.get_log_snr(1.0, training=training)):
139139
raise ValueError("log_snr(1) must be finite.")
140-
if not ops.isfinite(self.get_t_from_log_snr(self._log_snr_max, training=training)):
140+
if not ops.isfinite(self.get_t_from_log_snr(self.log_snr_max, training=training)):
141141
raise ValueError("t(0) must be finite.")
142-
if not ops.isfinite(self.get_t_from_log_snr(self._log_snr_min, training=training)):
142+
if not ops.isfinite(self.get_t_from_log_snr(self.log_snr_min, training=training)):
143143
raise ValueError("t(1) must be finite.")
144-
if not ops.isfinite(self.derivative_log_snr(self._log_snr_max, training=False)):
144+
if not ops.isfinite(self.derivative_log_snr(self.log_snr_max, training=False)):
145145
raise ValueError("dt/t log_snr(0) must be finite.")
146-
if not ops.isfinite(self.derivative_log_snr(self._log_snr_min, training=False)):
146+
if not ops.isfinite(self.derivative_log_snr(self.log_snr_min, training=False)):
147147
raise ValueError("dt/t log_snr(1) must be finite.")
148148

149149

@@ -158,11 +158,11 @@ class LinearNoiseSchedule(NoiseSchedule):
158158

159159
def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15):
160160
super().__init__(name="linear_noise_schedule", variance_type="preserving", weighting="likelihood_weighting")
161-
self._log_snr_min = min_log_snr
162-
self._log_snr_max = max_log_snr
161+
self.log_snr_min = min_log_snr
162+
self.log_snr_max = max_log_snr
163163

164-
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
165-
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_min, training=True)
164+
self._t_min = self.get_t_from_log_snr(log_snr_t=self.log_snr_max, training=True)
165+
self._t_max = self.get_t_from_log_snr(log_snr_t=self.log_snr_min, training=True)
166166

167167
def _truncated_t(self, t: Tensor) -> Tensor:
168168
return self._t_min + (self._t_max - self._t_min) * t
@@ -194,7 +194,7 @@ def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
194194
return -factor * dsnr_dt
195195

196196
def get_config(self):
197-
return dict(min_log_snr=self._log_snr_min, max_log_snr=self._log_snr_max)
197+
return dict(min_log_snr=self.log_snr_min, max_log_snr=self.log_snr_max)
198198

199199
@classmethod
200200
def from_config(cls, config, custom_objects=None):
@@ -214,12 +214,11 @@ def __init__(
214214
):
215215
super().__init__(name="cosine_noise_schedule", variance_type="preserving", weighting=weighting)
216216
self._s_shift_cosine = s_shift_cosine
217-
self._log_snr_min = min_log_snr
218-
self._log_snr_max = max_log_snr
219-
self._s_shift_cosine = s_shift_cosine
217+
self.log_snr_min = min_log_snr
218+
self.log_snr_max = max_log_snr
220219

221-
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
222-
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_min, training=True)
220+
self._t_min = self.get_t_from_log_snr(log_snr_t=self.log_snr_max, training=True)
221+
self._t_max = self.get_t_from_log_snr(log_snr_t=self.log_snr_min, training=True)
223222

224223
def _truncated_t(self, t: Tensor) -> Tensor:
225224
return self._t_min + (self._t_max - self._t_min) * t
@@ -250,7 +249,7 @@ def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
250249
return -factor * dsnr_dt
251250

252251
def get_config(self):
253-
return dict(min_log_snr=self._log_snr_min, max_log_snr=self._log_snr_max, s_shift_cosine=self._s_shift_cosine)
252+
return dict(min_log_snr=self.log_snr_min, max_log_snr=self.log_snr_max, s_shift_cosine=self._s_shift_cosine)
254253

255254
@classmethod
256255
def from_config(cls, config, custom_objects=None):
@@ -278,12 +277,12 @@ def __init__(self, sigma_data: float = 1.0, sigma_min: float = 1e-4, sigma_max:
278277
self.rho = 7
279278

280279
# convert EDM parameters to signal-to-noise ratio formulation
281-
self._log_snr_min = -2 * ops.log(sigma_max)
282-
self._log_snr_max = -2 * ops.log(sigma_min)
280+
self.log_snr_min = -2 * ops.log(sigma_max)
281+
self.log_snr_max = -2 * ops.log(sigma_min)
283282
# t is not truncated for EDM by definition of the sampling schedule
284283
# training bounds should be set to avoid numerical issues
285-
self._log_snr_min_training = self._log_snr_min - 1 # one is never sampler during training
286-
self._log_snr_max_training = self._log_snr_max + 1 # 0 is almost surely never sampled during training
284+
self._log_snr_min_training = self.log_snr_min - 1 # one is never sampler during training
285+
self._log_snr_max_training = self.log_snr_max + 1 # 0 is almost surely never sampled during training
287286

288287
def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor:
289288
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
@@ -537,9 +536,9 @@ def velocity(
537536
alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t, training=training)
538537

539538
if conditions is None:
540-
xtc = ops.concatenate([xz, log_snr_t], axis=-1)
539+
xtc = ops.concatenate([xz, self._transform_log_snr(log_snr_t)], axis=-1)
541540
else:
542-
xtc = ops.concatenate([xz, log_snr_t, conditions], axis=-1)
541+
xtc = ops.concatenate([xz, self._transform_log_snr(log_snr_t), conditions], axis=-1)
543542
pred = self.output_projector(self.subnet(xtc, training=training), training=training)
544543

545544
x_pred = self.convert_prediction_to_x(
@@ -587,6 +586,16 @@ def f(x):
587586

588587
return v, ops.expand_dims(trace, axis=-1)
589588

589+
def _transform_log_snr(self, log_snr: Tensor) -> Tensor:
590+
"""Transform the log_snr to the range [-1, 1] for the diffusion process."""
591+
# Transform the log_snr to the range [-1, 1]
592+
return (
593+
2
594+
* (log_snr - self.noise_schedule.log_snr_min)
595+
/ (self.noise_schedule.log_snr_max - self.noise_schedule.log_snr_min)
596+
- 1
597+
)
598+
590599
def _forward(
591600
self,
592601
x: Tensor,
@@ -749,9 +758,9 @@ def compute_metrics(
749758

750759
# calculate output of the network
751760
if conditions is None:
752-
xtc = ops.concatenate([diffused_x, log_snr_t], axis=-1)
761+
xtc = ops.concatenate([diffused_x, self._transform_log_snr(log_snr_t)], axis=-1)
753762
else:
754-
xtc = ops.concatenate([diffused_x, log_snr_t, conditions], axis=-1)
763+
xtc = ops.concatenate([diffused_x, self._transform_log_snr(log_snr_t), conditions], axis=-1)
755764
pred = self.output_projector(self.subnet(xtc, training=training), training=training)
756765

757766
x_pred = self.convert_prediction_to_x(

0 commit comments

Comments
 (0)