diff --git a/blocks/graph/__init__.py b/blocks/graph/__init__.py index e40ced0a..0e53cdac 100644 --- a/blocks/graph/__init__.py +++ b/blocks/graph/__init__.py @@ -1,6 +1,6 @@ """Annotated computation graph management.""" import logging -from collections import OrderedDict +from collections import OrderedDict, deque from itertools import chain import warnings @@ -103,8 +103,15 @@ def auxiliary_variables(self): @property def scan_variables(self): - """Variables of Scan ops.""" - return list(chain(*[g.variables for g in self._scan_graphs])) + """Variables of Scan ops. Breadth-first search""" + sg_que = deque(self._scan_graphs) + var_list = [] + while sg_que: + g = sg_que.popleft() + var_list.append(g.variables) + if g._scan_graphs: + sg_que.extend(g._scan_graphs) + return list(chain(*var_list)) def _get_variables(self): """Collect variables, updates and auxiliary variables. diff --git a/tests/test_graph.py b/tests/test_graph.py index dda332f4..7cad09d5 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -2,7 +2,7 @@ import theano import warnings from numpy.testing import assert_allclose -from theano import tensor +from theano import function, tensor from theano.sandbox.rng_mrg import MRG_RandomStreams from blocks.bricks import MLP, Identity, Logistic, Tanh @@ -74,6 +74,56 @@ def test_computation_graph(): assert all(v in cg6.scan_variables for v in scan.inputs + scan.outputs) +def test_computation_graph_nested_scan(): + inner_x = tensor.matrix('inner_x') + outer_x = tensor.matrix('outer_x') + factor = tensor.matrix('factor') + + def inner_scan(inner_x_, outer_x_one_step): + inner_o, _ = theano.scan(fn=lambda inp, ctx: inp + ctx, + sequences=inner_x_, + non_sequences=outer_x_one_step) + return inner_o.sum(axis=0) + + outer_o, _ = theano.scan(fn=lambda inp, ctx: inner_scan(ctx, inp), + sequences=outer_x, + non_sequences=inner_x) + + outs = outer_o * factor + + nested_scan = outs.owner.inputs[0].owner.op + cg = ComputationGraph(outer_o) + + assert cg.scans == [nested_scan] + assert all(var in cg.scan_variables + for var in nested_scan.inputs + nested_scan.outputs) + + func = function(inputs=[inner_x, outer_x, factor], outputs=outs, + allow_input_downcast=True) + + in_len = 9 + out_len = 7 + dim = 3 + + floatX = theano.config.floatX + x_val = numpy.asarray(numpy.random.uniform(size=(in_len, dim)), + dtype=floatX) + y_val = numpy.asarray(numpy.random.uniform(size=(out_len, dim)), + dtype=floatX) + factor_val = numpy.asarray(numpy.random.uniform(size=(out_len, dim)), + dtype=floatX) + + results = func(x_val, y_val, factor_val) + + results2 = numpy.zeros(shape=(out_len, dim)) + for i, y in enumerate(y_val): + for x in x_val: + results2[i] += (x + y) + results2 = results2 * factor_val + + assert_allclose(results, results2) + + def test_computation_graph_variable_duplicate(): # Test if ComputationGraph.variables contains duplicates if some outputs # are part of the computation graph diff --git a/tests/test_model.py b/tests/test_model.py index f65f05ca..a5480936 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -3,10 +3,12 @@ from theano import tensor from numpy.testing import assert_allclose, assert_raises -from blocks.bricks import MLP, Tanh +from blocks.bricks import MLP, Initializable, Linear, Tanh +from blocks.bricks.parallel import Fork +from blocks.bricks.recurrent import BaseRecurrent, GatedRecurrent, recurrent from blocks.model import Model from blocks.graph import add_role, PARAMETER -from blocks.utils import shared_floatx +from blocks.utils import dict_union, shared_floatx def test_model(): @@ -70,3 +72,88 @@ def test_model_handles_brickless_parameteres(): y = x.dot(v) model = Model(y) assert list(model.get_parameter_dict().items()) == [('V', v)] + + +class InnerRecurrent(BaseRecurrent, Initializable): + def __init__(self, inner_input_dim, outer_input_dim, inner_dim, **kwargs): + self.inner_gru = GatedRecurrent(dim=inner_dim, name='inner_gru') + + self.inner_input_fork = Fork( + output_names=[name for name in self.inner_gru.apply.sequences + if 'mask' not in name], + input_dim=inner_input_dim, name='inner_input_fork') + self.outer_input_fork = Fork( + output_names=[name for name in self.inner_gru.apply.sequences + if 'mask' not in name], + input_dim=outer_input_dim, name='inner_outer_fork') + + super(InnerRecurrent, self).__init__(**kwargs) + + self.children = [ + self.inner_gru, self.inner_input_fork, self.outer_input_fork] + + def _push_allocation_config(self): + self.inner_input_fork.output_dims = self.inner_gru.get_dims( + self.inner_input_fork.output_names) + self.outer_input_fork.output_dims = self.inner_gru.get_dims( + self.outer_input_fork.output_names) + + @recurrent(sequences=['inner_inputs'], states=['states'], + contexts=['outer_inputs'], outputs=['states']) + def apply(self, inner_inputs, states, outer_inputs): + forked_inputs = self.inner_input_fork.apply(inner_inputs, as_dict=True) + forked_states = self.outer_input_fork.apply(outer_inputs, as_dict=True) + + gru_inputs = {key: forked_inputs[key] + forked_states[key] + for key in forked_inputs.keys()} + + new_states = self.inner_gru.apply( + iterate=False, + **dict_union(gru_inputs, {'states': states})) + return new_states # mean according to the time axis + + def get_dim(self, name): + if name == 'states': + return self.inner_gru.get_dim(name) + else: + return AttributeError + + +class OuterLinear(BaseRecurrent, Initializable): + def __init__(self, inner_recurrent, inner_dim, **kwargs): + self.inner_recurrent = inner_recurrent + self.linear_map = Linear(input_dim=inner_dim, output_dim=1) + + super(OuterLinear, self).__init__(**kwargs) + + self.children = [self.inner_recurrent, self.linear_map] + + @recurrent(sequences=['outer_inputs'], states=[], + contexts=['inner_inputs'], outputs=['weighted_averages']) + def apply(self, outer_inputs, inner_inputs): + inner_states = self.inner_recurrent.apply( + inner_inputs=inner_inputs, outer_inputs=outer_inputs) + linear_outs = self.linear_map.apply(inner_states) + return linear_outs.mean(axis=0) + + +def test_nested_recurrent_model(): + inner_input_dim = 11 + outer_input_dim = 17 + inner_dim = 5 + + inner_recurrent = InnerRecurrent(inner_input_dim, outer_input_dim, + inner_dim) + nested_recurrent = OuterLinear(inner_recurrent, inner_dim) + + inner_inputs = tensor.tensor3() + outer_inputs = tensor.tensor3() + + nested_outs = nested_recurrent.apply( + outer_inputs=outer_inputs, inner_inputs=inner_inputs) + + outs_mean = nested_outs.mean() + model = Model(outs_mean) + + assert len(model.top_bricks) == 1 +