@@ -133,7 +133,7 @@ class Initializable(RNGMixin, Brick):
133133
134134 """
135135
136- initializable_roles = [' WEIGHT' , ' BIAS' , ' FILTER' , ' INITIAL_STATE' ]
136+ initializable_roles = [WEIGHT , BIAS , FILTER , INITIAL_STATE ]
137137
138138 @lazy ()
139139 def __init__ (self , initialization_schemes = None , use_bias = True ,
@@ -145,39 +145,39 @@ def __init__(self, initialization_schemes=None, use_bias=True,
145145 if self .initialization_schemes is None :
146146 self .initialization_schemes = {}
147147
148-
149- initialization_to_role = {"weights_init" : 'WEIGHT' , 'biases_init' : 'BIAS' ,
150- 'initial_state_init' : 'INITIAL_STATE' }
148+ initialization_to_role = {"weights_init" : WEIGHT , 'biases_init' : BIAS ,
149+ 'initial_state_init' : INITIAL_STATE }
151150 for key in list (kwargs .keys ()):
152151 if key [- 5 :] == "_init" :
153- if initialization_to_role [key ] in self .initialization_schemes .keys ():
152+ if initialization_to_role [key ] in \
153+ self .initialization_schemes .keys ():
154154 raise ValueError ("All initializations are accepted either"
155155 "through initialization schemes or "
156156 "corresponding attribute but not both" )
157157 else :
158- self .initialization_schemes [initialization_to_role [key ]] = kwargs [key ]
158+ self .initialization_schemes [initialization_to_role [
159+ key ]] = kwargs [key ]
159160 kwargs .pop (key )
160161
161162 for key in self .initialization_schemes :
162163 if key not in self .initializable_roles :
163- raise ValueError ("{} is not member of " .format (str ( key ) ) +
164- "initializable_roles" )
164+ raise ValueError ("{} is not member of " .format (key ) +
165+ "initializable_roles" )
165166
166167 super (Initializable , self ).__init__ (** kwargs )
167168
168-
169169 def _validate_roles_schmes (self ):
170170 for role in self .parameter_roles :
171171 if role not in self .initialization_schemes .keys ():
172172 found = False
173173 for init_role in list (self .initialization_schemes .keys ()):
174- if isinstance (eval (role ), type (eval (init_role ))):
175- self .initialization_schemes [role ] = self .initialization_schemes [init_role ]
174+ if isinstance (role , type (init_role )):
175+ self .initialization_schemes [role ] = \
176+ self .initialization_schemes [init_role ]
176177 found = True
177178 if not found :
178179 raise ValueError ("There is no initialization_schemes"
179- " defined for {}" .format (role ))
180-
180+ " defined for {}" .format (role ))
181181
182182 def _push_initialization_config (self ):
183183 self ._collect_roles ()
@@ -189,19 +189,37 @@ def _push_initialization_config(self):
189189 for role , scheme in self .initialization_schemes .items ():
190190 child .initialization_schemes [role ] = scheme
191191
192-
193192 def _collect_roles (self ):
194193 if hasattr (self , 'parameters' ):
195194 for param in self .parameters :
196195 for role in param .tag .roles :
197- if str ( role ) in self .initializable_roles :
198- self .parameter_roles .update (set ([str ( role ) ]))
196+ if role in self .initializable_roles :
197+ self .parameter_roles .update (set ([role ]))
199198
200199 def _initialize (self ):
201200 for param in self .parameters :
202201 for role in param .tag .roles :
203- if str (role ) in self .initializable_roles :
204- self .initialization_schemes [str (role )].initialize (param , self .rng )
202+ if role in self .initializable_roles :
203+ self .initialization_schemes [role ].initialize (param ,
204+ self .rng )
205+
206+ def __getattr__ (self , name ):
207+ if name == "weights_init" :
208+ if WEIGHT in self .initialization_schemes :
209+ return self .initialization_schemes [WEIGHT ]
210+ elif name == "biases_init" :
211+ if BIAS in self .initialization_schemes :
212+ return self .initialization_schemes [BIAS ]
213+ raise AttributeError ("Attribute {} not found" .format (name ))
214+
215+ def __setattr__ (self , name , value ):
216+ if name == 'weights_init' :
217+ self .initialization_schemes [WEIGHT ] = value
218+ elif name == 'biases_init' :
219+ self .initialization_schemes [BIAS ] = value
220+ else :
221+ super (Initializable , self ).__setattr__ (name , value )
222+
205223
206224class LinearLike (Initializable ):
207225 """Initializable subclass with logic for :class:`Linear`-like classes.
@@ -229,7 +247,6 @@ def b(self):
229247 raise AttributeError ('use_bias is False' )
230248
231249
232-
233250class Random (Brick ):
234251 """A mixin class for Bricks which need Theano RNGs.
235252
0 commit comments