Skip to content

Commit 5087381

Browse files
committed
support weight_init, biases_init
1 parent 2fa5ea4 commit 5087381

File tree

4 files changed

+43
-23
lines changed

4 files changed

+43
-23
lines changed

blocks/bricks/interfaces.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ class Initializable(RNGMixin, Brick):
133133
134134
"""
135135

136-
initializable_roles = ['WEIGHT', 'BIAS', 'FILTER', 'INITIAL_STATE']
136+
initializable_roles = [WEIGHT, BIAS, FILTER, INITIAL_STATE]
137137

138138
@lazy()
139139
def __init__(self, initialization_schemes=None, use_bias=True,
@@ -145,39 +145,39 @@ def __init__(self, initialization_schemes=None, use_bias=True,
145145
if self.initialization_schemes is None:
146146
self.initialization_schemes = {}
147147

148-
149-
initialization_to_role = {"weights_init": 'WEIGHT', 'biases_init': 'BIAS',
150-
'initial_state_init': 'INITIAL_STATE'}
148+
initialization_to_role = {"weights_init": WEIGHT, 'biases_init': BIAS,
149+
'initial_state_init': INITIAL_STATE}
151150
for key in list(kwargs.keys()):
152151
if key[-5:] == "_init":
153-
if initialization_to_role[key] in self.initialization_schemes.keys():
152+
if initialization_to_role[key] in \
153+
self.initialization_schemes.keys():
154154
raise ValueError("All initializations are accepted either"
155155
"through initialization schemes or "
156156
"corresponding attribute but not both")
157157
else:
158-
self.initialization_schemes[initialization_to_role[key]] = kwargs[key]
158+
self.initialization_schemes[initialization_to_role[
159+
key]] = kwargs[key]
159160
kwargs.pop(key)
160161

161162
for key in self.initialization_schemes:
162163
if key not in self.initializable_roles:
163-
raise ValueError("{} is not member of ".format(str(key)) +
164-
"initializable_roles")
164+
raise ValueError("{} is not member of ".format(key) +
165+
"initializable_roles")
165166

166167
super(Initializable, self).__init__(**kwargs)
167168

168-
169169
def _validate_roles_schmes(self):
170170
for role in self.parameter_roles:
171171
if role not in self.initialization_schemes.keys():
172172
found = False
173173
for init_role in list(self.initialization_schemes.keys()):
174-
if isinstance(eval(role), type(eval(init_role))):
175-
self.initialization_schemes[role] = self.initialization_schemes[init_role]
174+
if isinstance(role, type(init_role)):
175+
self.initialization_schemes[role] = \
176+
self.initialization_schemes[init_role]
176177
found = True
177178
if not found:
178179
raise ValueError("There is no initialization_schemes"
179-
" defined for {}".format(role))
180-
180+
" defined for {}".format(role))
181181

182182
def _push_initialization_config(self):
183183
self._collect_roles()
@@ -189,19 +189,37 @@ def _push_initialization_config(self):
189189
for role, scheme in self.initialization_schemes.items():
190190
child.initialization_schemes[role] = scheme
191191

192-
193192
def _collect_roles(self):
194193
if hasattr(self, 'parameters'):
195194
for param in self.parameters:
196195
for role in param.tag.roles:
197-
if str(role) in self.initializable_roles:
198-
self.parameter_roles.update(set([str(role)]))
196+
if role in self.initializable_roles:
197+
self.parameter_roles.update(set([role]))
199198

200199
def _initialize(self):
201200
for param in self.parameters:
202201
for role in param.tag.roles:
203-
if str(role) in self.initializable_roles:
204-
self.initialization_schemes[str(role)].initialize(param, self.rng)
202+
if role in self.initializable_roles:
203+
self.initialization_schemes[role].initialize(param,
204+
self.rng)
205+
206+
def __getattr__(self, name):
207+
if name == "weights_init":
208+
if WEIGHT in self.initialization_schemes:
209+
return self.initialization_schemes[WEIGHT]
210+
elif name == "biases_init":
211+
if BIAS in self.initialization_schemes:
212+
return self.initialization_schemes[BIAS]
213+
raise AttributeError("Attribute {} not found".format(name))
214+
215+
def __setattr__(self, name, value):
216+
if name == 'weights_init':
217+
self.initialization_schemes[WEIGHT] = value
218+
elif name == 'biases_init':
219+
self.initialization_schemes[BIAS] = value
220+
else:
221+
super(Initializable, self).__setattr__(name, value)
222+
205223

206224
class LinearLike(Initializable):
207225
"""Initializable subclass with logic for :class:`Linear`-like classes.
@@ -229,7 +247,6 @@ def b(self):
229247
raise AttributeError('use_bias is False')
230248

231249

232-
233250
class Random(Brick):
234251
"""A mixin class for Bricks which need Theano RNGs.
235252

blocks/roles.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ def __repr__(self):
7171
return re.sub(r'(?!^)([A-Z]+)', r'_\1',
7272
self.__class__.__name__[:-4]).upper()
7373

74+
def __hash__(self):
75+
return hash(str(self))
76+
7477

7578
class InputRole(VariableRole):
7679
pass

tests/bricks/test_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ def test_attention_recurrent():
7373
state_names=wrapped.apply.states,
7474
attended_dim=attended_dim, match_dim=attended_dim)
7575
recurrent = AttentionRecurrent(wrapped, attention, seed=1234)
76-
recurrent.initialization_schemes['WEIGHT'] = IsotropicGaussian(0.5)
77-
recurrent.initialization_schemes['BIAS'] = Constant(0)
76+
recurrent.weights_init = IsotropicGaussian(0.5)
77+
recurrent.biases_init = Constant(0)
7878
recurrent.initialize()
7979

8080
attended = tensor.tensor3("attended")

tests/bricks/test_recurrent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -542,8 +542,8 @@ def setUp(self):
542542
for _ in range(3)]
543543
self.stack = RecurrentStack(self.layers)
544544
for fork in self.stack.forks:
545-
fork.initialization_schemes['WEIGHT'] = Identity(1)
546-
fork.initialization_schemes['BIAS'] = Constant(0)
545+
fork.weights_init = Identity(1)
546+
fork.biases_init = Constant(0)
547547
self.stack.initialize()
548548

549549
self.x_val = 0.1 * numpy.asarray(

0 commit comments

Comments
 (0)