33
44import keras
55
6- from bayesflow .utils import sequential_kwargs
6+ from bayesflow .utils import layer_kwargs
77from bayesflow .utils .serialization import deserialize , serializable , serialize
88
9+ from ..sequential import Sequential
910from ..residual import Residual
1011
1112
1213@serializable ("bayesflow.networks" )
13- class MLP (keras . Sequential ):
14+ class MLP (Sequential ):
1415 """
1516 Implements a simple configurable MLP with optional residual connections and dropout.
1617
@@ -67,40 +68,44 @@ def __init__(
6768 self .norm = norm
6869 self .spectral_normalization = spectral_normalization
6970
70- layers = []
71+ blocks = []
7172
7273 for width in widths :
73- layer = self ._make_layer (
74+ block = self ._make_block (
7475 width , activation , kernel_initializer , residual , dropout , norm , spectral_normalization
7576 )
76- layers .append (layer )
77+ blocks .append (block )
7778
78- super ().__init__ (layers , ** sequential_kwargs ( kwargs ) )
79+ super ().__init__ (* blocks , ** kwargs )
7980
8081 def build (self , input_shape = None ):
8182 if self .built :
8283 # building when the network is already built can cause issues with serialization
8384 # see https://github.com/keras-team/keras/issues/21147
8485 return
8586
86- # we only care about the last dimension, and using ... signifies to keras.Sequential
87- # that any number of batch dimensions is valid (which is what we want for all sublayers)
88- # we also have to avoid calling super().build() because this causes
89- # shape errors when building on non-sets but doing inference on sets
90- # this is a work-around for https://github.com/keras-team/keras/issues/21158
91- input_shape = (..., input_shape [- 1 ])
92-
9387 for layer in self ._layers :
9488 layer .build (input_shape )
9589 input_shape = layer .compute_output_shape (input_shape )
9690
91+ def call (self , x , training = None , mask = None ):
92+ for layer in self ._layers :
93+ kwargs = {}
94+ if layer ._call_has_mask_arg :
95+ kwargs ["mask" ] = mask
96+ if layer ._call_has_training_arg and training is not None :
97+ kwargs ["training" ] = training
98+
99+ x = layer (x , ** kwargs )
100+ return x
101+
97102 @classmethod
98103 def from_config (cls , config , custom_objects = None ):
99104 return cls (** deserialize (config , custom_objects = custom_objects ))
100105
101106 def get_config (self ):
102107 base_config = super ().get_config ()
103- base_config = sequential_kwargs (base_config )
108+ base_config = layer_kwargs (base_config )
104109
105110 config = {
106111 "widths" : self .widths ,
@@ -115,7 +120,7 @@ def get_config(self):
115120 return base_config | serialize (config )
116121
117122 @staticmethod
118- def _make_layer (width , activation , kernel_initializer , residual , dropout , norm , spectral_normalization ):
123+ def _make_block (width , activation , kernel_initializer , residual , dropout , norm , spectral_normalization ):
119124 layers = []
120125
121126 dense = keras .layers .Dense (width , kernel_initializer = kernel_initializer )
@@ -148,4 +153,4 @@ def _make_layer(width, activation, kernel_initializer, residual, dropout, norm,
148153 if residual :
149154 return Residual (* layers )
150155
151- return keras . Sequential (layers )
156+ return Sequential (layers )
0 commit comments