1111
1212from blocks .bricks import Initializable , Logistic , Tanh , Linear
1313from blocks .bricks .base import Application , application , Brick , lazy
14- from blocks .initialization import NdarrayInitialization
14+ from blocks .initialization import NdarrayInitialization , Constant
1515from blocks .roles import add_role , WEIGHT , INITIAL_STATE
1616from blocks .utils import (pack , shared_floatx_nans , shared_floatx_zeros ,
1717 dict_union , dict_subset , is_shared_variable )
@@ -279,6 +279,8 @@ class SimpleRecurrent(BaseRecurrent, Initializable):
279279 def __init__ (self , dim , activation , ** kwargs ):
280280 self .dim = dim
281281 children = [activation ] + kwargs .get ('children' , [])
282+ if not 'initial_state_init' in kwargs :
283+ kwargs ['initial_state_init' ] = Constant (0. )
282284 super (SimpleRecurrent , self ).__init__ (children = children , ** kwargs )
283285
284286 @property
@@ -297,13 +299,10 @@ def _allocate(self):
297299 self .parameters .append (shared_floatx_nans ((self .dim , self .dim ),
298300 name = "W" ))
299301 add_role (self .parameters [0 ], WEIGHT )
300- self .parameters .append (shared_floatx_zeros ((self .dim ,),
302+ self .parameters .append (shared_floatx_nans ((self .dim ,),
301303 name = "initial_state" ))
302304 add_role (self .parameters [1 ], INITIAL_STATE )
303305
304- def _initialize (self ):
305- self .weights_init .initialize (self .W , self .rng )
306-
307306 @recurrent (sequences = ['inputs' , 'mask' ], states = ['states' ],
308307 outputs = ['states' ], contexts = [])
309308 def apply (self , inputs , states , mask = None ):
@@ -386,6 +385,9 @@ def __init__(self, dim, activation=None, gate_activation=None, **kwargs):
386385
387386 children = ([self .activation , self .gate_activation ] +
388387 kwargs .get ('children' , []))
388+
389+ if not 'initial_state_init' in kwargs :
390+ kwargs ['initial_state_init' ] = Constant (0. )
389391 super (LSTM , self ).__init__ (children = children , ** kwargs )
390392
391393 def get_dim (self , name ):
@@ -408,9 +410,9 @@ def _allocate(self):
408410 name = 'W_cell_to_out' )
409411 # The underscore is required to prevent collision with
410412 # the `initial_state` application method
411- self .initial_state_ = shared_floatx_zeros ((self .dim ,),
413+ self .initial_state_ = shared_floatx_nans ((self .dim ,),
412414 name = "initial_state" )
413- self .initial_cells = shared_floatx_zeros ((self .dim ,),
415+ self .initial_cells = shared_floatx_nans ((self .dim ,),
414416 name = "initial_cells" )
415417 add_role (self .W_state , WEIGHT )
416418 add_role (self .W_cell_to_in , WEIGHT )
@@ -423,10 +425,6 @@ def _allocate(self):
423425 self .W_state , self .W_cell_to_in , self .W_cell_to_forget ,
424426 self .W_cell_to_out , self .initial_state_ , self .initial_cells ]
425427
426- def _initialize (self ):
427- for weights in self .parameters [:4 ]:
428- self .weights_init .initialize (weights , self .rng )
429-
430428 @recurrent (sequences = ['inputs' , 'mask' ], states = ['states' , 'cells' ],
431429 contexts = [], outputs = ['states' , 'cells' ])
432430 def apply (self , inputs , states , cells , mask = None ):
@@ -533,6 +531,9 @@ def __init__(self, dim, activation=None, gate_activation=None,
533531 self .gate_activation = gate_activation
534532
535533 children = [activation , gate_activation ] + kwargs .get ('children' , [])
534+
535+ if not 'initial_state_init' in kwargs :
536+ kwargs ['initial_state_init' ] = Constant (0. )
536537 super (GatedRecurrent , self ).__init__ (children = children , ** kwargs )
537538
538539 @property
@@ -557,22 +558,13 @@ def _allocate(self):
557558 name = 'state_to_state' ))
558559 self .parameters .append (shared_floatx_nans ((self .dim , 2 * self .dim ),
559560 name = 'state_to_gates' ))
560- self .parameters .append (shared_floatx_zeros ((self .dim ,),
561+ self .parameters .append (shared_floatx_nans ((self .dim ,),
561562 name = "initial_state" ))
562563 for i in range (2 ):
563564 if self .parameters [i ]:
564565 add_role (self .parameters [i ], WEIGHT )
565566 add_role (self .parameters [2 ], INITIAL_STATE )
566567
567- def _initialize (self ):
568- self .weights_init .initialize (self .state_to_state , self .rng )
569- state_to_update = self .weights_init .generate (
570- self .rng , (self .dim , self .dim ))
571- state_to_reset = self .weights_init .generate (
572- self .rng , (self .dim , self .dim ))
573- self .state_to_gates .set_value (
574- numpy .hstack ([state_to_update , state_to_reset ]))
575-
576568 @recurrent (sequences = ['mask' , 'inputs' , 'gate_inputs' ],
577569 states = ['states' ], outputs = ['states' ], contexts = [])
578570 def apply (self , inputs , gate_inputs , states , mask = None ):
0 commit comments