diff --git a/bayesflow/networks/residual/residual.py b/bayesflow/networks/residual/residual.py index 9d33e4057..457602eb0 100644 --- a/bayesflow/networks/residual/residual.py +++ b/bayesflow/networks/residual/residual.py @@ -12,7 +12,12 @@ @serializable("bayesflow.networks") class Residual(Sequential): def __init__(self, *layers: keras.Layer, **kwargs): - if len(layers) == 1 and isinstance(layers[0], Sequence): + if len(layers) == 0 and "layers" in kwargs: + # extract layers from kwargs, in case they were passed as a keyword argument + layers = kwargs.pop("layers") + elif len(layers) > 0 and "layers" in kwargs: + raise ValueError("Layers passed both as positional argument and as keyword argument") + elif len(layers) == 1 and isinstance(layers[0], Sequence): layers = layers[0] super().__init__(list(layers), **sequential_kwargs(kwargs)) self.projector = keras.layers.Dense(units=None, name="projector") @@ -43,6 +48,8 @@ def build(self, input_shape=None): # this is a work-around for https://github.com/keras-team/keras/issues/21158 output_shape = input_shape for layer in self._layers: + if layer.built: + continue layer.build(output_shape) output_shape = layer.compute_output_shape(output_shape) diff --git a/bayesflow/networks/sequential/sequential.py b/bayesflow/networks/sequential/sequential.py index d52a166b2..ff931772a 100644 --- a/bayesflow/networks/sequential/sequential.py +++ b/bayesflow/networks/sequential/sequential.py @@ -31,9 +31,14 @@ class Sequential(keras.Layer): """ def __init__(self, *layers: keras.Layer | Sequence[keras.Layer], **kwargs): - super().__init__(**layer_kwargs(kwargs)) - if len(layers) == 1 and isinstance(layers[0], Sequence): + if len(layers) == 0 and "layers" in kwargs: + # extract layers from kwargs, in case they were passed as a keyword argument + layers = kwargs.pop("layers") + elif len(layers) > 0 and "layers" in kwargs: + raise ValueError("Layers passed both as positional argument and as keyword argument") + elif len(layers) == 1 and isinstance(layers[0], Sequence): layers = layers[0] + super().__init__(**layer_kwargs(kwargs)) self._layers = layers @@ -44,6 +49,8 @@ def build(self, input_shape): return for layer in self._layers: + if layer.built: + continue layer.build(input_shape) input_shape = layer.compute_output_shape(input_shape) diff --git a/tests/test_networks/test_residual/__init__.py b/tests/test_networks/test_residual/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_networks/test_residual/conftest.py b/tests/test_networks/test_residual/conftest.py new file mode 100644 index 000000000..e7ba511bb --- /dev/null +++ b/tests/test_networks/test_residual/conftest.py @@ -0,0 +1,15 @@ +import pytest + +from bayesflow.networks.residual import Residual + + +@pytest.fixture() +def residual(): + import keras + + return Residual(keras.layers.Flatten(), keras.layers.Dense(2)) + + +@pytest.fixture() +def build_shapes(): + return {"input_shape": (32, 2)} diff --git a/tests/test_networks/test_residual/test_residual.py b/tests/test_networks/test_residual/test_residual.py new file mode 100644 index 000000000..97a0229f5 --- /dev/null +++ b/tests/test_networks/test_residual/test_residual.py @@ -0,0 +1,24 @@ +import keras + +from bayesflow.utils.serialization import deserialize, serialize + +from ...utils import assert_layers_equal + + +def test_serialize_deserialize(residual, build_shapes): + residual.build(**build_shapes) + + serialized = serialize(residual) + deserialized = deserialize(serialized) + reserialized = serialize(deserialized) + + assert reserialized == serialized + + +def test_save_and_load(tmp_path, residual, build_shapes): + residual.build(**build_shapes) + + keras.saving.save_model(residual, tmp_path / "model.keras") + loaded = keras.saving.load_model(tmp_path / "model.keras") + + assert_layers_equal(residual, loaded) diff --git a/tests/test_networks/test_sequential/__init__.py b/tests/test_networks/test_sequential/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_networks/test_sequential/conftest.py b/tests/test_networks/test_sequential/conftest.py new file mode 100644 index 000000000..7a6ab3fef --- /dev/null +++ b/tests/test_networks/test_sequential/conftest.py @@ -0,0 +1,15 @@ +import pytest + +from bayesflow.networks import Sequential + + +@pytest.fixture() +def sequential(): + import keras + + return Sequential(keras.layers.Flatten(), keras.layers.Dense(2)) + + +@pytest.fixture() +def build_shapes(): + return {"input_shape": (32, 2)} diff --git a/tests/test_networks/test_sequential/test_sequential.py b/tests/test_networks/test_sequential/test_sequential.py new file mode 100644 index 000000000..82f3f5d84 --- /dev/null +++ b/tests/test_networks/test_sequential/test_sequential.py @@ -0,0 +1,24 @@ +import keras + +from bayesflow.utils.serialization import deserialize, serialize + +from ...utils import assert_layers_equal + + +def test_serialize_deserialize(sequential, build_shapes): + sequential.build(**build_shapes) + + serialized = serialize(sequential) + deserialized = deserialize(serialized) + reserialized = serialize(deserialized) + + assert reserialized == serialized + + +def test_save_and_load(tmp_path, sequential, build_shapes): + sequential.build(**build_shapes) + + keras.saving.save_model(sequential, tmp_path / "model.keras") + loaded = keras.saving.load_model(tmp_path / "model.keras") + + assert_layers_equal(sequential, loaded)