-
Notifications
You must be signed in to change notification settings - Fork 348
Open
Labels
Description
Multiple top_bricks are found while building the nested recurrent model. But there is only one in fact. So errors occurs. Besides, the theano.function of the nested recurrent method works well.
I think the error occurs because the outermost scan instance does not find all the variables when building the computation graph.
The bug may be reproduced with the following code.
import numpy
from theano import tensor, function
from blocks.bricks import Initializable, Linear
from blocks.bricks.recurrent import BaseRecurrent, GatedRecurrent, recurrent
from blocks.bricks.parallel import Fork
from blocks.initialization import IsotropicGaussian, Constant
from blocks.utils import dict_union
from blocks.model import Model
# from recurrent import recurrent
class InnerRecurrent(BaseRecurrent, Initializable):
def __init__(self, inner_input_dim, outer_input_dim, inner_dim, **kwargs):
self.inner_gru = GatedRecurrent(dim=inner_dim, name='inner_gru')
self.inner_input_fork = Fork(
output_names=[name for name in self.inner_gru.apply.sequences
if 'mask' not in name],
input_dim=inner_input_dim, name='inner_input_fork')
self.outer_input_fork = Fork(
output_names=[name for name in self.inner_gru.apply.sequences
if 'mask' not in name],
input_dim=outer_input_dim, name='inner_outer_fork')
super(InnerRecurrent, self).__init__(**kwargs)
self.children = [
self.inner_gru, self.inner_input_fork, self.outer_input_fork]
def _push_allocation_config(self):
self.inner_input_fork.output_dims = self.inner_gru.get_dims(
self.inner_input_fork.output_names)
self.outer_input_fork.output_dims = self.inner_gru.get_dims(
self.outer_input_fork.output_names)
@recurrent(sequences=['inner_inputs'], states=['states'],
contexts=['outer_inputs'], outputs=['states'])
def apply(self, inner_inputs, states, outer_inputs):
forked_inputs = self.inner_input_fork.apply(inner_inputs, as_dict=True)
forked_states = self.outer_input_fork.apply(outer_inputs, as_dict=True)
gru_inputs = {key: forked_inputs[key] + forked_states[key]
for key in forked_inputs.keys()}
new_states = self.inner_gru.apply(
iterate=False,
**dict_union(gru_inputs, {'states': states}))
return new_states # mean according to the time axis
def get_dim(self, name):
if name == 'states':
return self.inner_gru.get_dim(name)
else:
return AttributeError
class OuterLinear(BaseRecurrent, Initializable):
def __init__(self, inner_recurrent, inner_dim, **kwargs):
self.inner_recurrent = inner_recurrent
self.linear_map = Linear(input_dim=inner_dim, output_dim=1)
super(OuterLinear, self).__init__(**kwargs)
self.children = [self.inner_recurrent, self.linear_map]
@recurrent(sequences=['outer_inputs'], states=[],
contexts=['inner_inputs'], outputs=['weighted_averages'])
def apply(self, outer_inputs, inner_inputs):
inner_states = self.inner_recurrent.apply(
inner_inputs=inner_inputs, outer_inputs=outer_inputs)
linear_outs = self.linear_map.apply(inner_states)
return linear_outs.mean(axis=0)
def test_nested_recurrent():
inits = {
'weights_init': IsotropicGaussian(0.01),
'biases_init': Constant(0.)}
inner_input_dim = 11
outer_input_dim = 17
inner_dim = 5
batch_size = 3
inner_steps = 19
outer_steps = 7
inner_recurrent = InnerRecurrent(inner_input_dim, outer_input_dim,
inner_dim)
nested_recurrent = OuterLinear(inner_recurrent, inner_dim, **inits)
nested_recurrent.push_allocation_config()
nested_recurrent.initialize()
inner_inputs = tensor.tensor3()
outer_inputs = tensor.tensor3()
nested_outs = nested_recurrent.apply(
outer_inputs=outer_inputs, inner_inputs=inner_inputs)
func = function(inputs=[inner_inputs, outer_inputs], outputs=nested_outs,
allow_input_downcast=True)
inner_input_val = numpy.random.uniform(
size=(inner_steps, batch_size, inner_input_dim))
outer_input_val = numpy.random.uniform(
size=(outer_steps, batch_size, outer_input_dim))
# works well till this line
outputs_val = func(inner_input_val, outer_input_val)
# bug occurs
outs_mean = nested_outs.mean()
model = Model(outs_mean)
if __name__ == '__main__':
test_nested_recurrent()Reactions are currently unavailable