Skip to content

Commit 0b09487

Browse files
committed
fix: layers were not deserialized for Sequential and Residual
As layers were passed with the `*layers` syntax, they could not be passed as keyword arguments. In `from_config`, however, this was attempted, leading to the layers to be ignored during reserialization. This commit fixes this by taking the layers from `kwargs` if they are passed as a keyword argument.
1 parent 8567049 commit 0b09487

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

bayesflow/networks/residual/residual.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212
@serializable("bayesflow.networks")
1313
class Residual(Sequential):
1414
def __init__(self, *layers: keras.Layer, **kwargs):
15-
if len(layers) == 1 and isinstance(layers[0], Sequence):
15+
if len(layers) == 0 and "layers" in kwargs:
16+
# extract layers from kwargs, in case they were passed as a keyword argument
17+
layers = kwargs.pop("layers")
18+
elif len(layers) > 0 and "layers" in kwargs:
19+
raise ValueError("Layers passed both as positional argument and as keyword argument")
20+
elif len(layers) == 1 and isinstance(layers[0], Sequence):
1621
layers = layers[0]
1722
super().__init__(list(layers), **sequential_kwargs(kwargs))
1823
self.projector = keras.layers.Dense(units=None, name="projector")
@@ -43,6 +48,8 @@ def build(self, input_shape=None):
4348
# this is a work-around for https://github.com/keras-team/keras/issues/21158
4449
output_shape = input_shape
4550
for layer in self._layers:
51+
if layer.built:
52+
continue
4653
layer.build(output_shape)
4754
output_shape = layer.compute_output_shape(output_shape)
4855

bayesflow/networks/sequential/sequential.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,14 @@ class Sequential(keras.Layer):
3131
"""
3232

3333
def __init__(self, *layers: keras.Layer | Sequence[keras.Layer], **kwargs):
34-
super().__init__(**layer_kwargs(kwargs))
35-
if len(layers) == 1 and isinstance(layers[0], Sequence):
34+
if len(layers) == 0 and "layers" in kwargs:
35+
# extract layers from kwargs, in case they were passed as a keyword argument
36+
layers = kwargs.pop("layers")
37+
elif len(layers) > 0 and "layers" in kwargs:
38+
raise ValueError("Layers passed both as positional argument and as keyword argument")
39+
elif len(layers) == 1 and isinstance(layers[0], Sequence):
3640
layers = layers[0]
41+
super().__init__(**layer_kwargs(kwargs))
3742

3843
self._layers = layers
3944

@@ -44,6 +49,8 @@ def build(self, input_shape):
4449
return
4550

4651
for layer in self._layers:
52+
if layer.built:
53+
continue
4754
layer.build(input_shape)
4855
input_shape = layer.compute_output_shape(input_shape)
4956

0 commit comments

Comments
 (0)