Skip to content

Commit dd1ca0c

Browse files
committed
add custom sequential to fix #491
1 parent cd2c212 commit dd1ca0c

File tree

8 files changed

+112
-23
lines changed

8 files changed

+112
-23
lines changed

bayesflow/diagnostics/metrics/classifier_two_sample_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import keras
66

77
from bayesflow.utils.exceptions import ShapeError
8-
from bayesflow.networks import MLP
8+
from bayesflow.networks import MLP, Sequential
99

1010

1111
def classifier_two_sample_test(
@@ -96,7 +96,7 @@ def classifier_two_sample_test(
9696
labels = labels[shuffle_idx]
9797

9898
# Create and train classifier with optional stopping
99-
classifier = keras.Sequential(
99+
classifier = Sequential(
100100
[MLP(widths=mlp_widths, **kwargs.get("mlp_kwargs", {})), keras.layers.Dense(1, activation="sigmoid")]
101101
)
102102

bayesflow/networks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .point_inference_network import PointInferenceNetwork
1313
from .mlp import MLP
1414
from .fusion_network import FusionNetwork
15+
from .sequential import Sequential
1516
from .summary_network import SummaryNetwork
1617
from .time_series_network import TimeSeriesNetwork
1718
from .transformers import SetTransformer, TimeSeriesTransformer, FusionTransformer

bayesflow/networks/mlp/mlp.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33

44
import keras
55

6-
from bayesflow.utils import sequential_kwargs
6+
from bayesflow.utils import layer_kwargs
77
from bayesflow.utils.serialization import deserialize, serializable, serialize
88

9+
from ..sequential import Sequential
910
from ..residual import Residual
1011

1112

1213
@serializable("bayesflow.networks")
13-
class MLP(keras.Sequential):
14+
class MLP(Sequential):
1415
"""
1516
Implements a simple configurable MLP with optional residual connections and dropout.
1617
@@ -67,40 +68,44 @@ def __init__(
6768
self.norm = norm
6869
self.spectral_normalization = spectral_normalization
6970

70-
layers = []
71+
blocks = []
7172

7273
for width in widths:
73-
layer = self._make_layer(
74+
block = self._make_block(
7475
width, activation, kernel_initializer, residual, dropout, norm, spectral_normalization
7576
)
76-
layers.append(layer)
77+
blocks.append(block)
7778

78-
super().__init__(layers, **sequential_kwargs(kwargs))
79+
super().__init__(*blocks, **kwargs)
7980

8081
def build(self, input_shape=None):
8182
if self.built:
8283
# building when the network is already built can cause issues with serialization
8384
# see https://github.com/keras-team/keras/issues/21147
8485
return
8586

86-
# we only care about the last dimension, and using ... signifies to keras.Sequential
87-
# that any number of batch dimensions is valid (which is what we want for all sublayers)
88-
# we also have to avoid calling super().build() because this causes
89-
# shape errors when building on non-sets but doing inference on sets
90-
# this is a work-around for https://github.com/keras-team/keras/issues/21158
91-
input_shape = (..., input_shape[-1])
92-
9387
for layer in self._layers:
9488
layer.build(input_shape)
9589
input_shape = layer.compute_output_shape(input_shape)
9690

91+
def call(self, x, training=None, mask=None):
92+
for layer in self._layers:
93+
kwargs = {}
94+
if layer._call_has_mask_arg:
95+
kwargs["mask"] = mask
96+
if layer._call_has_training_arg and training is not None:
97+
kwargs["training"] = training
98+
99+
x = layer(x, **kwargs)
100+
return x
101+
97102
@classmethod
98103
def from_config(cls, config, custom_objects=None):
99104
return cls(**deserialize(config, custom_objects=custom_objects))
100105

101106
def get_config(self):
102107
base_config = super().get_config()
103-
base_config = sequential_kwargs(base_config)
108+
base_config = layer_kwargs(base_config)
104109

105110
config = {
106111
"widths": self.widths,
@@ -115,7 +120,7 @@ def get_config(self):
115120
return base_config | serialize(config)
116121

117122
@staticmethod
118-
def _make_layer(width, activation, kernel_initializer, residual, dropout, norm, spectral_normalization):
123+
def _make_block(width, activation, kernel_initializer, residual, dropout, norm, spectral_normalization):
119124
layers = []
120125

121126
dense = keras.layers.Dense(width, kernel_initializer=kernel_initializer)
@@ -148,4 +153,4 @@ def _make_layer(width, activation, kernel_initializer, residual, dropout, norm,
148153
if residual:
149154
return Residual(*layers)
150155

151-
return keras.Sequential(layers)
156+
return Sequential(layers)

bayesflow/networks/residual/residual.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
from bayesflow.utils import sequential_kwargs
77
from bayesflow.utils.serialization import deserialize, serializable, serialize
88

9+
from ..sequential import Sequential
10+
911

1012
@serializable("bayesflow.networks")
11-
class Residual(keras.Sequential):
13+
class Residual(Sequential):
1214
def __init__(self, *layers: keras.Layer, **kwargs):
1315
if len(layers) == 1 and isinstance(layers[0], Sequence):
1416
layers = layers[0]
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .sequential import Sequential
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from collections.abc import Sequence
2+
import keras
3+
4+
from bayesflow.utils import layer_kwargs
5+
from bayesflow.utils.serialization import deserialize, serializable, serialize
6+
7+
8+
@serializable("bayesflow.networks")
9+
class Sequential(keras.Layer):
10+
def __init__(self, *layers: keras.Layer | Sequence[keras.Layer], **kwargs):
11+
super().__init__(**layer_kwargs(kwargs))
12+
if len(layers) == 1 and isinstance(layers[0], Sequence):
13+
layers = layers[0]
14+
15+
self._layers = layers
16+
17+
def build(self, input_shape):
18+
if self.built:
19+
# building when the network is already built can cause issues with serialization
20+
# see https://github.com/keras-team/keras/issues/21147
21+
return
22+
23+
for layer in self._layers:
24+
layer.build(input_shape)
25+
input_shape = layer.compute_output_shape(input_shape)
26+
27+
def call(self, inputs, training=None, mask=None):
28+
x = inputs
29+
for layer in self._layers:
30+
kwargs = self._make_kwargs_for_layer(layer, training, mask)
31+
x = layer(x, **kwargs)
32+
return x
33+
34+
def compute_output_shape(self, input_shape):
35+
for layer in self._layers:
36+
input_shape = layer.compute_output_shape(input_shape)
37+
38+
return input_shape
39+
40+
def get_config(self):
41+
base_config = super().get_config()
42+
base_config = layer_kwargs(base_config)
43+
44+
config = {
45+
"layers": [serialize(layer) for layer in self._layers],
46+
}
47+
48+
return base_config | config
49+
50+
@classmethod
51+
def from_config(cls, config, custom_objects=None):
52+
return cls(**deserialize(config, custom_objects=custom_objects))
53+
54+
@property
55+
def layers(self):
56+
return self._layers
57+
58+
@staticmethod
59+
def _make_kwargs_for_layer(layer, training, mask):
60+
kwargs = {}
61+
if layer._call_has_mask_arg:
62+
kwargs["mask"] = mask
63+
if layer._call_has_training_arg and training is not None:
64+
kwargs["training"] = training
65+
return kwargs

tests/test_networks/test_mlp/conftest.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,24 @@
33
from bayesflow.networks import MLP
44

55

6+
@pytest.fixture(params=[None, 0.0, 0.1])
7+
def dropout(request):
8+
return request.param
9+
10+
11+
@pytest.fixture(params=[None, "batch"])
12+
def norm(request):
13+
return request.param
14+
15+
16+
@pytest.fixture(params=[False, True])
17+
def residual(request):
18+
return request.param
19+
20+
621
@pytest.fixture()
7-
def mlp():
8-
return MLP([64, 64])
22+
def mlp(dropout, norm, residual):
23+
return MLP([64, 64], dropout=dropout, norm=norm, residual=residual)
924

1025

1126
@pytest.fixture()

tests/test_networks/test_mlp/test_mlp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from bayesflow.utils.serialization import deserialize, serialize
44

5-
from ...utils import assert_models_equal
5+
from ...utils import assert_layers_equal
66

77

88
def test_serialize_deserialize(mlp, build_shapes):
@@ -21,4 +21,4 @@ def test_save_and_load(tmp_path, mlp, build_shapes):
2121
keras.saving.save_model(mlp, tmp_path / "model.keras")
2222
loaded = keras.saving.load_model(tmp_path / "model.keras")
2323

24-
assert_models_equal(mlp, loaded)
24+
assert_layers_equal(mlp, loaded)

0 commit comments

Comments
 (0)