Skip to content

Commit 5b52499

Browse files
committed
minor changes
1 parent 196683c commit 5b52499

File tree

1 file changed

+26
-26
lines changed

1 file changed

+26
-26
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def derivative_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) ->
7575

7676
def get_drift_diffusion(self, log_snr_t: Tensor, x: Tensor = None, training: bool = False) -> tuple[Tensor, Tensor]:
7777
r"""Compute the drift and optionally the squared diffusion term for the reverse SDE.
78-
Usually it can be derived from the derivative of the schedule:
78+
It can be derived from the derivative of the schedule:
7979
\beta(t) = d/dt log(1 + e^(-snr(t)))
8080
f(z, t) = -0.5 * \beta(t) * z
8181
g(t)^2 = \beta(t)
@@ -85,9 +85,8 @@ def get_drift_diffusion(self, log_snr_t: Tensor, x: Tensor = None, training: boo
8585
8686
For a variance exploding schedule, one should set f(z, t) = 0.
8787
"""
88-
# Default implementation is to return the diffusion term only
8988
beta = self.derivative_log_snr(log_snr_t=log_snr_t, training=training)
90-
if x is None: # return g only
89+
if x is None: # return g^2 only
9190
return beta
9291
if self.variance_type == "preserving":
9392
f = -0.5 * beta * x
@@ -121,7 +120,7 @@ def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
121120
"""Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda). Default is 1.
122121
Generally, weighting functions should be defined for a noise prediction loss.
123122
"""
124-
# sigmoid: ops.sigmoid(-log_snr_t / 2), based on Kingma et al. (2023)
123+
# sigmoid: ops.sigmoid(-log_snr_t + 2), based on Kingma et al. (2023)
125124
# min-snr with gamma = 5, based on Hang et al. (2023)
126125
# 1 / ops.cosh(log_snr_t / 2) * ops.minimum(ops.ones_like(log_snr_t), gamma * ops.exp(-log_snr_t))
127126
return ops.ones_like(log_snr_t)
@@ -291,9 +290,9 @@ def __init__(self, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max:
291290
self._log_snr_min = -2 * ops.log(sigma_max)
292291
self._log_snr_max = -2 * ops.log(sigma_min)
293292
# t is not truncated for EDM by definition of the sampling schedule
294-
# training bounds are not so important, but should be set to avoid numerical issues
295-
self._log_snr_min_training = self._log_snr_min * 2 # one is never sampler during training
296-
self._log_snr_max_training = self._log_snr_max * 2 # 0 is almost surely never sampled during training
293+
# training bounds should be set to avoid numerical issues
294+
self._log_snr_min_training = self._log_snr_min - 1 # one is never sampler during training
295+
self._log_snr_max_training = self._log_snr_max + 1 # 0 is almost surely never sampled during training
297296

298297
def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor:
299298
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
@@ -304,14 +303,9 @@ def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor:
304303
snr = -(loc + scale * ops.erfinv(2 * t - 1) * math.sqrt(2))
305304
snr = keras.ops.clip(snr, x_min=self._log_snr_min_training, x_max=self._log_snr_max_training)
306305
else: # sampling
307-
snr = (
308-
-2
309-
* self.rho
310-
* ops.log(
311-
self.sigma_max ** (1 / self.rho)
312-
+ (1 - t) * (self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho))
313-
)
314-
)
306+
sigma_min_rho = self.sigma_min ** (1 / self.rho)
307+
sigma_max_rho = self.sigma_max ** (1 / self.rho)
308+
snr = -2 * self.rho * ops.log(sigma_max_rho + (1 - t) * (sigma_min_rho - sigma_max_rho))
315309
return snr
316310

317311
def get_t_from_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) -> Tensor:
@@ -325,10 +319,9 @@ def get_t_from_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) ->
325319
else: # sampling
326320
# SNR = -2 * rho * log(sigma_max ** (1/rho) + (1 - t) * (sigma_min ** (1/rho) - sigma_max ** (1/rho)))
327321
# => t = 1 - ((exp(-snr/(2*rho)) - sigma_max ** (1/rho)) / (sigma_min ** (1/rho) - sigma_max ** (1/rho)))
328-
t = 1 - (
329-
(ops.exp(-log_snr_t / (2 * self.rho)) - self.sigma_max ** (1 / self.rho))
330-
/ (self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho))
331-
)
322+
sigma_min_rho = self.sigma_min ** (1 / self.rho)
323+
sigma_max_rho = self.sigma_max ** (1 / self.rho)
324+
t = 1 - ((ops.exp(-log_snr_t / (2 * self.rho)) - sigma_max_rho) / (sigma_min_rho - sigma_max_rho))
332325
return t
333326

334327
def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
@@ -354,6 +347,13 @@ def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
354347
"""Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda)."""
355348
return (ops.exp(-log_snr_t) + ops.square(self.sigma_data)) / ops.square(self.sigma_data)
356349

350+
def get_config(self):
351+
return dict(sigma_data=self.sigma_data, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
352+
353+
@classmethod
354+
def from_config(cls, config, custom_objects=None):
355+
return cls(**deserialize(config, custom_objects=custom_objects))
356+
357357

358358
@serializable
359359
class DiffusionModel(InferenceNetwork):
@@ -510,15 +510,15 @@ def convert_prediction_to_x(
510510
elif self.prediction_type == "noise":
511511
# convert noise prediction into x
512512
x = (z - sigma_t * pred) / alpha_t
513-
elif self.prediction_type == "x":
514-
x = pred
515-
elif self.prediction_type == "score":
516-
x = (z + sigma_t**2 * pred) / alpha_t
517-
else: # self.prediction_type == 'F': # EDM
513+
elif self.prediction_type == "F": # EDM
518514
sigma_data = self.noise_schedule.sigma_data
519515
x1 = (sigma_data**2 * alpha_t) / (ops.exp(-log_snr_t) + sigma_data**2)
520516
x2 = ops.exp(-log_snr_t / 2) * sigma_data / ops.sqrt(ops.exp(-log_snr_t) + sigma_data**2)
521517
x = x1 * z + x2 * pred
518+
elif self.prediction_type == "x":
519+
x = pred
520+
else: # "score"
521+
x = (z + sigma_t**2 * pred) / alpha_t
522522

523523
if clip_x:
524524
x = keras.ops.clip(x, self._clip_min, self._clip_max)
@@ -606,7 +606,7 @@ def _forward(
606606
| kwargs
607607
)
608608
if integrate_kwargs["method"] == "euler_maruyama":
609-
raise ValueError("Stoachastic methods are not supported for forward integration.")
609+
raise ValueError("Stochastic methods are not supported for forward integration.")
610610

611611
if density:
612612

@@ -661,7 +661,7 @@ def _inverse(
661661
)
662662
if density:
663663
if integrate_kwargs["method"] == "euler_maruyama":
664-
raise ValueError("Stoachastic methods are not supported for density computation.")
664+
raise ValueError("Stochastic methods are not supported for density computation.")
665665

666666
def deltas(time, xz):
667667
v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training)

0 commit comments

Comments
 (0)