Skip to content

Commit efeff85

Browse files
committed
Merge branch 'feat-diffusion-model-adapt' into feat-diffusion-model
2 parents 739491a + d5dc2ba commit efeff85

File tree

12 files changed

+104
-40
lines changed

12 files changed

+104
-40
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from abc import ABC, abstractmethod
33
import keras
44
from keras import ops
5-
from keras.saving import register_keras_serializable as serializable
65

6+
from bayesflow.utils.serialization import serialize, deserialize, serializable
77
from bayesflow.types import Tensor, Shape
88
import bayesflow as bf
99
from bayesflow.networks import InferenceNetwork
@@ -13,8 +13,7 @@
1313
expand_right_as,
1414
find_network,
1515
jacobian_trace,
16-
serialize_value_or_type,
17-
deserialize_value_or_type,
16+
layer_kwargs,
1817
weighted_mean,
1918
integrate,
2019
)
@@ -132,9 +131,9 @@ class LinearNoiseSchedule(NoiseSchedule):
132131
"""
133132

134133
def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15):
135-
super().__init__(name="linear_noise_schedule", variance_type="preserving")
136-
self._log_snr_min = ops.convert_to_tensor(min_log_snr)
137-
self._log_snr_max = ops.convert_to_tensor(max_log_snr)
134+
super().__init__(name="linear_noise_schedule")
135+
self._log_snr_min = min_log_snr
136+
self._log_snr_max = max_log_snr
138137

139138
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
140139
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
@@ -182,9 +181,10 @@ class CosineNoiseSchedule(NoiseSchedule):
182181

183182
def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15, s_shift_cosine: float = 0.0):
184183
super().__init__(name="cosine_noise_schedule", variance_type="preserving")
185-
self._log_snr_min = ops.convert_to_tensor(min_log_snr)
186-
self._log_snr_max = ops.convert_to_tensor(max_log_snr)
187184
self._s_shift_cosine = ops.convert_to_tensor(s_shift_cosine)
185+
self._log_snr_min = min_log_snr
186+
self._log_snr_max = max_log_snr
187+
self._s_shift_cosine = s_shift_cosine
188188

189189
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
190190
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
@@ -229,12 +229,13 @@ class EDMNoiseSchedule(NoiseSchedule):
229229

230230
def __init__(self, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max: float = 80):
231231
super().__init__(name="edm_noise_schedule", variance_type="exploding")
232-
self.sigma_data = ops.convert_to_tensor(sigma_data)
233-
self.sigma_max = ops.convert_to_tensor(sigma_max)
234-
self.sigma_min = ops.convert_to_tensor(sigma_min)
235-
self.p_mean = ops.convert_to_tensor(-1.2)
236-
self.p_std = ops.convert_to_tensor(1.2)
237-
self.rho = ops.convert_to_tensor(7)
232+
super().__init__(name="edm_noise_schedule")
233+
self.sigma_data = sigma_data
234+
self.sigma_max = sigma_max
235+
self.sigma_min = sigma_min
236+
self.p_mean = -1.2
237+
self.p_std = 1.2
238+
self.rho = 7
238239

239240
# convert EDM parameters to signal-to-noise ratio formulation
240241
self._log_snr_min = -2 * ops.log(sigma_max)
@@ -306,7 +307,7 @@ def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
306307
return ops.exp(-log_snr_t) + 0.5**2
307308

308309

309-
@serializable(package="bayesflow.networks")
310+
@serializable
310311
class DiffusionModel(InferenceNetwork):
311312
"""Diffusion Model as described in this overview paper [1].
312313
@@ -401,18 +402,11 @@ def __init__(
401402
self.subnet = find_network(subnet, **subnet_kwargs)
402403
self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros")
403404

404-
# serialization: store all parameters necessary to call __init__
405-
self.config = {
406-
"integrate_kwargs": self.integrate_kwargs,
407-
"subnet_kwargs": subnet_kwargs,
408-
"noise_schedule": self.noise_schedule,
409-
"prediction_type": self.prediction_type,
410-
**kwargs,
411-
}
412-
self.config = serialize_value_or_type(self.config, "subnet", subnet)
413-
414405
def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
415-
super().build(xz_shape, conditions_shape=conditions_shape)
406+
if self.built:
407+
return
408+
409+
self.base_distribution.build(xz_shape)
416410

417411
self.output_projector.units = xz_shape[-1]
418412
input_shape = list(xz_shape)
@@ -430,12 +424,19 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
430424

431425
def get_config(self):
432426
base_config = super().get_config()
433-
return base_config | self.config
427+
base_config = layer_kwargs(base_config)
428+
429+
config = {
430+
"subnet": self.subnet,
431+
"noise_schedule": self.noise_schedule,
432+
"integrate_kwargs": self.integrate_kwargs,
433+
"prediction_type": self.prediction_type,
434+
}
435+
return base_config | serialize(config)
434436

435437
@classmethod
436-
def from_config(cls, config):
437-
config = deserialize_value_or_type(config, "subnet")
438-
return cls(**config)
438+
def from_config(cls, config, custom_objects=None):
439+
return cls(**deserialize(config, custom_objects=custom_objects))
439440

440441
def convert_prediction_to_x(
441442
self, pred: Tensor, z: Tensor, alpha_t: Tensor, sigma_t: Tensor, log_snr_t: Tensor, clip_x: bool
@@ -515,7 +516,14 @@ def _forward(
515516
training: bool = False,
516517
**kwargs,
517518
) -> Tensor | tuple[Tensor, Tensor]:
518-
integrate_kwargs = self.integrate_kwargs | kwargs
519+
integrate_kwargs = (
520+
{
521+
"start_time": self.noise_schedule._t_min,
522+
"stop_time": self.noise_schedule._t_max,
523+
}
524+
| self.integrate_kwargs
525+
| kwargs
526+
)
519527
if density:
520528

521529
def deltas(time, xz):
@@ -557,7 +565,14 @@ def _inverse(
557565
training: bool = False,
558566
**kwargs,
559567
) -> Tensor | tuple[Tensor, Tensor]:
560-
integrate_kwargs = self.integrate_kwargs | kwargs
568+
integrate_kwargs = (
569+
{
570+
"start_time": self.noise_schedule._t_max,
571+
"stop_time": self.noise_schedule._t_min,
572+
}
573+
| self.integrate_kwargs
574+
| kwargs
575+
)
561576
if density:
562577

563578
def deltas(time, xz):

bayesflow/links/ordered.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from keras.saving import register_keras_serializable as serializable
33

44
from bayesflow.utils import layer_kwargs
5+
from bayesflow.utils.decorators import sanitize_input_shape
56

67

78
@serializable(package="links.ordered")
@@ -49,5 +50,6 @@ def call(self, inputs):
4950
x = keras.ops.concatenate([below, anchor_input, above], self.axis)
5051
return x
5152

53+
@sanitize_input_shape
5254
def compute_output_shape(self, input_shape):
5355
return input_shape

bayesflow/networks/summary_network.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def build(self, input_shape):
2121
if self.base_distribution is not None:
2222
self.base_distribution.build(keras.ops.shape(z))
2323

24+
@sanitize_input_shape
2425
def compute_output_shape(self, input_shape):
2526
return keras.ops.shape(self.call(keras.ops.zeros(input_shape)))
2627

bayesflow/networks/transformers/mab.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from bayesflow.networks import MLP
55
from bayesflow.types import Tensor
66
from bayesflow.utils import layer_kwargs
7+
from bayesflow.utils.decorators import sanitize_input_shape
78
from bayesflow.utils.serialization import serializable
89

910

@@ -122,8 +123,10 @@ def call(self, seq_x: Tensor, seq_y: Tensor, training: bool = False, **kwargs) -
122123
return out
123124

124125
# noinspection PyMethodOverriding
126+
@sanitize_input_shape
125127
def build(self, seq_x_shape, seq_y_shape):
126128
self.call(keras.ops.zeros(seq_x_shape), keras.ops.zeros(seq_y_shape))
127129

130+
@sanitize_input_shape
128131
def compute_output_shape(self, seq_x_shape, seq_y_shape):
129132
return keras.ops.shape(self.call(keras.ops.zeros(seq_x_shape), keras.ops.zeros(seq_y_shape)))

bayesflow/networks/transformers/pma.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from bayesflow.networks import MLP
55
from bayesflow.types import Tensor
66
from bayesflow.utils import layer_kwargs
7+
from bayesflow.utils.decorators import sanitize_input_shape
78
from bayesflow.utils.serialization import serializable
89

910
from .mab import MultiHeadAttentionBlock
@@ -125,5 +126,6 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
125126
summaries = self.mab(seed_tiled, set_x_transformed, training=training, **kwargs)
126127
return ops.reshape(summaries, (ops.shape(summaries)[0], -1))
127128

129+
@sanitize_input_shape
128130
def compute_output_shape(self, input_shape):
129131
return keras.ops.shape(self.call(keras.ops.zeros(input_shape)))

bayesflow/networks/transformers/sab.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import keras
22

33
from bayesflow.types import Tensor
4+
from bayesflow.utils.decorators import sanitize_input_shape
45
from bayesflow.utils.serialization import serializable
56

67
from .mab import MultiHeadAttentionBlock
@@ -16,6 +17,7 @@ class SetAttentionBlock(MultiHeadAttentionBlock):
1617
"""
1718

1819
# noinspection PyMethodOverriding
20+
@sanitize_input_shape
1921
def build(self, input_set_shape):
2022
self.call(keras.ops.zeros(input_set_shape))
2123

@@ -42,5 +44,6 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
4244
return super().call(input_set, input_set, training=training, **kwargs)
4345

4446
# noinspection PyMethodOverriding
47+
@sanitize_input_shape
4548
def compute_output_shape(self, input_set_shape):
4649
return keras.ops.shape(self.call(keras.ops.zeros(input_set_shape)))

bayesflow/utils/decorators.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def callback(x):
114114

115115

116116
def sanitize_input_shape(fn: Callable):
117-
"""Decorator to replace the first dimension in input_shape with a dummy batch size if it is None"""
117+
"""Decorator to replace the first dimension in ..._shape arguments with a dummy batch size if it is None"""
118118

119119
# The Keras functional API passes input_shape = (None, second_dim, third_dim, ...), which
120120
# causes problems when constructions like self.call(keras.ops.zeros(input_shape)) are used
@@ -126,5 +126,8 @@ def callback(input_shape: Shape) -> Shape:
126126
return tuple(input_shape)
127127
return input_shape
128128

129-
fn = argument_callback("input_shape", callback)(fn)
129+
args = inspect.getfullargspec(fn).args
130+
for arg in args:
131+
if arg.endswith("_shape"):
132+
fn = argument_callback(arg, callback)(fn)
130133
return fn

examples/From_ABC_to_BayesFlow.ipynb

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@
3838
"outputs": [],
3939
"source": [
4040
"import numpy as np\n",
41-
"import matplotlib.pyplot as plt"
41+
"import matplotlib.pyplot as plt\n",
42+
"import tempfile\n",
43+
"from pathlib import Path\n",
44+
"import platform"
4245
]
4346
},
4447
{
@@ -322,7 +325,9 @@
322325
")\n",
323326
"\n",
324327
"# generate a temporary SQLite DB\n",
325-
"abc_id = abc.new(\"sqlite:////tmp/mjp.db\", observations)"
328+
"prefix = \"sqlite:///\" if platform.system() == \"Windows\" else \"sqlite:////\"\n",
329+
"db_path = (Path(tempfile.gettempdir()).absolute() / \"mjp.db\").as_uri().replace(\"file:///\", prefix)\n",
330+
"abc_id = abc.new(db_path, observations)"
326331
]
327332
},
328333
{

examples/SIR_Posterior_Estimation.ipynb

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
"source": [
2020
"import os\n",
2121
"# Set to your favorite backend\n",
22-
"os.environ[\"KERAS_BACKEND\"] = \"jax\""
22+
"if \"KERAS_BACKEND\" not in os.environ:\n",
23+
" # set this to \"torch\", \"tensorflow\", or \"jax\"\n",
24+
" os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
25+
"else:\n",
26+
" print(f\"Using '{os.environ['KERAS_BACKEND']}' backend\")"
2327
]
2428
},
2529
{

examples/Two_Moons_Starter.ipynb

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@
2424
"source": [
2525
"import os\n",
2626
"# Set to your favorite backend\n",
27-
"os.environ[\"KERAS_BACKEND\"] = \"jax\""
27+
"if \"KERAS_BACKEND\" not in os.environ:\n",
28+
" # set this to \"torch\", \"tensorflow\", or \"jax\"\n",
29+
" os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
30+
"else:\n",
31+
" print(f\"Using '{os.environ['KERAS_BACKEND']}' backend\")"
2832
]
2933
},
3034
{

0 commit comments

Comments
 (0)