@@ -127,44 +127,55 @@ class Initializable(RNGMixin, Brick):
127127 ``True``.
128128 use_bias : :obj:`bool`, optional
129129 Whether to use a bias. Defaults to `True`. Required by
130- :meth:`~.Brick.initialize`. Only supported by bricks for which
131- :attr:`has_biases` is ``True``.
130+ :meth:`~.Brick.initialize`.
132131 rng : :class:`numpy.random.RandomState`
133132
134- Attributes
135- ----------
136- has_biases : bool
137- ``False`` if the brick does not support biases, and only has
138- :attr:`weights_init`. For an example of this, see
139- :class:`.Bidirectional`. If this is ``False``, the brick does not
140- support the arguments ``biases_init`` or ``use_bias``.
141-
142133 """
143- has_biases = True
144134
145135 @lazy ()
146- def __init__ (self , weights_init = None , biases_init = None , use_bias = True ,
136+ def __init__ (self , initialization_schemes = None , use_bias = True ,
147137 seed = None , ** kwargs ):
148- super (Initializable , self ).__init__ (** kwargs )
149- self .weights_init = weights_init
150- if self .has_biases :
151- self .biases_init = biases_init
152- elif biases_init is not None or not use_bias :
153- raise ValueError ("This brick does not support biases config" )
154138 self .use_bias = use_bias
155139 self .seed = seed
140+ self .initialization_schemes = initialization_schemes
141+ self .parameter_roles = set ([])
142+ if self .initialization_schemes is None :
143+ self .initialization_schemes = {}
144+
145+ kwargs_ = {}
146+ for key in kwargs :
147+ if key [- 5 :] == "_init" :
148+ if key in self .initialization_schemes :
149+ raise ValueError ("All initializations are accepted either"
150+ "through initialization_schemes or "
151+ "correspodong attribute but not both" )
152+ else :
153+ self .initialization_schemes [key ] = kwargs [key ]
154+ else :
155+ kwargs_ [key ] = kwargs [key ]
156+
157+ super (Initializable , self ).__init__ (** kwargs_ )
158+ self ._collect_roles ()
156159
157160 def _push_initialization_config (self ):
161+ for child in self .children :
162+ if (isinstance (child , Initializable ) and
163+ hasattr (child , 'initialization_schemes' )):
164+ for role in child .initialization_schemes :
165+ if role not in self .parameter_roles :
166+ raise ValueError ("The parameter role: " +
167+ "{} is not defined in" .format (role ) +
168+ "in the class parameter_roles" )
169+
158170 for child in self .children :
159171 if isinstance (child , Initializable ):
160172 child .rng = self .rng
161- if self .weights_init :
162- child .weights_init = self .weights_init
163- if hasattr (self , 'biases_init' ) and self .biases_init :
164- for child in self .children :
165- if (isinstance (child , Initializable ) and
166- hasattr (child , 'biases_init' )):
167- child .biases_init = self .biases_init
173+ child .initialization_schemes = self .initialization_schemes
174+
175+ def _collect_roles (self ):
176+ for child in self .children :
177+ if isinstance (child , Initializable ):
178+ self .parameter_roles .update (child .parameter_roles )
168179
169180
170181class LinearLike (Initializable ):
@@ -196,8 +207,10 @@ def _initialize(self):
196207 # Use self.parameters[] references in case W and b are overridden
197208 # to return non-shared-variables.
198209 if self .use_bias :
199- self .biases_init .initialize (self .parameters [1 ], self .rng )
200- self .weights_init .initialize (self .parameters [0 ], self .rng )
210+ self .initialization_schemes ['biases_init' ].initialize (
211+ self .parameters [1 ], self .rng )
212+ self .initialization_schemes ['weights_init' ].initialize (
213+ self .parameters [0 ], self .rng )
201214
202215
203216class Random (Brick ):
0 commit comments