Skip to content

Commit 0e1696d

Browse files
committed
Merge pull request #1029 from memimo/children2
Pass children as argument to Brick
2 parents 0bd6ee8 + e7c6e84 commit 0e1696d

File tree

7 files changed

+60
-49
lines changed

7 files changed

+60
-49
lines changed

blocks/bricks/attention.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -309,13 +309,12 @@ class SequenceContentAttention(GenericSequenceAttention, Initializable):
309309
@lazy(allocation=['match_dim'])
310310
def __init__(self, match_dim, state_transformer=None,
311311
attended_transformer=None, energy_computer=None, **kwargs):
312-
super(SequenceContentAttention, self).__init__(**kwargs)
313312
if not state_transformer:
314313
state_transformer = Linear(use_bias=False)
315314
self.match_dim = match_dim
316315
self.state_transformer = state_transformer
317316

318-
self.state_transformers = Parallel(input_names=self.state_names,
317+
self.state_transformers = Parallel(input_names=kwargs['state_names'],
319318
prototype=state_transformer,
320319
name="state_trans")
321320
if not attended_transformer:
@@ -325,8 +324,10 @@ def __init__(self, match_dim, state_transformer=None,
325324
self.attended_transformer = attended_transformer
326325
self.energy_computer = energy_computer
327326

328-
self.children = [self.state_transformers, attended_transformer,
329-
energy_computer]
327+
children = [self.state_transformers, attended_transformer,
328+
energy_computer] + kwargs.get('children', [])
329+
super(SequenceContentAttention, self).__init__(children=children,
330+
**kwargs)
330331

331332
def _push_allocation_config(self):
332333
self.state_transformers.input_dims = self.state_dims
@@ -540,7 +541,6 @@ def __init__(self, transition, attention, distribute=None,
540541
add_contexts=True,
541542
attended_name=None, attended_mask_name=None,
542543
**kwargs):
543-
super(AttentionRecurrent, self).__init__(**kwargs)
544544
self._sequence_names = list(transition.apply.sequences)
545545
self._state_names = list(transition.apply.states)
546546
self._context_names = list(transition.apply.contexts)
@@ -575,7 +575,9 @@ def __init__(self, transition, attention, distribute=None,
575575
name for name in self._glimpse_names
576576
if name in self.attention.take_glimpses.inputs]
577577

578-
self.children = [self.transition, self.attention, self.distribute]
578+
children = [self.transition, self.attention, self.distribute]
579+
children += kwargs.get('children', [])
580+
super(AttentionRecurrent, self).__init__(children=children, **kwargs)
579581

580582
def _push_allocation_config(self):
581583
self.attention.state_dims = self.transition.get_dims(

blocks/bricks/base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -546,12 +546,15 @@ class Brick(Annotation):
546546
#: See :attr:`Brick.print_shapes`
547547
print_shapes = False
548548

549-
def __init__(self, name=None):
549+
def __init__(self, name=None, children=None):
550550
if name is None:
551551
name = self.__class__.__name__.lower()
552-
self.name = name
553552

554-
self.children = []
553+
if children is None:
554+
children = []
555+
556+
self.name = name
557+
self.children = children
555558
self.parents = []
556559

557560
self.allocated = False

blocks/bricks/recurrent.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,9 @@ class SimpleRecurrent(BaseRecurrent, Initializable):
275275
"""
276276
@lazy(allocation=['dim'])
277277
def __init__(self, dim, activation, **kwargs):
278-
super(SimpleRecurrent, self).__init__(**kwargs)
279278
self.dim = dim
280-
self.children = [activation]
279+
children = [activation] + kwargs.get('children', [])
280+
super(SimpleRecurrent, self).__init__(children=children, **kwargs)
281281

282282
@property
283283
def W(self):
@@ -370,12 +370,12 @@ class LSTM(BaseRecurrent, Initializable):
370370
"""
371371
@lazy(allocation=['dim'])
372372
def __init__(self, dim, activation=None, **kwargs):
373-
super(LSTM, self).__init__(**kwargs)
374373
self.dim = dim
375374

376375
if not activation:
377376
activation = Tanh()
378-
self.children = [activation]
377+
children = [activation] + kwargs.get('children', [])
378+
super(LSTM, self).__init__(children=children, **kwargs)
379379

380380
def get_dim(self, name):
381381
if name == 'inputs':
@@ -513,7 +513,6 @@ class GatedRecurrent(BaseRecurrent, Initializable):
513513
@lazy(allocation=['dim'])
514514
def __init__(self, dim, activation=None, gate_activation=None,
515515
**kwargs):
516-
super(GatedRecurrent, self).__init__(**kwargs)
517516
self.dim = dim
518517

519518
if not activation:
@@ -523,7 +522,8 @@ def __init__(self, dim, activation=None, gate_activation=None,
523522
self.activation = activation
524523
self.gate_activation = gate_activation
525524

526-
self.children = [activation, gate_activation]
525+
children = [activation, gate_activation] + kwargs.get('children', [])
526+
super(GatedRecurrent, self).__init__(children=children, **kwargs)
527527

528528
@property
529529
def state_to_state(self):
@@ -629,12 +629,13 @@ class Bidirectional(Initializable):
629629

630630
@lazy()
631631
def __init__(self, prototype, **kwargs):
632-
super(Bidirectional, self).__init__(**kwargs)
633632
self.prototype = prototype
634633

635-
self.children = [copy.deepcopy(prototype) for _ in range(2)]
636-
self.children[0].name = 'forward'
637-
self.children[1].name = 'backward'
634+
children = [copy.deepcopy(prototype) for _ in range(2)]
635+
children[0].name = 'forward'
636+
children[1].name = 'backward'
637+
children += kwargs.get('children', [])
638+
super(Bidirectional, self).__init__(children=children, **kwargs)
638639

639640
@application
640641
def apply(self, *args, **kwargs):

blocks/bricks/sequence_generators.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,14 @@ class BaseSequenceGenerator(Initializable):
150150
"""
151151
@lazy()
152152
def __init__(self, readout, transition, fork, **kwargs):
153-
super(BaseSequenceGenerator, self).__init__(**kwargs)
154153
self.readout = readout
155154
self.transition = transition
156155
self.fork = fork
157156

158-
self.children = [self.readout, self.fork, self.transition]
157+
children = [self.readout, self.fork, self.transition]
158+
children += kwargs.get('children', [])
159+
super(BaseSequenceGenerator, self).__init__(children=children,
160+
**kwargs)
159161

160162
@property
161163
def _state_names(self):
@@ -508,27 +510,27 @@ class Readout(AbstractReadout):
508510
def __init__(self, emitter=None, feedback_brick=None,
509511
merge=None, merge_prototype=None,
510512
post_merge=None, merged_dim=None, **kwargs):
511-
super(Readout, self).__init__(**kwargs)
512513

513514
if not emitter:
514-
emitter = TrivialEmitter(self.readout_dim)
515+
emitter = TrivialEmitter(kwargs['readout_dim'])
515516
if not feedback_brick:
516-
feedback_brick = TrivialFeedback(self.readout_dim)
517+
feedback_brick = TrivialFeedback(kwargs['readout_dim'])
517518
if not merge:
518-
merge = Merge(input_names=self.source_names,
519+
merge = Merge(input_names=kwargs['source_names'],
519520
prototype=merge_prototype)
520521
if not post_merge:
521-
post_merge = Bias(dim=self.readout_dim)
522+
post_merge = Bias(dim=kwargs['readout_dim'])
522523
if not merged_dim:
523-
merged_dim = self.readout_dim
524+
merged_dim = kwargs['readout_dim']
524525
self.emitter = emitter
525526
self.feedback_brick = feedback_brick
526527
self.merge = merge
527528
self.post_merge = post_merge
528529
self.merged_dim = merged_dim
529530

530-
self.children = [self.emitter, self.feedback_brick,
531-
self.merge, self.post_merge]
531+
children = [self.emitter, self.feedback_brick, self.merge,
532+
self.post_merge] + kwargs.get('children', [])
533+
super(Readout, self).__init__(children=children, **kwargs)
532534

533535
def _push_allocation_config(self):
534536
self.emitter.readout_dim = self.get_dim('readouts')
@@ -684,10 +686,10 @@ class SoftmaxEmitter(AbstractEmitter, Initializable, Random):
684686
685687
"""
686688
def __init__(self, initial_output=0, **kwargs):
687-
super(SoftmaxEmitter, self).__init__(**kwargs)
688689
self.initial_output = initial_output
689690
self.softmax = NDimensionalSoftmax()
690-
self.children = [self.softmax]
691+
children = [self.softmax] + kwargs.get('children', [])
692+
super(SoftmaxEmitter, self).__init__(children=children, **kwargs)
691693

692694
@application
693695
def probs(self, readouts):
@@ -743,13 +745,12 @@ class LookupFeedback(AbstractFeedback, Initializable):
743745
744746
"""
745747
def __init__(self, num_outputs=None, feedback_dim=None, **kwargs):
746-
super(LookupFeedback, self).__init__(**kwargs)
747748
self.num_outputs = num_outputs
748749
self.feedback_dim = feedback_dim
749750

750-
self.lookup = LookupTable(num_outputs, feedback_dim,
751-
weights_init=self.weights_init)
752-
self.children = [self.lookup]
751+
self.lookup = LookupTable(num_outputs, feedback_dim)
752+
children = [self.lookup] + kwargs.get('children', [])
753+
super(LookupFeedback, self).__init__(children=children, **kwargs)
753754

754755
def _push_allocation_config(self):
755756
self.lookup.length = self.num_outputs
@@ -784,14 +785,15 @@ class FakeAttentionRecurrent(AbstractAttentionRecurrent, Initializable):
784785
785786
"""
786787
def __init__(self, transition, **kwargs):
787-
super(FakeAttentionRecurrent, self).__init__(**kwargs)
788788
self.transition = transition
789789

790790
self.state_names = transition.apply.states
791791
self.context_names = transition.apply.contexts
792792
self.glimpse_names = []
793793

794-
self.children = [self.transition]
794+
children = [self.transition] + kwargs.get('children', [])
795+
super(FakeAttentionRecurrent, self).__init__(children=children,
796+
**kwargs)
795797

796798
@application
797799
def apply(self, *args, **kwargs):

blocks/bricks/sequences.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@ class Sequence(Brick):
2222
2323
"""
2424
def __init__(self, application_methods, **kwargs):
25-
super(Sequence, self).__init__(**kwargs)
2625
self.application_methods = application_methods
2726

2827
seen = set()
29-
self.children = [app.brick for app in application_methods
30-
if not (app.brick in seen or seen.add(app.brick))]
28+
children = [app.brick for app in application_methods
29+
if not (app.brick in seen or seen.add(app.brick))]
30+
children += kwargs.get('children', [])
31+
super(Sequence, self).__init__(children=children, **kwargs)
3132

3233
@application
3334
def apply(self, *args):

blocks/bricks/simple.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,10 @@ class LinearMaxout(Initializable, Feedforward):
204204
"""
205205
@lazy(allocation=['input_dim', 'output_dim', 'num_pieces'])
206206
def __init__(self, input_dim, output_dim, num_pieces, **kwargs):
207-
super(LinearMaxout, self).__init__(**kwargs)
208207
self.linear = Linear()
209208
self.maxout = Maxout()
210-
self.children = [self.linear,
211-
self.maxout]
209+
children = [self.linear, self.maxout] + kwargs.get('children', [])
210+
super(LinearMaxout, self).__init__(children=children, **kwargs)
212211

213212
self.input_dim = input_dim
214213
self.output_dim = output_dim

docs/create_your_own_brick.rst

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ of :class:`.Brick` for a precise description of the life-cycle of a brick):
5757

5858
* :meth:`.Brick.__init__`: you should pass by argument the attributes of your
5959
brick. It is also in this method that you should create the potential
60-
"children bricks" that belongs to your brick (in that case, you have to put
61-
the children bricks into ``self.children``). The initialization of the
60+
"children bricks" that belongs to your brick (in that case, you have to pass
61+
the children bricks to ``super().__init__``). The initialization of the
6262
attributes can be lazy as described later in the tutorial.
6363
* :meth:`apply`: you need to implement a method that actually
6464
implements the operation of the brick, taking as arguments the inputs
@@ -210,10 +210,11 @@ specify the ``input_dim`` of ``brick2`` directly at its creation.
210210
>>> class ChainOfTwoFeedforward(Feedforward):
211211
... """Two sequential Feedforward bricks."""
212212
... def __init__(self, brick1, brick2, **kwargs):
213-
... super(Feedforward, self).__init__(**kwargs)
214213
... self.brick1 = brick1
215214
... self.brick2 = brick2
216-
... self.children = [self.brick1, self.brick2]
215+
... children = [self.brick1, self.brick2]
216+
... children += kwargs.get('children', [])
217+
... super(Feedforward, self).__init__(children=children, **kwargs)
217218
...
218219
... @property
219220
... def input_dim(self):
@@ -370,12 +371,14 @@ One can also create the brick using :class:`Linear` children bricks, which
370371
>>> class ParallelLinear2(Initializable):
371372
... def __init__(self, input_dim1, input_dim2, output_dim1, output_dim2,
372373
... **kwargs):
373-
... super(ParallelLinear2, self).__init__(**kwargs)
374374
... self.linear1 = Linear(input_dim1, output_dim1,
375375
... use_bias=False, **kwargs)
376376
... self.linear2 = Linear(input_dim2, output_dim2,
377377
... use_bias=False, **kwargs)
378-
... self.children = [self.linear1, self.linear2]
378+
... children = [self.linear1, self.linear2]
379+
... children += kwargs.get('children', [])
380+
... super(ParallelLinear2, self).__init__(children=children,
381+
... **kwargs)
379382
...
380383
... @application(inputs=['input1_', 'input2_'], outputs=['output1',
381384
... 'output2'])

0 commit comments

Comments
 (0)