Skip to content

Commit 22b5a09

Browse files
committed
reorganize initialization
1 parent 21ce4cf commit 22b5a09

File tree

7 files changed

+65
-65
lines changed

7 files changed

+65
-65
lines changed

blocks/bricks/interfaces.py

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from ..config import config
77
from .base import _Brick, Brick, lazy
8+
from blocks.roles import WEIGHT, BIAS, FILTER, INITIAL_STATE
89

910

1011
class ActivationDocumentation(_Brick):
@@ -132,6 +133,8 @@ class Initializable(RNGMixin, Brick):
132133
133134
"""
134135

136+
initializable_roles = ['WEIGHT', 'BIAS', 'FILTER', 'INITIAL_STATE']
137+
135138
@lazy()
136139
def __init__(self, initialization_schemes=None, use_bias=True,
137140
seed=None, **kwargs):
@@ -142,41 +145,62 @@ def __init__(self, initialization_schemes=None, use_bias=True,
142145
if self.initialization_schemes is None:
143146
self.initialization_schemes = {}
144147

145-
kwargs_ = {}
146-
for key in kwargs:
148+
149+
initialization_to_role = {"weights_init": 'WEIGHT', 'biases_init': 'BIAS',
150+
'initial_state_init': 'INITIAL_STATE'}
151+
for key in list(kwargs.keys()):
147152
if key[-5:] == "_init":
148-
if key in self.initialization_schemes:
153+
if initialization_to_role[key] in self.initialization_schemes.keys():
149154
raise ValueError("All initializations are accepted either"
150-
"through initialization_schemes or "
151-
"correspodong attribute but not both")
155+
"through initialization schemes or "
156+
"corresponding attribute but not both")
152157
else:
153-
self.initialization_schemes[key] = kwargs[key]
154-
else:
155-
kwargs_[key] = kwargs[key]
158+
self.initialization_schemes[initialization_to_role[key]] = kwargs[key]
159+
kwargs.pop(key)
160+
161+
for key in self.initialization_schemes:
162+
if key not in self.initializable_roles:
163+
raise ValueError("{} is not member of ".format(str(key)) +
164+
"initializable_roles")
165+
166+
super(Initializable, self).__init__(**kwargs)
167+
168+
169+
def _validate_roles_schmes(self):
170+
for role in self.parameter_roles:
171+
if role not in self.initialization_schemes.keys():
172+
found = False
173+
for init_role in list(self.initialization_schemes.keys()):
174+
if isinstance(eval(role), type(eval(init_role))):
175+
self.initialization_schemes[role] = self.initialization_schemes[init_role]
176+
found = True
177+
if not found:
178+
raise ValueError("There is no initialization_schemes"
179+
" defined for {}".format(role))
156180

157-
super(Initializable, self).__init__(**kwargs_)
158-
self._collect_roles()
159181

160182
def _push_initialization_config(self):
183+
self._collect_roles()
184+
self._validate_roles_schmes()
161185
for child in self.children:
162186
if (isinstance(child, Initializable) and
163187
hasattr(child, 'initialization_schemes')):
164-
for role in child.initialization_schemes:
165-
if role not in self.parameter_roles:
166-
raise ValueError("The parameter role: " +
167-
"{} is not defined in".format(role) +
168-
"in the class parameter_roles")
169-
170-
for child in self.children:
171-
if isinstance(child, Initializable):
172188
child.rng = self.rng
173-
child.initialization_schemes = self.initialization_schemes
189+
for role, scheme in self.initialization_schemes.items():
190+
child.initialization_schemes[role] = scheme
191+
174192

175193
def _collect_roles(self):
176-
for child in self.children:
177-
if isinstance(child, Initializable):
178-
self.parameter_roles.update(child.parameter_roles)
194+
for param in self.parameters:
195+
for role in param.tag.roles:
196+
if str(role) in self.initializable_roles:
197+
self.parameter_roles.update(set([str(role)]))
179198

199+
def _initialize(self):
200+
for param in self.parameters:
201+
for role in param.tag.roles:
202+
if str(role) in self.initializable_roles:
203+
self.initialization_schemes[str(role)].initialize(param, self.rng)
180204

181205
class LinearLike(Initializable):
182206
"""Initializable subclass with logic for :class:`Linear`-like classes.
@@ -203,14 +227,6 @@ def b(self):
203227
else:
204228
raise AttributeError('use_bias is False')
205229

206-
def _initialize(self):
207-
# Use self.parameters[] references in case W and b are overridden
208-
# to return non-shared-variables.
209-
if self.use_bias:
210-
self.initialization_schemes['biases_init'].initialize(
211-
self.parameters[1], self.rng)
212-
self.initialization_schemes['weights_init'].initialize(
213-
self.parameters[0], self.rng)
214230

215231

216232
class Random(Brick):

blocks/bricks/lookup.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,6 @@ def _allocate(self):
4141
name='W'))
4242
add_role(self.parameters[-1], WEIGHT)
4343

44-
def _initialize(self):
45-
self.weights_init.initialize(self.W, self.rng)
46-
4744
@application(inputs=['indices'], outputs=['output'])
4845
def apply(self, indices):
4946
"""Perform lookup.

blocks/bricks/recurrent.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from blocks.bricks import Initializable, Logistic, Tanh, Linear
1313
from blocks.bricks.base import Application, application, Brick, lazy
14-
from blocks.initialization import NdarrayInitialization
14+
from blocks.initialization import NdarrayInitialization, Constant
1515
from blocks.roles import add_role, WEIGHT, INITIAL_STATE
1616
from blocks.utils import (pack, shared_floatx_nans, shared_floatx_zeros,
1717
dict_union, dict_subset, is_shared_variable)
@@ -279,6 +279,8 @@ class SimpleRecurrent(BaseRecurrent, Initializable):
279279
def __init__(self, dim, activation, **kwargs):
280280
self.dim = dim
281281
children = [activation] + kwargs.get('children', [])
282+
if not 'initial_state_init' in kwargs:
283+
kwargs['initial_state_init'] = Constant(0.)
282284
super(SimpleRecurrent, self).__init__(children=children, **kwargs)
283285

284286
@property
@@ -297,13 +299,10 @@ def _allocate(self):
297299
self.parameters.append(shared_floatx_nans((self.dim, self.dim),
298300
name="W"))
299301
add_role(self.parameters[0], WEIGHT)
300-
self.parameters.append(shared_floatx_zeros((self.dim,),
302+
self.parameters.append(shared_floatx_nans((self.dim,),
301303
name="initial_state"))
302304
add_role(self.parameters[1], INITIAL_STATE)
303305

304-
def _initialize(self):
305-
self.weights_init.initialize(self.W, self.rng)
306-
307306
@recurrent(sequences=['inputs', 'mask'], states=['states'],
308307
outputs=['states'], contexts=[])
309308
def apply(self, inputs, states, mask=None):
@@ -386,6 +385,9 @@ def __init__(self, dim, activation=None, gate_activation=None, **kwargs):
386385

387386
children = ([self.activation, self.gate_activation] +
388387
kwargs.get('children', []))
388+
389+
if not 'initial_state_init' in kwargs:
390+
kwargs['initial_state_init'] = Constant(0.)
389391
super(LSTM, self).__init__(children=children, **kwargs)
390392

391393
def get_dim(self, name):
@@ -408,9 +410,9 @@ def _allocate(self):
408410
name='W_cell_to_out')
409411
# The underscore is required to prevent collision with
410412
# the `initial_state` application method
411-
self.initial_state_ = shared_floatx_zeros((self.dim,),
413+
self.initial_state_ = shared_floatx_nans((self.dim,),
412414
name="initial_state")
413-
self.initial_cells = shared_floatx_zeros((self.dim,),
415+
self.initial_cells = shared_floatx_nans((self.dim,),
414416
name="initial_cells")
415417
add_role(self.W_state, WEIGHT)
416418
add_role(self.W_cell_to_in, WEIGHT)
@@ -423,10 +425,6 @@ def _allocate(self):
423425
self.W_state, self.W_cell_to_in, self.W_cell_to_forget,
424426
self.W_cell_to_out, self.initial_state_, self.initial_cells]
425427

426-
def _initialize(self):
427-
for weights in self.parameters[:4]:
428-
self.weights_init.initialize(weights, self.rng)
429-
430428
@recurrent(sequences=['inputs', 'mask'], states=['states', 'cells'],
431429
contexts=[], outputs=['states', 'cells'])
432430
def apply(self, inputs, states, cells, mask=None):
@@ -533,6 +531,9 @@ def __init__(self, dim, activation=None, gate_activation=None,
533531
self.gate_activation = gate_activation
534532

535533
children = [activation, gate_activation] + kwargs.get('children', [])
534+
535+
if not 'initial_state_init' in kwargs:
536+
kwargs['initial_state_init'] = Constant(0.)
536537
super(GatedRecurrent, self).__init__(children=children, **kwargs)
537538

538539
@property
@@ -557,22 +558,13 @@ def _allocate(self):
557558
name='state_to_state'))
558559
self.parameters.append(shared_floatx_nans((self.dim, 2 * self.dim),
559560
name='state_to_gates'))
560-
self.parameters.append(shared_floatx_zeros((self.dim,),
561+
self.parameters.append(shared_floatx_nans((self.dim,),
561562
name="initial_state"))
562563
for i in range(2):
563564
if self.parameters[i]:
564565
add_role(self.parameters[i], WEIGHT)
565566
add_role(self.parameters[2], INITIAL_STATE)
566567

567-
def _initialize(self):
568-
self.weights_init.initialize(self.state_to_state, self.rng)
569-
state_to_update = self.weights_init.generate(
570-
self.rng, (self.dim, self.dim))
571-
state_to_reset = self.weights_init.generate(
572-
self.rng, (self.dim, self.dim))
573-
self.state_to_gates.set_value(
574-
numpy.hstack([state_to_update, state_to_reset]))
575-
576568
@recurrent(sequences=['mask', 'inputs', 'gate_inputs'],
577569
states=['states'], outputs=['states'], contexts=[])
578570
def apply(self, inputs, gate_inputs, states, mask=None):

blocks/bricks/simple.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def __init__(self, input_dim, output_dim, **kwargs):
4343
super(Linear, self).__init__(**kwargs)
4444
self.input_dim = input_dim
4545
self.output_dim = output_dim
46-
self.parameter_roles = set(['weights_init', 'biases_init'])
4746

4847
def _allocate(self):
4948
W = shared_floatx_nans((self.input_dim, self.output_dim), name='W')
@@ -96,10 +95,6 @@ def _allocate(self):
9695
add_role(b, BIAS)
9796
self.parameters.append(b)
9897

99-
def _initialize(self):
100-
b, = self.parameters
101-
self.biases_init.initialize(b, self.rng)
102-
10398
@application(inputs=['input_'], outputs=['output'])
10499
def apply(self, input_):
105100
"""Apply the linear transformation.

tests/bricks/test_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ def test_attention_recurrent():
7373
state_names=wrapped.apply.states,
7474
attended_dim=attended_dim, match_dim=attended_dim)
7575
recurrent = AttentionRecurrent(wrapped, attention, seed=1234)
76-
recurrent.weights_init = IsotropicGaussian(0.5)
77-
recurrent.biases_init = Constant(0)
76+
recurrent.initialization_schemes['WEIGHT'] = IsotropicGaussian(0.5)
77+
recurrent.initialization_schemes['BIAS'] = Constant(0)
7878
recurrent.initialize()
7979

8080
attended = tensor.tensor3("attended")

tests/bricks/test_recurrent.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ def setUp(self):
502502
dim=3, activation=Tanh()))
503503
self.simple = SimpleRecurrent(dim=3, weights_init=Orthogonal(),
504504
activation=Tanh(), seed=1)
505-
self.bidir.allocate()
505+
self.bidir.initialize()
506506
self.simple.initialize()
507507
self.bidir.children[0].parameters[0].set_value(
508508
self.simple.parameters[0].get_value())
@@ -542,8 +542,8 @@ def setUp(self):
542542
for _ in range(3)]
543543
self.stack = RecurrentStack(self.layers)
544544
for fork in self.stack.forks:
545-
fork.weights_init = Identity(1)
546-
fork.biases_init = Constant(0)
545+
fork.initialization_schemes['WEIGHT'] = Identity(1)
546+
fork.initialization_schemes['BIAS'] = Constant(0)
547547
self.stack.initialize()
548548

549549
self.x_val = 0.1 * numpy.asarray(

tests/bricks/test_sequence_generators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def test_integer_sequence_generator():
160160
assert outputs_val.shape == (n_steps, batch_size)
161161
assert outputs_val.dtype == 'int64'
162162
assert costs_val.shape == (n_steps, batch_size)
163-
assert_allclose(states_val.sum(), -17.854, rtol=1e-5)
163+
assert_allclose(states_val.sum(), -17.889, rtol=1e-5)
164164
assert_allclose(costs_val.sum(), 482.868, rtol=1e-5)
165165
assert outputs_val.sum() == 629
166166

0 commit comments

Comments
 (0)