File tree Expand file tree Collapse file tree 8 files changed +95
-3
lines changed Expand file tree Collapse file tree 8 files changed +95
-3
lines changed Original file line number Diff line number Diff line change 1212@serializable ("bayesflow.networks" )
1313class Residual (Sequential ):
1414 def __init__ (self , * layers : keras .Layer , ** kwargs ):
15- if len (layers ) == 1 and isinstance (layers [0 ], Sequence ):
15+ if len (layers ) == 0 and "layers" in kwargs :
16+ # extract layers from kwargs, in case they were passed as a keyword argument
17+ layers = kwargs .pop ("layers" )
18+ elif len (layers ) > 0 and "layers" in kwargs :
19+ raise ValueError ("Layers passed both as positional argument and as keyword argument" )
20+ elif len (layers ) == 1 and isinstance (layers [0 ], Sequence ):
1621 layers = layers [0 ]
1722 super ().__init__ (list (layers ), ** sequential_kwargs (kwargs ))
1823 self .projector = keras .layers .Dense (units = None , name = "projector" )
@@ -43,6 +48,8 @@ def build(self, input_shape=None):
4348 # this is a work-around for https://github.com/keras-team/keras/issues/21158
4449 output_shape = input_shape
4550 for layer in self ._layers :
51+ if layer .built :
52+ continue
4653 layer .build (output_shape )
4754 output_shape = layer .compute_output_shape (output_shape )
4855
Original file line number Diff line number Diff line change @@ -31,9 +31,14 @@ class Sequential(keras.Layer):
3131 """
3232
3333 def __init__ (self , * layers : keras .Layer | Sequence [keras .Layer ], ** kwargs ):
34- super ().__init__ (** layer_kwargs (kwargs ))
35- if len (layers ) == 1 and isinstance (layers [0 ], Sequence ):
34+ if len (layers ) == 0 and "layers" in kwargs :
35+ # extract layers from kwargs, in case they were passed as a keyword argument
36+ layers = kwargs .pop ("layers" )
37+ elif len (layers ) > 0 and "layers" in kwargs :
38+ raise ValueError ("Layers passed both as positional argument and as keyword argument" )
39+ elif len (layers ) == 1 and isinstance (layers [0 ], Sequence ):
3640 layers = layers [0 ]
41+ super ().__init__ (** layer_kwargs (kwargs ))
3742
3843 self ._layers = layers
3944
@@ -44,6 +49,8 @@ def build(self, input_shape):
4449 return
4550
4651 for layer in self ._layers :
52+ if layer .built :
53+ continue
4754 layer .build (input_shape )
4855 input_shape = layer .compute_output_shape (input_shape )
4956
Original file line number Diff line number Diff line change 1+ import pytest
2+
3+ from bayesflow .networks .residual import Residual
4+
5+
6+ @pytest .fixture ()
7+ def residual ():
8+ import keras
9+
10+ return Residual (keras .layers .Flatten (), keras .layers .Dense (2 ))
11+
12+
13+ @pytest .fixture ()
14+ def build_shapes ():
15+ return {"input_shape" : (32 , 2 )}
Original file line number Diff line number Diff line change 1+ import keras
2+
3+ from bayesflow .utils .serialization import deserialize , serialize
4+
5+ from ...utils import assert_layers_equal
6+
7+
8+ def test_serialize_deserialize (residual , build_shapes ):
9+ residual .build (** build_shapes )
10+
11+ serialized = serialize (residual )
12+ deserialized = deserialize (serialized )
13+ reserialized = serialize (deserialized )
14+
15+ assert reserialized == serialized
16+
17+
18+ def test_save_and_load (tmp_path , residual , build_shapes ):
19+ residual .build (** build_shapes )
20+
21+ keras .saving .save_model (residual , tmp_path / "model.keras" )
22+ loaded = keras .saving .load_model (tmp_path / "model.keras" )
23+
24+ assert_layers_equal (residual , loaded )
Original file line number Diff line number Diff line change 1+ import pytest
2+
3+ from bayesflow .networks import Sequential
4+
5+
6+ @pytest .fixture ()
7+ def sequential ():
8+ import keras
9+
10+ return Sequential (keras .layers .Flatten (), keras .layers .Dense (2 ))
11+
12+
13+ @pytest .fixture ()
14+ def build_shapes ():
15+ return {"input_shape" : (32 , 2 )}
Original file line number Diff line number Diff line change 1+ import keras
2+
3+ from bayesflow .utils .serialization import deserialize , serialize
4+
5+ from ...utils import assert_layers_equal
6+
7+
8+ def test_serialize_deserialize (sequential , build_shapes ):
9+ sequential .build (** build_shapes )
10+
11+ serialized = serialize (sequential )
12+ deserialized = deserialize (serialized )
13+ reserialized = serialize (deserialized )
14+
15+ assert reserialized == serialized
16+
17+
18+ def test_save_and_load (tmp_path , sequential , build_shapes ):
19+ sequential .build (** build_shapes )
20+
21+ keras .saving .save_model (sequential , tmp_path / "model.keras" )
22+ loaded = keras .saving .load_model (tmp_path / "model.keras" )
23+
24+ assert_layers_equal (sequential , loaded )
You can’t perform that action at this time.
0 commit comments