@@ -150,12 +150,14 @@ class BaseSequenceGenerator(Initializable):
150150 """
151151 @lazy ()
152152 def __init__ (self , readout , transition , fork , ** kwargs ):
153- super (BaseSequenceGenerator , self ).__init__ (** kwargs )
154153 self .readout = readout
155154 self .transition = transition
156155 self .fork = fork
157156
158- self .children = [self .readout , self .fork , self .transition ]
157+ children = [self .readout , self .fork , self .transition ]
158+ children += kwargs .get ('children' , [])
159+ super (BaseSequenceGenerator , self ).__init__ (children = children ,
160+ ** kwargs )
159161
160162 @property
161163 def _state_names (self ):
@@ -508,27 +510,27 @@ class Readout(AbstractReadout):
508510 def __init__ (self , emitter = None , feedback_brick = None ,
509511 merge = None , merge_prototype = None ,
510512 post_merge = None , merged_dim = None , ** kwargs ):
511- super (Readout , self ).__init__ (** kwargs )
512513
513514 if not emitter :
514- emitter = TrivialEmitter (self . readout_dim )
515+ emitter = TrivialEmitter (kwargs [ ' readout_dim' ] )
515516 if not feedback_brick :
516- feedback_brick = TrivialFeedback (self . readout_dim )
517+ feedback_brick = TrivialFeedback (kwargs [ ' readout_dim' ] )
517518 if not merge :
518- merge = Merge (input_names = self . source_names ,
519+ merge = Merge (input_names = kwargs [ ' source_names' ] ,
519520 prototype = merge_prototype )
520521 if not post_merge :
521- post_merge = Bias (dim = self . readout_dim )
522+ post_merge = Bias (dim = kwargs [ ' readout_dim' ] )
522523 if not merged_dim :
523- merged_dim = self . readout_dim
524+ merged_dim = kwargs [ ' readout_dim' ]
524525 self .emitter = emitter
525526 self .feedback_brick = feedback_brick
526527 self .merge = merge
527528 self .post_merge = post_merge
528529 self .merged_dim = merged_dim
529530
530- self .children = [self .emitter , self .feedback_brick ,
531- self .merge , self .post_merge ]
531+ children = [self .emitter , self .feedback_brick , self .merge ,
532+ self .post_merge ] + kwargs .get ('children' , [])
533+ super (Readout , self ).__init__ (children = children , ** kwargs )
532534
533535 def _push_allocation_config (self ):
534536 self .emitter .readout_dim = self .get_dim ('readouts' )
@@ -684,10 +686,10 @@ class SoftmaxEmitter(AbstractEmitter, Initializable, Random):
684686
685687 """
686688 def __init__ (self , initial_output = 0 , ** kwargs ):
687- super (SoftmaxEmitter , self ).__init__ (** kwargs )
688689 self .initial_output = initial_output
689690 self .softmax = NDimensionalSoftmax ()
690- self .children = [self .softmax ]
691+ children = [self .softmax ] + kwargs .get ('children' , [])
692+ super (SoftmaxEmitter , self ).__init__ (children = children , ** kwargs )
691693
692694 @application
693695 def probs (self , readouts ):
@@ -743,13 +745,12 @@ class LookupFeedback(AbstractFeedback, Initializable):
743745
744746 """
745747 def __init__ (self , num_outputs = None , feedback_dim = None , ** kwargs ):
746- super (LookupFeedback , self ).__init__ (** kwargs )
747748 self .num_outputs = num_outputs
748749 self .feedback_dim = feedback_dim
749750
750- self .lookup = LookupTable (num_outputs , feedback_dim ,
751- weights_init = self . weights_init )
752- self . children = [ self . lookup ]
751+ self .lookup = LookupTable (num_outputs , feedback_dim )
752+ children = [ self . lookup ] + kwargs . get ( 'children' , [] )
753+ super ( LookupFeedback , self ). __init__ ( children = children , ** kwargs )
753754
754755 def _push_allocation_config (self ):
755756 self .lookup .length = self .num_outputs
@@ -784,14 +785,15 @@ class FakeAttentionRecurrent(AbstractAttentionRecurrent, Initializable):
784785
785786 """
786787 def __init__ (self , transition , ** kwargs ):
787- super (FakeAttentionRecurrent , self ).__init__ (** kwargs )
788788 self .transition = transition
789789
790790 self .state_names = transition .apply .states
791791 self .context_names = transition .apply .contexts
792792 self .glimpse_names = []
793793
794- self .children = [self .transition ]
794+ children = [self .transition ] + kwargs .get ('children' , [])
795+ super (FakeAttentionRecurrent , self ).__init__ (children = children ,
796+ ** kwargs )
795797
796798 @application
797799 def apply (self , * args , ** kwargs ):
0 commit comments