Skip to content

Commit 21ce4cf

Browse files
committed
managing initilizations via a a role scheme dictionary
1 parent 9f8189e commit 21ce4cf

File tree

2 files changed

+43
-30
lines changed

2 files changed

+43
-30
lines changed

blocks/bricks/interfaces.py

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -127,45 +127,55 @@ class Initializable(RNGMixin, Brick):
127127
``True``.
128128
use_bias : :obj:`bool`, optional
129129
Whether to use a bias. Defaults to `True`. Required by
130-
:meth:`~.Brick.initialize`. Only supported by bricks for which
131-
:attr:`has_biases` is ``True``.
130+
:meth:`~.Brick.initialize`.
132131
rng : :class:`numpy.random.RandomState`
133132
134-
Attributes
135-
----------
136-
has_biases : bool
137-
``False`` if the brick does not support biases, and only has
138-
:attr:`weights_init`. For an example of this, see
139-
:class:`.Bidirectional`. If this is ``False``, the brick does not
140-
support the arguments ``biases_init`` or ``use_bias``.
141-
142133
"""
143-
has_biases = True
144134

145135
@lazy()
146-
def __init__(self, weights_init=None, biases_init=None, use_bias=None,
136+
def __init__(self, initialization_schemes=None, use_bias=True,
147137
seed=None, **kwargs):
148-
super(Initializable, self).__init__(**kwargs)
149-
self.weights_init = weights_init
150-
if self.has_biases:
151-
self.biases_init = biases_init
152-
elif biases_init is not None or not use_bias:
153-
raise ValueError("This brick does not support biases config")
154-
if use_bias is not None:
155-
self.use_bias = use_bias
138+
self.use_bias = use_bias
156139
self.seed = seed
140+
self.initialization_schemes = initialization_schemes
141+
self.parameter_roles = set([])
142+
if self.initialization_schemes is None:
143+
self.initialization_schemes = {}
144+
145+
kwargs_ = {}
146+
for key in kwargs:
147+
if key[-5:] == "_init":
148+
if key in self.initialization_schemes:
149+
raise ValueError("All initializations are accepted either"
150+
"through initialization_schemes or "
151+
"correspodong attribute but not both")
152+
else:
153+
self.initialization_schemes[key] = kwargs[key]
154+
else:
155+
kwargs_[key] = kwargs[key]
156+
157+
super(Initializable, self).__init__(**kwargs_)
158+
self._collect_roles()
157159

158160
def _push_initialization_config(self):
161+
for child in self.children:
162+
if (isinstance(child, Initializable) and
163+
hasattr(child, 'initialization_schemes')):
164+
for role in child.initialization_schemes:
165+
if role not in self.parameter_roles:
166+
raise ValueError("The parameter role: " +
167+
"{} is not defined in".format(role) +
168+
"in the class parameter_roles")
169+
159170
for child in self.children:
160171
if isinstance(child, Initializable):
161172
child.rng = self.rng
162-
if self.weights_init:
163-
child.weights_init = self.weights_init
164-
if hasattr(self, 'biases_init') and self.biases_init:
165-
for child in self.children:
166-
if (isinstance(child, Initializable) and
167-
hasattr(child, 'biases_init')):
168-
child.biases_init = self.biases_init
173+
child.initialization_schemes = self.initialization_schemes
174+
175+
def _collect_roles(self):
176+
for child in self.children:
177+
if isinstance(child, Initializable):
178+
self.parameter_roles.update(child.parameter_roles)
169179

170180

171181
class LinearLike(Initializable):
@@ -196,9 +206,11 @@ def b(self):
196206
def _initialize(self):
197207
# Use self.parameters[] references in case W and b are overridden
198208
# to return non-shared-variables.
199-
if getattr(self, 'use_bias', True):
200-
self.biases_init.initialize(self.parameters[1], self.rng)
201-
self.weights_init.initialize(self.parameters[0], self.rng)
209+
if self.use_bias:
210+
self.initialization_schemes['biases_init'].initialize(
211+
self.parameters[1], self.rng)
212+
self.initialization_schemes['weights_init'].initialize(
213+
self.parameters[0], self.rng)
202214

203215

204216
class Random(Brick):

blocks/bricks/simple.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(self, input_dim, output_dim, **kwargs):
4343
super(Linear, self).__init__(**kwargs)
4444
self.input_dim = input_dim
4545
self.output_dim = output_dim
46+
self.parameter_roles = set(['weights_init', 'biases_init'])
4647

4748
def _allocate(self):
4849
W = shared_floatx_nans((self.input_dim, self.output_dim), name='W')

0 commit comments

Comments
 (0)