Skip to content

Commit 92131d7

Browse files
committed
add serialization, remove unnecessary tensor conversions
1 parent efeff85 commit 92131d7

File tree

1 file changed

+32
-6
lines changed

1 file changed

+32
-6
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919
)
2020

2121

22+
@serializable
2223
class NoiseSchedule(ABC):
23-
"""Noise schedule for diffusion models. We follow the notation from [1].
24+
r"""Noise schedule for diffusion models. We follow the notation from [1].
2425
2526
The diffusion process is defined by a noise schedule, which determines how the noise level changes over time.
2627
We define the noise schedule as a function of the log signal-to-noise ratio (lambda), which can be
@@ -39,8 +40,8 @@ class NoiseSchedule(ABC):
3940
def __init__(self, name: str, variance_type: str):
4041
self.name = name
4142
self.variance_type = variance_type # 'exploding' or 'preserving'
42-
self._log_snr_min = ops.convert_to_tensor(-15) # should be set in the subclasses
43-
self._log_snr_max = ops.convert_to_tensor(15) # should be set in the subclasses
43+
self._log_snr_min = -15 # should be set in the subclasses
44+
self._log_snr_max = 15 # should be set in the subclasses
4445

4546
@property
4647
def scale_base_distribution(self):
@@ -65,11 +66,11 @@ def get_t_from_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
6566

6667
@abstractmethod
6768
def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
68-
"""Compute \beta(t) = d/dt log(1 + e^(-snr(t))). This is usually used for the reverse SDE."""
69+
r"""Compute \beta(t) = d/dt log(1 + e^(-snr(t))). This is usually used for the reverse SDE."""
6970
pass
7071

7172
def get_drift_diffusion(self, log_snr_t: Tensor, x: Tensor = None, training: bool = True) -> tuple[Tensor, Tensor]:
72-
"""Compute the drift and optionally the diffusion term for the reverse SDE.
73+
r"""Compute the drift and optionally the diffusion term for the reverse SDE.
7374
Usually it can be derived from the derivative of the schedule:
7475
\beta(t) = d/dt log(1 + e^(-snr(t)))
7576
f(z, t) = -0.5 * \beta(t) * z
@@ -121,7 +122,15 @@ def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
121122
# 1 / ops.cosh(log_snr_t / 2) * ops.minimum(ops.ones_like(log_snr_t), gamma * ops.exp(-log_snr_t))
122123
return ops.ones_like(log_snr_t)
123124

125+
def get_config(self):
126+
return dict(name=self.name, variance_type=self.variance_type)
127+
128+
@classmethod
129+
def from_config(cls, config, custom_objects=None):
130+
return cls(**deserialize(config, custom_objects=custom_objects))
131+
124132

133+
@serializable
125134
class LinearNoiseSchedule(NoiseSchedule):
126135
"""Linear noise schedule for diffusion models.
127136
@@ -171,7 +180,15 @@ def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
171180
sigma_t = self.get_alpha_sigma(log_snr_t=log_snr_t, training=True)[1]
172181
return ops.square(g / sigma_t)
173182

183+
def get_config(self):
184+
return dict(min_log_snr=self._log_snr_min, max_log_snr=self._log_snr_max)
174185

186+
@classmethod
187+
def from_config(cls, config, custom_objects=None):
188+
return cls(**deserialize(config, custom_objects=custom_objects))
189+
190+
191+
@serializable
175192
class CosineNoiseSchedule(NoiseSchedule):
176193
"""Cosine noise schedule for diffusion models. This schedule is based on the cosine schedule from [1].
177194
For images, use s_shift_cosine = log(base_resolution / d), where d is the used resolution of the image.
@@ -181,7 +198,7 @@ class CosineNoiseSchedule(NoiseSchedule):
181198

182199
def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15, s_shift_cosine: float = 0.0):
183200
super().__init__(name="cosine_noise_schedule", variance_type="preserving")
184-
self._s_shift_cosine = ops.convert_to_tensor(s_shift_cosine)
201+
self._s_shift_cosine = s_shift_cosine
185202
self._log_snr_min = min_log_snr
186203
self._log_snr_max = max_log_snr
187204
self._s_shift_cosine = s_shift_cosine
@@ -220,7 +237,15 @@ def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
220237
"""
221238
return ops.sigmoid(-log_snr_t / 2)
222239

240+
def get_config(self):
241+
return dict(min_log_snr=self._log_snr_min, max_log_snr=self._log_snr_max, s_shift_cosine=self._s_shift_cosine)
223242

243+
@classmethod
244+
def from_config(cls, config, custom_objects=None):
245+
return cls(**deserialize(config, custom_objects=custom_objects))
246+
247+
248+
@serializable
224249
class EDMNoiseSchedule(NoiseSchedule):
225250
"""EDM noise schedule for diffusion models. This schedule is based on the EDM paper [1].
226251
@@ -472,6 +497,7 @@ def velocity(
472497
) -> Tensor:
473498
# calculate the current noise level and transform into correct shape
474499
log_snr_t = expand_right_as(self.noise_schedule.get_log_snr(t=time, training=training), xz)
500+
log_snr_t = keras.ops.broadcast_to(log_snr_t, keras.ops.shape(xz)[:-1] + (1,))
475501
alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t, training=training)
476502

477503
if conditions is None:

0 commit comments

Comments
 (0)