Skip to content

Commit d5dc2ba

Browse files
committed
wip: adapt network to layer paradigm
1 parent 280b651 commit d5dc2ba

File tree

1 file changed

+49
-36
lines changed

1 file changed

+49
-36
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 49 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from abc import ABC, abstractmethod
33
import keras
44
from keras import ops
5-
from keras.saving import register_keras_serializable as serializable
65

6+
from bayesflow.utils.serialization import serialize, deserialize, serializable
77
from bayesflow.types import Tensor, Shape
88
import bayesflow as bf
99
from bayesflow.networks import InferenceNetwork
@@ -13,9 +13,7 @@
1313
expand_right_as,
1414
find_network,
1515
jacobian_trace,
16-
keras_kwargs,
17-
serialize_value_or_type,
18-
deserialize_value_or_type,
16+
layer_kwargs,
1917
weighted_mean,
2018
integrate,
2119
)
@@ -145,8 +143,8 @@ class LinearNoiseSchedule(NoiseSchedule):
145143

146144
def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15):
147145
super().__init__(name="linear_noise_schedule")
148-
self._log_snr_min = ops.convert_to_tensor(min_log_snr)
149-
self._log_snr_max = ops.convert_to_tensor(max_log_snr)
146+
self._log_snr_min = min_log_snr
147+
self._log_snr_max = max_log_snr
150148

151149
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
152150
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
@@ -192,11 +190,11 @@ class CosineNoiseSchedule(NoiseSchedule):
192190
[1] Diffusion models beat gans on image synthesis: Dhariwal and Nichol (2022)
193191
"""
194192

195-
def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15, s_shift_cosine: float = 0.0):
193+
def __init__(self, min_log_snr: float = -15.0, max_log_snr: float = 15.0, s_shift_cosine: float = 0.0):
196194
super().__init__(name="cosine_noise_schedule")
197-
self._log_snr_min = ops.convert_to_tensor(min_log_snr)
198-
self._log_snr_max = ops.convert_to_tensor(max_log_snr)
199-
self._s_shift_cosine = ops.convert_to_tensor(s_shift_cosine)
195+
self._log_snr_min = min_log_snr
196+
self._log_snr_max = max_log_snr
197+
self._s_shift_cosine = s_shift_cosine
200198

201199
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
202200
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
@@ -210,7 +208,8 @@ def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
210208
def get_t_from_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
211209
"""Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
212210
# SNR = -2 * log(tan(pi*t/2)) => t = 2/pi * arctan(exp(-snr/2))
213-
return 2 / math.pi * ops.arctan(ops.exp((2 * self._s_shift_cosine - log_snr_t) / 2))
211+
print("p", log_snr_t)
212+
return 2.0 / math.pi * ops.arctan(ops.exp((2.0 * self._s_shift_cosine - log_snr_t) / 2.0))
214213

215214
def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
216215
"""Compute d/dt log(1 + e^(-snr(t))), which is used for the reverse SDE."""
@@ -241,12 +240,12 @@ class EDMNoiseSchedule(NoiseSchedule):
241240

242241
def __init__(self, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max: float = 80):
243242
super().__init__(name="edm_noise_schedule")
244-
self.sigma_data = ops.convert_to_tensor(sigma_data)
245-
self.sigma_max = ops.convert_to_tensor(sigma_max)
246-
self.sigma_min = ops.convert_to_tensor(sigma_min)
247-
self.p_mean = ops.convert_to_tensor(-1.2)
248-
self.p_std = ops.convert_to_tensor(1.2)
249-
self.rho = ops.convert_to_tensor(7)
243+
self.sigma_data = sigma_data
244+
self.sigma_max = sigma_max
245+
self.sigma_min = sigma_min
246+
self.p_mean = -1.2
247+
self.p_std = 1.2
248+
self.rho = 7
250249

251250
# convert EDM parameters to signal-to-noise ratio formulation
252251
self._log_snr_min = -2 * ops.log(sigma_max)
@@ -336,7 +335,7 @@ def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
336335
return ops.exp(-log_snr_t) + 0.5**2
337336

338337

339-
@serializable(package="bayesflow.networks")
338+
@serializable
340339
class DiffusionModel(InferenceNetwork):
341340
"""Diffusion Model as described in this overview paper [1].
342341
@@ -395,7 +394,7 @@ def __init__(
395394
Additional keyword arguments passed to the subnet and other components.
396395
"""
397396

398-
super().__init__(base_distribution=None, **keras_kwargs(kwargs))
397+
super().__init__(base_distribution=None, **kwargs)
399398

400399
if isinstance(noise_schedule, str):
401400
if noise_schedule == "linear":
@@ -432,18 +431,11 @@ def __init__(
432431
self.subnet = find_network(subnet, **subnet_kwargs)
433432
self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros")
434433

435-
# serialization: store all parameters necessary to call __init__
436-
self.config = {
437-
"integrate_kwargs": self.integrate_kwargs,
438-
"subnet_kwargs": subnet_kwargs,
439-
"noise_schedule": self.noise_schedule,
440-
"prediction_type": self.prediction_type,
441-
**kwargs,
442-
}
443-
self.config = serialize_value_or_type(self.config, "subnet", subnet)
444-
445434
def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
446-
super().build(xz_shape, conditions_shape=conditions_shape)
435+
if self.built:
436+
return
437+
438+
self.base_distribution.build(xz_shape)
447439

448440
self.output_projector.units = xz_shape[-1]
449441
input_shape = list(xz_shape)
@@ -461,12 +453,19 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
461453

462454
def get_config(self):
463455
base_config = super().get_config()
464-
return base_config | self.config
456+
base_config = layer_kwargs(base_config)
457+
458+
config = {
459+
"subnet": self.subnet,
460+
"noise_schedule": self.noise_schedule,
461+
"integrate_kwargs": self.integrate_kwargs,
462+
"prediction_type": self.prediction_type,
463+
}
464+
return base_config | serialize(config)
465465

466466
@classmethod
467-
def from_config(cls, config):
468-
config = deserialize_value_or_type(config, "subnet")
469-
return cls(**config)
467+
def from_config(cls, config, custom_objects=None):
468+
return cls(**deserialize(config, custom_objects=custom_objects))
470469

471470
def convert_prediction_to_x(
472471
self, pred: Tensor, z: Tensor, alpha_t: Tensor, sigma_t: Tensor, log_snr_t: Tensor, clip_x: bool
@@ -546,7 +545,14 @@ def _forward(
546545
training: bool = False,
547546
**kwargs,
548547
) -> Tensor | tuple[Tensor, Tensor]:
549-
integrate_kwargs = self.integrate_kwargs | kwargs
548+
integrate_kwargs = (
549+
{
550+
"start_time": self.noise_schedule._t_min,
551+
"stop_time": self.noise_schedule._t_max,
552+
}
553+
| self.integrate_kwargs
554+
| kwargs
555+
)
550556
if density:
551557

552558
def deltas(time, xz):
@@ -588,7 +594,14 @@ def _inverse(
588594
training: bool = False,
589595
**kwargs,
590596
) -> Tensor | tuple[Tensor, Tensor]:
591-
integrate_kwargs = self.integrate_kwargs | kwargs
597+
integrate_kwargs = (
598+
{
599+
"start_time": self.noise_schedule._t_max,
600+
"stop_time": self.noise_schedule._t_min,
601+
}
602+
| self.integrate_kwargs
603+
| kwargs
604+
)
592605
if density:
593606

594607
def deltas(time, xz):

0 commit comments

Comments
 (0)