Skip to content

Commit 45d4456

Browse files
committed
get rid of initializable_roles
1 parent 5087381 commit 45d4456

File tree

3 files changed

+75
-41
lines changed

3 files changed

+75
-41
lines changed

blocks/bricks/conv.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ class Convolutional(LinearLike):
7272
def __init__(self, filter_size, num_filters, num_channels, batch_size=None,
7373
image_size=(None, None), step=(1, 1), border_mode='valid',
7474
tied_biases=False, **kwargs):
75-
super(Convolutional, self).__init__(**kwargs)
75+
parameter_roles = set([FILTER, BIAS])
76+
super(Convolutional, self).__init__(parameter_roles=parameter_roles,
77+
**kwargs)
7678

7779
self.filter_size = filter_size
7880
self.num_filters = num_filters

blocks/bricks/interfaces.py

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
"""Bricks that are interfaces and/or mixins."""
22
import numpy
3+
import logging
34
from six import add_metaclass
45
from theano.sandbox.rng_mrg import MRG_RandomStreams
56

67
from ..config import config
78
from .base import _Brick, Brick, lazy
89
from blocks.roles import WEIGHT, BIAS, FILTER, INITIAL_STATE
910

11+
logger = logging.getLogger(__name__)
12+
1013

1114
class ActivationDocumentation(_Brick):
1215
"""Dynamically adds documentation to activations.
@@ -133,24 +136,37 @@ class Initializable(RNGMixin, Brick):
133136
134137
"""
135138

136-
initializable_roles = [WEIGHT, BIAS, FILTER, INITIAL_STATE]
137-
138139
@lazy()
139-
def __init__(self, initialization_schemes=None, use_bias=True,
140-
seed=None, **kwargs):
140+
def __init__(self, initialization_schemes=None, parameter_roles=None,
141+
use_bias=True, seed=None, **kwargs):
141142
self.use_bias = use_bias
142143
self.seed = seed
143144
self.initialization_schemes = initialization_schemes
144-
self.parameter_roles = set([])
145145
if self.initialization_schemes is None:
146146
self.initialization_schemes = {}
147147

148+
if parameter_roles:
149+
self.parameter_roles = parameter_roles
150+
else:
151+
# logger.warning("The block has not received parameter_roles, "
152+
# "hence only the default WEIGHT and BIAS are set."
153+
# "It's a good idea to manually set the roles "
154+
# "of all the initlizable parameters inside "
155+
# "parameter_roles")
156+
self.parameter_roles = set([WEIGHT])
157+
if use_bias:
158+
self.parameter_roles.update(set([BIAS]))
159+
148160
initialization_to_role = {"weights_init": WEIGHT, 'biases_init': BIAS,
149161
'initial_state_init': INITIAL_STATE}
150162
for key in list(kwargs.keys()):
151163
if key[-5:] == "_init":
164+
if key not in initialization_to_role:
165+
raise ValueError("The initlization scheme: {}".format(key),
166+
"is not defined by default, pass it"
167+
"via initialization_schemes")
152168
if initialization_to_role[key] in \
153-
self.initialization_schemes.keys():
169+
self.initialization_schemes.keys():
154170
raise ValueError("All initializations are accepted either"
155171
"through initialization schemes or "
156172
"corresponding attribute but not both")
@@ -159,47 +175,47 @@ def __init__(self, initialization_schemes=None, use_bias=True,
159175
key]] = kwargs[key]
160176
kwargs.pop(key)
161177

162-
for key in self.initialization_schemes:
163-
if key not in self.initializable_roles:
164-
raise ValueError("{} is not member of ".format(key) +
165-
"initializable_roles")
166-
167178
super(Initializable, self).__init__(**kwargs)
179+
self._collect_roles()
168180

169-
def _validate_roles_schmes(self):
181+
def _validate_roles(self):
182+
high_level_roles = []
170183
for role in self.parameter_roles:
171184
if role not in self.initialization_schemes.keys():
172-
found = False
173-
for init_role in list(self.initialization_schemes.keys()):
174-
if isinstance(role, type(init_role)):
185+
for key in self.initialization_schemes.keys():
186+
if isinstance(role, type(key)):
175187
self.initialization_schemes[role] = \
176-
self.initialization_schemes[init_role]
177-
found = True
178-
if not found:
179-
raise ValueError("There is no initialization_schemes"
180-
" defined for {}".format(role))
188+
self.initialization_schemes[key]
189+
high_level_roles.append(key)
190+
191+
for key in high_level_roles:
192+
if key not in self.parameter_roles:
193+
self.initialization_schemes.pop(key)
194+
195+
for key in self.initialization_schemes:
196+
if key not in self.parameter_roles:
197+
raise ValueError("{} is not member of ".format(key) +
198+
"parameter_roles")
181199

182200
def _push_initialization_config(self):
183-
self._collect_roles()
184-
self._validate_roles_schmes()
201+
self._validate_roles()
185202
for child in self.children:
186203
if (isinstance(child, Initializable) and
187204
hasattr(child, 'initialization_schemes')):
188205
child.rng = self.rng
189206
for role, scheme in self.initialization_schemes.items():
190-
child.initialization_schemes[role] = scheme
207+
if role in child.parameter_roles:
208+
child.initialization_schemes[role] = scheme
191209

192210
def _collect_roles(self):
193-
if hasattr(self, 'parameters'):
194-
for param in self.parameters:
195-
for role in param.tag.roles:
196-
if role in self.initializable_roles:
197-
self.parameter_roles.update(set([role]))
211+
for child in self.children:
212+
if isinstance(child, Initializable):
213+
self.parameter_roles.update(child.parameter_roles)
198214

199215
def _initialize(self):
200216
for param in self.parameters:
201217
for role in param.tag.roles:
202-
if role in self.initializable_roles:
218+
if role in self.parameter_roles:
203219
self.initialization_schemes[role].initialize(param,
204220
self.rng)
205221

@@ -210,7 +226,7 @@ def __getattr__(self, name):
210226
elif name == "biases_init":
211227
if BIAS in self.initialization_schemes:
212228
return self.initialization_schemes[BIAS]
213-
raise AttributeError("Attribute {} not found".format(name))
229+
super(Initializable, self).__getattr__(name)
214230

215231
def __setattr__(self, name, value):
216232
if name == 'weights_init':
@@ -235,6 +251,14 @@ class LinearLike(Initializable):
235251
first and biases (if ``use_bias`` is True) coming second.
236252
237253
"""
254+
255+
def __init__(self, **kwargs):
256+
if 'parameter_roles' in kwargs:
257+
kwargs['parameter_roles'].update(set([WEIGHT, BIAS]))
258+
else:
259+
kwargs['parameter_roles'] = set([WEIGHT, BIAS])
260+
super(LinearLike, self).__init__(**kwargs)
261+
238262
@property
239263
def W(self):
240264
return self.parameters[0]

blocks/bricks/recurrent.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from blocks.bricks import Initializable, Logistic, Tanh, Linear
1313
from blocks.bricks.base import Application, application, Brick, lazy
1414
from blocks.initialization import NdarrayInitialization, Constant
15-
from blocks.roles import add_role, WEIGHT, INITIAL_STATE
15+
from blocks.roles import add_role, WEIGHT, BIAS, INITIAL_STATE
1616
from blocks.utils import (pack, shared_floatx_nans, shared_floatx_zeros,
1717
dict_union, dict_subset, is_shared_variable)
1818
from blocks.bricks.parallel import Fork
@@ -279,9 +279,12 @@ class SimpleRecurrent(BaseRecurrent, Initializable):
279279
def __init__(self, dim, activation, **kwargs):
280280
self.dim = dim
281281
children = [activation] + kwargs.get('children', [])
282-
if not 'initial_state_init' in kwargs:
282+
if 'initial_state_init' not in kwargs:
283283
kwargs['initial_state_init'] = Constant(0.)
284-
super(SimpleRecurrent, self).__init__(children=children, **kwargs)
284+
parameter_roles = set([WEIGHT, BIAS, INITIAL_STATE])
285+
super(SimpleRecurrent, self).__init__(children=children,
286+
parameter_roles=parameter_roles,
287+
**kwargs)
285288

286289
@property
287290
def W(self):
@@ -300,7 +303,7 @@ def _allocate(self):
300303
name="W"))
301304
add_role(self.parameters[0], WEIGHT)
302305
self.parameters.append(shared_floatx_nans((self.dim,),
303-
name="initial_state"))
306+
name="initial_state"))
304307
add_role(self.parameters[1], INITIAL_STATE)
305308

306309
@recurrent(sequences=['inputs', 'mask'], states=['states'],
@@ -386,9 +389,11 @@ def __init__(self, dim, activation=None, gate_activation=None, **kwargs):
386389
children = ([self.activation, self.gate_activation] +
387390
kwargs.get('children', []))
388391

389-
if not 'initial_state_init' in kwargs:
392+
if 'initial_state_init' not in kwargs:
390393
kwargs['initial_state_init'] = Constant(0.)
391-
super(LSTM, self).__init__(children=children, **kwargs)
394+
parameter_roles = set([WEIGHT, BIAS, INITIAL_STATE])
395+
super(LSTM, self).__init__(children=children,
396+
parameter_roles=parameter_roles, **kwargs)
392397

393398
def get_dim(self, name):
394399
if name == 'inputs':
@@ -411,9 +416,9 @@ def _allocate(self):
411416
# The underscore is required to prevent collision with
412417
# the `initial_state` application method
413418
self.initial_state_ = shared_floatx_nans((self.dim,),
414-
name="initial_state")
419+
name="initial_state")
415420
self.initial_cells = shared_floatx_nans((self.dim,),
416-
name="initial_cells")
421+
name="initial_cells")
417422
add_role(self.W_state, WEIGHT)
418423
add_role(self.W_cell_to_in, WEIGHT)
419424
add_role(self.W_cell_to_forget, WEIGHT)
@@ -532,9 +537,12 @@ def __init__(self, dim, activation=None, gate_activation=None,
532537

533538
children = [activation, gate_activation] + kwargs.get('children', [])
534539

535-
if not 'initial_state_init' in kwargs:
540+
if 'initial_state_init' not in kwargs:
536541
kwargs['initial_state_init'] = Constant(0.)
537-
super(GatedRecurrent, self).__init__(children=children, **kwargs)
542+
parameter_roles = set([WEIGHT, BIAS, INITIAL_STATE])
543+
super(GatedRecurrent, self).__init__(children=children,
544+
parameter_roles=parameter_roles,
545+
**kwargs)
538546

539547
@property
540548
def state_to_state(self):

0 commit comments

Comments
 (0)