support recurrent with no states.#1113
Conversation
|
I'll let someone more familiar with |
|
Ok, I'll write the test case. |
blocks/bricks/recurrent/base.py
Outdated
| # Ensure that all initial states are available. | ||
| initial_states = brick.initial_states(batch_size, as_dict=True, | ||
| *args, **kwargs) | ||
| for state_name in application.states: |
There was a problem hiding this comment.
It seems like the code starting from this line can be moved out of the if clause, and the else part is not really necessary. Right now we pay a high price of having an additional level of indentation for this new feature, and it would be great to keep the complexity of the code down.
There was a problem hiding this comment.
I suggest to add the line before
else:
initial_states = OrderedDict()|
In the original code, the |
|
You may produce the error with the following code. The error occurs when the class does not contain a recurrent method named import numpy
import theano
from numpy.testing import assert_allclose
from theano import tensor
from blocks.bricks import Brick
from blocks.bricks.recurrent import BaseRecurrent, recurrent
# from recurrent import recurrent
class RecurrentWrapperNoStatesClass(BaseRecurrent):
def __init__(self, dim, **kwargs):
super(RecurrentWrapperNoStatesClass, self).__init__(**kwargs)
self.dim = dim
def get_dim(self, name):
if name in ['inputs', 'outputs', 'outputs_2']:
return self.dim
if name == 'mask':
return 0
return super(RecurrentWrapperNoStatesClass, self).get_dim(name)
@recurrent(sequences=['inputs', 'mask'], states=[],
outputs=['outputs', 'outputs_2'], contexts=[])
def apply2(self, inputs, mask=None):
outputs = inputs * 10
outputs_2 = tensor.sqr(inputs)
if mask:
outputs *= mask
outputs_2 *= mask
return outputs, outputs_2
if __name__ == '__main__':
recurrent_examples = RecurrentWrapperNoStatesClass(
dim=11, name='test_example')
X = tensor.tensor3('X')
out, out_2 = recurrent_examples.apply2(inputs=X, mask=None)
x_val = numpy.random.uniform(size=(5, 1, 1))
x_val = numpy.asarray(x_val, dtype=theano.config.floatX)
out_eval = out.eval({X: x_val})
out_2_eval = out_2.eval({X: x_val})
assert_allclose(x_val * 10, out_eval)
assert_allclose(numpy.square(x_val), out_2_eval) |
blocks/bricks/recurrent/base.py
Outdated
| state_name, brick.name)) | ||
| states_given = dict_subset(kwargs, application.states) | ||
| else: | ||
| states_given = {} |
There was a problem hiding this comment.
If I remember right, it should be an OrderedDict.
There was a problem hiding this comment.
Since states_given in the else clause is never used, it does not matter whether it is a OrderedDict, dict or None.
| @property | ||
| def scan_variables(self): | ||
| """Variables of Scan ops.""" | ||
| return list(chain(*[g.variables for g in self._scan_graphs])) |
There was a problem hiding this comment.
This code supposed that no recurrent class is nested. #1115
blocks/graph/__init__.py
Outdated
| """Variables of Scan ops.""" | ||
| return list(chain(*[g.variables for g in self._scan_graphs])) | ||
| # BFS | ||
| scan_graphs = self._scan_graphs |
There was a problem hiding this comment.
You probably want to copy scan_graphs here, like e.g. scan_graphs = list(self._scan_graphs).
blocks/bricks/recurrent/base.py
Outdated
| """ | ||
| if not hasattr(self, 'apply') or not self.apply.states: | ||
| return | ||
|
|
There was a problem hiding this comment.
Can you explain how it works? I cannot immediately see it.
There was a problem hiding this comment.
when some subclass call the default initial_states function in the BaseRecurrent class. This line would check whether it is necessary to return the initial states. If the subclass does not have an apply method or its apply method does not contain states, the initial_states would not return anything.
This line would make it to support recurrent class with no apply function or with no states.
There was a problem hiding this comment.
Why do you want to have a class without apply? It's a mistake if a user forgot to define apply and the best is to crash soon.
In a case if apply.states is empty, initial_states would return an empty list before this change, why is it wrong?
There was a problem hiding this comment.
If this line is added, the above code, which contains a recurrent brick with no apply method, would run well.
But, you are right about the apply method. The Brick subclass should follow some design rules. The problem is no code checks whether there is an apply method in a Brick subclass at present.
There was a problem hiding this comment.
@Beronx86 , checking apply.states in BaseRecurrent.initial_states is not a solution. There are quite a few places in Blocks-dependent code where initial_states method is overloaded. Instead, like in your previous solution, initial_states should not be called if application does not have states. Can you please revert back to the previous version of your fix?
There was a problem hiding this comment.
@rizar I think this check could be carried out in Brick.__init__ method. So we can make sure all Brick subclasses contain apply methods. I reverted back the changes in BaseRecurrent.
|
I don't understand, now you have removed your fix, and it is again not supported to have no states property. Why not just implemented like you did in the first place, but with more gentle changes to the code as I suggested? |
The recurrent wrapper does not support loop with no states. But this kind of loop may be useful. So I modified the codes.
Fixes #1112