11"""Bricks that are interfaces and/or mixins."""
22import numpy
3+ import logging
34from six import add_metaclass
45from theano .sandbox .rng_mrg import MRG_RandomStreams
56
67from ..config import config
78from .base import _Brick , Brick , lazy
89from blocks .roles import WEIGHT , BIAS , FILTER , INITIAL_STATE
910
11+ logger = logging .getLogger (__name__ )
12+
1013
1114class ActivationDocumentation (_Brick ):
1215 """Dynamically adds documentation to activations.
@@ -133,24 +136,37 @@ class Initializable(RNGMixin, Brick):
133136
134137 """
135138
136- initializable_roles = [WEIGHT , BIAS , FILTER , INITIAL_STATE ]
137-
138139 @lazy ()
139- def __init__ (self , initialization_schemes = None , use_bias = True ,
140- seed = None , ** kwargs ):
140+ def __init__ (self , initialization_schemes = None , parameter_roles = None ,
141+ use_bias = True , seed = None , ** kwargs ):
141142 self .use_bias = use_bias
142143 self .seed = seed
143144 self .initialization_schemes = initialization_schemes
144- self .parameter_roles = set ([])
145145 if self .initialization_schemes is None :
146146 self .initialization_schemes = {}
147147
148+ if parameter_roles :
149+ self .parameter_roles = parameter_roles
150+ else :
151+ # logger.warning("The block has not received parameter_roles, "
152+ # "hence only the default WEIGHT and BIAS are set."
153+ # "It's a good idea to manually set the roles "
154+ # "of all the initlizable parameters inside "
155+ # "parameter_roles")
156+ self .parameter_roles = set ([WEIGHT ])
157+ if use_bias :
158+ self .parameter_roles .update (set ([BIAS ]))
159+
148160 initialization_to_role = {"weights_init" : WEIGHT , 'biases_init' : BIAS ,
149161 'initial_state_init' : INITIAL_STATE }
150162 for key in list (kwargs .keys ()):
151163 if key [- 5 :] == "_init" :
164+ if key not in initialization_to_role :
165+ raise ValueError ("The initlization scheme: {}" .format (key ),
166+ "is not defined by default, pass it"
167+ "via initialization_schemes" )
152168 if initialization_to_role [key ] in \
153- self .initialization_schemes .keys ():
169+ self .initialization_schemes .keys ():
154170 raise ValueError ("All initializations are accepted either"
155171 "through initialization schemes or "
156172 "corresponding attribute but not both" )
@@ -159,47 +175,47 @@ def __init__(self, initialization_schemes=None, use_bias=True,
159175 key ]] = kwargs [key ]
160176 kwargs .pop (key )
161177
162- for key in self .initialization_schemes :
163- if key not in self .initializable_roles :
164- raise ValueError ("{} is not member of " .format (key ) +
165- "initializable_roles" )
166-
167178 super (Initializable , self ).__init__ (** kwargs )
179+ self ._collect_roles ()
168180
169- def _validate_roles_schmes (self ):
181+ def _validate_roles (self ):
182+ high_level_roles = []
170183 for role in self .parameter_roles :
171184 if role not in self .initialization_schemes .keys ():
172- found = False
173- for init_role in list (self .initialization_schemes .keys ()):
174- if isinstance (role , type (init_role )):
185+ for key in self .initialization_schemes .keys ():
186+ if isinstance (role , type (key )):
175187 self .initialization_schemes [role ] = \
176- self .initialization_schemes [init_role ]
177- found = True
178- if not found :
179- raise ValueError ("There is no initialization_schemes"
180- " defined for {}" .format (role ))
188+ self .initialization_schemes [key ]
189+ high_level_roles .append (key )
190+
191+ for key in high_level_roles :
192+ if key not in self .parameter_roles :
193+ self .initialization_schemes .pop (key )
194+
195+ for key in self .initialization_schemes :
196+ if key not in self .parameter_roles :
197+ raise ValueError ("{} is not member of " .format (key ) +
198+ "parameter_roles" )
181199
182200 def _push_initialization_config (self ):
183- self ._collect_roles ()
184- self ._validate_roles_schmes ()
201+ self ._validate_roles ()
185202 for child in self .children :
186203 if (isinstance (child , Initializable ) and
187204 hasattr (child , 'initialization_schemes' )):
188205 child .rng = self .rng
189206 for role , scheme in self .initialization_schemes .items ():
190- child .initialization_schemes [role ] = scheme
207+ if role in child .parameter_roles :
208+ child .initialization_schemes [role ] = scheme
191209
192210 def _collect_roles (self ):
193- if hasattr (self , 'parameters' ):
194- for param in self .parameters :
195- for role in param .tag .roles :
196- if role in self .initializable_roles :
197- self .parameter_roles .update (set ([role ]))
211+ for child in self .children :
212+ if isinstance (child , Initializable ):
213+ self .parameter_roles .update (child .parameter_roles )
198214
199215 def _initialize (self ):
200216 for param in self .parameters :
201217 for role in param .tag .roles :
202- if role in self .initializable_roles :
218+ if role in self .parameter_roles :
203219 self .initialization_schemes [role ].initialize (param ,
204220 self .rng )
205221
@@ -210,7 +226,7 @@ def __getattr__(self, name):
210226 elif name == "biases_init" :
211227 if BIAS in self .initialization_schemes :
212228 return self .initialization_schemes [BIAS ]
213- raise AttributeError ( "Attribute {} not found" . format (name ) )
229+ super ( Initializable , self ). __getattr__ (name )
214230
215231 def __setattr__ (self , name , value ):
216232 if name == 'weights_init' :
@@ -235,6 +251,14 @@ class LinearLike(Initializable):
235251 first and biases (if ``use_bias`` is True) coming second.
236252
237253 """
254+
255+ def __init__ (self , ** kwargs ):
256+ if 'parameter_roles' in kwargs :
257+ kwargs ['parameter_roles' ].update (set ([WEIGHT , BIAS ]))
258+ else :
259+ kwargs ['parameter_roles' ] = set ([WEIGHT , BIAS ])
260+ super (LinearLike , self ).__init__ (** kwargs )
261+
238262 @property
239263 def W (self ):
240264 return self .parameters [0 ]
0 commit comments