Skip to content

Commit 241105b

Browse files
committed
Merge remote-tracking branch 'upstream/dev' into feat-pairs-plot
2 parents 233c206 + d582111 commit 241105b

File tree

8 files changed

+95
-3
lines changed

8 files changed

+95
-3
lines changed

bayesflow/networks/residual/residual.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212
@serializable("bayesflow.networks")
1313
class Residual(Sequential):
1414
def __init__(self, *layers: keras.Layer, **kwargs):
15-
if len(layers) == 1 and isinstance(layers[0], Sequence):
15+
if len(layers) == 0 and "layers" in kwargs:
16+
# extract layers from kwargs, in case they were passed as a keyword argument
17+
layers = kwargs.pop("layers")
18+
elif len(layers) > 0 and "layers" in kwargs:
19+
raise ValueError("Layers passed both as positional argument and as keyword argument")
20+
elif len(layers) == 1 and isinstance(layers[0], Sequence):
1621
layers = layers[0]
1722
super().__init__(list(layers), **sequential_kwargs(kwargs))
1823
self.projector = keras.layers.Dense(units=None, name="projector")
@@ -43,6 +48,8 @@ def build(self, input_shape=None):
4348
# this is a work-around for https://github.com/keras-team/keras/issues/21158
4449
output_shape = input_shape
4550
for layer in self._layers:
51+
if layer.built:
52+
continue
4653
layer.build(output_shape)
4754
output_shape = layer.compute_output_shape(output_shape)
4855

bayesflow/networks/sequential/sequential.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,14 @@ class Sequential(keras.Layer):
3131
"""
3232

3333
def __init__(self, *layers: keras.Layer | Sequence[keras.Layer], **kwargs):
34-
super().__init__(**layer_kwargs(kwargs))
35-
if len(layers) == 1 and isinstance(layers[0], Sequence):
34+
if len(layers) == 0 and "layers" in kwargs:
35+
# extract layers from kwargs, in case they were passed as a keyword argument
36+
layers = kwargs.pop("layers")
37+
elif len(layers) > 0 and "layers" in kwargs:
38+
raise ValueError("Layers passed both as positional argument and as keyword argument")
39+
elif len(layers) == 1 and isinstance(layers[0], Sequence):
3640
layers = layers[0]
41+
super().__init__(**layer_kwargs(kwargs))
3742

3843
self._layers = layers
3944

@@ -44,6 +49,8 @@ def build(self, input_shape):
4449
return
4550

4651
for layer in self._layers:
52+
if layer.built:
53+
continue
4754
layer.build(input_shape)
4855
input_shape = layer.compute_output_shape(input_shape)
4956

tests/test_networks/test_residual/__init__.py

Whitespace-only changes.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import pytest
2+
3+
from bayesflow.networks.residual import Residual
4+
5+
6+
@pytest.fixture()
7+
def residual():
8+
import keras
9+
10+
return Residual(keras.layers.Flatten(), keras.layers.Dense(2))
11+
12+
13+
@pytest.fixture()
14+
def build_shapes():
15+
return {"input_shape": (32, 2)}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import keras
2+
3+
from bayesflow.utils.serialization import deserialize, serialize
4+
5+
from ...utils import assert_layers_equal
6+
7+
8+
def test_serialize_deserialize(residual, build_shapes):
9+
residual.build(**build_shapes)
10+
11+
serialized = serialize(residual)
12+
deserialized = deserialize(serialized)
13+
reserialized = serialize(deserialized)
14+
15+
assert reserialized == serialized
16+
17+
18+
def test_save_and_load(tmp_path, residual, build_shapes):
19+
residual.build(**build_shapes)
20+
21+
keras.saving.save_model(residual, tmp_path / "model.keras")
22+
loaded = keras.saving.load_model(tmp_path / "model.keras")
23+
24+
assert_layers_equal(residual, loaded)

tests/test_networks/test_sequential/__init__.py

Whitespace-only changes.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import pytest
2+
3+
from bayesflow.networks import Sequential
4+
5+
6+
@pytest.fixture()
7+
def sequential():
8+
import keras
9+
10+
return Sequential(keras.layers.Flatten(), keras.layers.Dense(2))
11+
12+
13+
@pytest.fixture()
14+
def build_shapes():
15+
return {"input_shape": (32, 2)}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import keras
2+
3+
from bayesflow.utils.serialization import deserialize, serialize
4+
5+
from ...utils import assert_layers_equal
6+
7+
8+
def test_serialize_deserialize(sequential, build_shapes):
9+
sequential.build(**build_shapes)
10+
11+
serialized = serialize(sequential)
12+
deserialized = deserialize(serialized)
13+
reserialized = serialize(deserialized)
14+
15+
assert reserialized == serialized
16+
17+
18+
def test_save_and_load(tmp_path, sequential, build_shapes):
19+
sequential.build(**build_shapes)
20+
21+
keras.saving.save_model(sequential, tmp_path / "model.keras")
22+
loaded = keras.saving.load_model(tmp_path / "model.keras")
23+
24+
assert_layers_equal(sequential, loaded)

0 commit comments

Comments
 (0)