@@ -127,45 +127,55 @@ class Initializable(RNGMixin, Brick):
127127 ``True``.
128128 use_bias : :obj:`bool`, optional
129129 Whether to use a bias. Defaults to `True`. Required by
130- :meth:`~.Brick.initialize`. Only supported by bricks for which
131- :attr:`has_biases` is ``True``.
130+ :meth:`~.Brick.initialize`.
132131 rng : :class:`numpy.random.RandomState`
133132
134- Attributes
135- ----------
136- has_biases : bool
137- ``False`` if the brick does not support biases, and only has
138- :attr:`weights_init`. For an example of this, see
139- :class:`.Bidirectional`. If this is ``False``, the brick does not
140- support the arguments ``biases_init`` or ``use_bias``.
141-
142133 """
143- has_biases = True
144134
145135 @lazy ()
146- def __init__ (self , weights_init = None , biases_init = None , use_bias = None ,
136+ def __init__ (self , initialization_schemes = None , use_bias = True ,
147137 seed = None , ** kwargs ):
148- super (Initializable , self ).__init__ (** kwargs )
149- self .weights_init = weights_init
150- if self .has_biases :
151- self .biases_init = biases_init
152- elif biases_init is not None or not use_bias :
153- raise ValueError ("This brick does not support biases config" )
154- if use_bias is not None :
155- self .use_bias = use_bias
138+ self .use_bias = use_bias
156139 self .seed = seed
140+ self .initialization_schemes = initialization_schemes
141+ self .parameter_roles = set ([])
142+ if self .initialization_schemes is None :
143+ self .initialization_schemes = {}
144+
145+ kwargs_ = {}
146+ for key in kwargs :
147+ if key [- 5 :] == "_init" :
148+ if key in self .initialization_schemes :
149+ raise ValueError ("All initializations are accepted either"
150+ "through initialization_schemes or "
151+ "correspodong attribute but not both" )
152+ else :
153+ self .initialization_schemes [key ] = kwargs [key ]
154+ else :
155+ kwargs_ [key ] = kwargs [key ]
156+
157+ super (Initializable , self ).__init__ (** kwargs_ )
158+ self ._collect_roles ()
157159
158160 def _push_initialization_config (self ):
161+ for child in self .children :
162+ if (isinstance (child , Initializable ) and
163+ hasattr (child , 'initialization_schemes' )):
164+ for role in child .initialization_schemes :
165+ if role not in self .parameter_roles :
166+ raise ValueError ("The parameter role: " +
167+ "{} is not defined in" .format (role ) +
168+ "in the class parameter_roles" )
169+
159170 for child in self .children :
160171 if isinstance (child , Initializable ):
161172 child .rng = self .rng
162- if self .weights_init :
163- child .weights_init = self .weights_init
164- if hasattr (self , 'biases_init' ) and self .biases_init :
165- for child in self .children :
166- if (isinstance (child , Initializable ) and
167- hasattr (child , 'biases_init' )):
168- child .biases_init = self .biases_init
173+ child .initialization_schemes = self .initialization_schemes
174+
175+ def _collect_roles (self ):
176+ for child in self .children :
177+ if isinstance (child , Initializable ):
178+ self .parameter_roles .update (child .parameter_roles )
169179
170180
171181class LinearLike (Initializable ):
@@ -196,9 +206,11 @@ def b(self):
196206 def _initialize (self ):
197207 # Use self.parameters[] references in case W and b are overridden
198208 # to return non-shared-variables.
199- if getattr (self , 'use_bias' , True ):
200- self .biases_init .initialize (self .parameters [1 ], self .rng )
201- self .weights_init .initialize (self .parameters [0 ], self .rng )
209+ if self .use_bias :
210+ self .initialization_schemes ['biases_init' ].initialize (
211+ self .parameters [1 ], self .rng )
212+ self .initialization_schemes ['weights_init' ].initialize (
213+ self .parameters [0 ], self .rng )
202214
203215
204216class Random (Brick ):
0 commit comments