Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions blocks/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Annotated computation graph management."""
import logging
from collections import OrderedDict
from collections import OrderedDict, deque
from itertools import chain
import warnings

Expand Down Expand Up @@ -103,8 +103,15 @@ def auxiliary_variables(self):

@property
def scan_variables(self):
"""Variables of Scan ops."""
return list(chain(*[g.variables for g in self._scan_graphs]))
"""Variables of Scan ops. Breadth-first search"""
sg_que = deque(self._scan_graphs)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code supposed that no recurrent class is nested. #1115

var_list = []
while sg_que:
g = sg_que.popleft()
var_list.append(g.variables)
if g._scan_graphs:
sg_que.extend(g._scan_graphs)
return list(chain(*var_list))
Copy link
Contributor

@rizar rizar Jun 14, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a desirable change, but I have a few concerns:

  • modifying the list that you iterate over is hard to understand and can very likely cause errors. Please try to simplify the code
  • a test is required


def _get_variables(self):
"""Collect variables, updates and auxiliary variables.
Expand Down
52 changes: 51 additions & 1 deletion tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import theano
import warnings
from numpy.testing import assert_allclose
from theano import tensor
from theano import function, tensor
from theano.sandbox.rng_mrg import MRG_RandomStreams

from blocks.bricks import MLP, Identity, Logistic, Tanh
Expand Down Expand Up @@ -74,6 +74,56 @@ def test_computation_graph():
assert all(v in cg6.scan_variables for v in scan.inputs + scan.outputs)


def test_computation_graph_nested_scan():
inner_x = tensor.matrix('inner_x')
outer_x = tensor.matrix('outer_x')
factor = tensor.matrix('factor')

def inner_scan(inner_x_, outer_x_one_step):
inner_o, _ = theano.scan(fn=lambda inp, ctx: inp + ctx,
sequences=inner_x_,
non_sequences=outer_x_one_step)
return inner_o.sum(axis=0)

outer_o, _ = theano.scan(fn=lambda inp, ctx: inner_scan(ctx, inp),
sequences=outer_x,
non_sequences=inner_x)

outs = outer_o * factor

nested_scan = outs.owner.inputs[0].owner.op
cg = ComputationGraph(outer_o)

assert cg.scans == [nested_scan]
assert all(var in cg.scan_variables
for var in nested_scan.inputs + nested_scan.outputs)

func = function(inputs=[inner_x, outer_x, factor], outputs=outs,
allow_input_downcast=True)

in_len = 9
out_len = 7
dim = 3

floatX = theano.config.floatX
x_val = numpy.asarray(numpy.random.uniform(size=(in_len, dim)),
dtype=floatX)
y_val = numpy.asarray(numpy.random.uniform(size=(out_len, dim)),
dtype=floatX)
factor_val = numpy.asarray(numpy.random.uniform(size=(out_len, dim)),
dtype=floatX)

results = func(x_val, y_val, factor_val)

results2 = numpy.zeros(shape=(out_len, dim))
for i, y in enumerate(y_val):
for x in x_val:
results2[i] += (x + y)
results2 = results2 * factor_val

assert_allclose(results, results2)


def test_computation_graph_variable_duplicate():
# Test if ComputationGraph.variables contains duplicates if some outputs
# are part of the computation graph
Expand Down
91 changes: 89 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from theano import tensor
from numpy.testing import assert_allclose, assert_raises

from blocks.bricks import MLP, Tanh
from blocks.bricks import MLP, Initializable, Linear, Tanh
from blocks.bricks.parallel import Fork
from blocks.bricks.recurrent import BaseRecurrent, GatedRecurrent, recurrent
from blocks.model import Model
from blocks.graph import add_role, PARAMETER
from blocks.utils import shared_floatx
from blocks.utils import dict_union, shared_floatx


def test_model():
Expand Down Expand Up @@ -70,3 +72,88 @@ def test_model_handles_brickless_parameteres():
y = x.dot(v)
model = Model(y)
assert list(model.get_parameter_dict().items()) == [('V', v)]


class InnerRecurrent(BaseRecurrent, Initializable):
def __init__(self, inner_input_dim, outer_input_dim, inner_dim, **kwargs):
self.inner_gru = GatedRecurrent(dim=inner_dim, name='inner_gru')

self.inner_input_fork = Fork(
output_names=[name for name in self.inner_gru.apply.sequences
if 'mask' not in name],
input_dim=inner_input_dim, name='inner_input_fork')
self.outer_input_fork = Fork(
output_names=[name for name in self.inner_gru.apply.sequences
if 'mask' not in name],
input_dim=outer_input_dim, name='inner_outer_fork')

super(InnerRecurrent, self).__init__(**kwargs)

self.children = [
self.inner_gru, self.inner_input_fork, self.outer_input_fork]

def _push_allocation_config(self):
self.inner_input_fork.output_dims = self.inner_gru.get_dims(
self.inner_input_fork.output_names)
self.outer_input_fork.output_dims = self.inner_gru.get_dims(
self.outer_input_fork.output_names)

@recurrent(sequences=['inner_inputs'], states=['states'],
contexts=['outer_inputs'], outputs=['states'])
def apply(self, inner_inputs, states, outer_inputs):
forked_inputs = self.inner_input_fork.apply(inner_inputs, as_dict=True)
forked_states = self.outer_input_fork.apply(outer_inputs, as_dict=True)

gru_inputs = {key: forked_inputs[key] + forked_states[key]
for key in forked_inputs.keys()}

new_states = self.inner_gru.apply(
iterate=False,
**dict_union(gru_inputs, {'states': states}))
return new_states # mean according to the time axis

def get_dim(self, name):
if name == 'states':
return self.inner_gru.get_dim(name)
else:
return AttributeError


class OuterLinear(BaseRecurrent, Initializable):
def __init__(self, inner_recurrent, inner_dim, **kwargs):
self.inner_recurrent = inner_recurrent
self.linear_map = Linear(input_dim=inner_dim, output_dim=1)

super(OuterLinear, self).__init__(**kwargs)

self.children = [self.inner_recurrent, self.linear_map]

@recurrent(sequences=['outer_inputs'], states=[],
contexts=['inner_inputs'], outputs=['weighted_averages'])
def apply(self, outer_inputs, inner_inputs):
inner_states = self.inner_recurrent.apply(
inner_inputs=inner_inputs, outer_inputs=outer_inputs)
linear_outs = self.linear_map.apply(inner_states)
return linear_outs.mean(axis=0)


def test_nested_recurrent_model():
inner_input_dim = 11
outer_input_dim = 17
inner_dim = 5

inner_recurrent = InnerRecurrent(inner_input_dim, outer_input_dim,
inner_dim)
nested_recurrent = OuterLinear(inner_recurrent, inner_dim)

inner_inputs = tensor.tensor3()
outer_inputs = tensor.tensor3()

nested_outs = nested_recurrent.apply(
outer_inputs=outer_inputs, inner_inputs=inner_inputs)

outs_mean = nested_outs.mean()
model = Model(outs_mean)

assert len(model.top_bricks) == 1