Skip to content

Commit fa72e54

Browse files
reyoungYang Yang(Tony)
authored andcommitted
Python API for StaticRNN (#4991)
1 parent 23bf6b2 commit fa72e54

File tree

4 files changed

+226
-10
lines changed

4 files changed

+226
-10
lines changed

python/paddle/v2/framework/framework.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ def data_type(self):
113113
def lod_level(self):
114114
return self.desc.lod_level()
115115

116+
@property
117+
def type(self):
118+
return self.desc.type()
119+
116120
@staticmethod
117121
def _unique_var_name_():
118122
uid = core.unique_integer() # unique during whole process.

python/paddle/v2/framework/layer_helper.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
from paddle.v2.framework.framework import Variable, OpProtoHolder, g_program, g_init_program
2-
import paddle.v2.framework.core as core
31
import copy
42
import itertools
53

4+
import paddle.v2.framework.core as core
5+
6+
from paddle.v2.framework.framework import Variable, g_program, \
7+
g_init_program
8+
69

710
def unique_name(prefix):
811
uid = core.unique_integer() # unique during whole process.
@@ -130,6 +133,9 @@ def create_tmp_variable(self, dtype):
130133
dtype=dtype,
131134
persistable=False)
132135

136+
def create_variable(self, *args, **kwargs):
137+
return self.program.current_block().create_var(*args, **kwargs)
138+
133139
def create_global_variable(self, *args, **kwargs):
134140
return self.program.global_block().create_var(
135141
*args, persistable=False, **kwargs)

python/paddle/v2/framework/layers.py

Lines changed: 176 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from paddle.v2.framework.layer_helper import LayerHelper
1+
from paddle.v2.framework.layer_helper import LayerHelper, unique_name
22
import paddle.v2.framework.core as core
3-
from paddle.v2.framework.framework import OpProtoHolder, Variable
3+
from paddle.v2.framework.framework import OpProtoHolder, Variable, Program
44
import re
55

66
__all__ = [
7-
'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat'
7+
'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat',
8+
'StaticRNN'
89
]
910

1011

@@ -26,7 +27,9 @@ def fc(input,
2627
mul_results = []
2728
for input_var, param_attr in helper.iter_inputs_and_params():
2829
input_shape = input_var.shape
29-
param_shape = list(input_shape[num_flatten_dims:]) + [size]
30+
param_shape = [
31+
reduce(lambda a, b: a * b, input_shape[num_flatten_dims:], 1)
32+
] + [size]
3033

3134
w = helper.create_parameter(
3235
attr=param_attr, shape=param_shape, dtype=dtype)
@@ -38,10 +41,8 @@ def fc(input,
3841
"Y": w,
3942
},
4043
outputs={"Out": tmp},
41-
attrs={
42-
'x_num_col_dims': num_flatten_dims,
43-
'y_num_col_dims': len(input_shape) - num_flatten_dims
44-
})
44+
attrs={'x_num_col_dims': num_flatten_dims,
45+
'y_num_col_dims': 1})
4546
mul_results.append(tmp)
4647

4748
# sum
@@ -273,3 +274,170 @@ def pool2d(input,
273274
})
274275

275276
return pool_out
277+
278+
279+
class BlockGuard(object):
280+
"""
281+
BlockGuard used to create sub-block in program by using Python `with`
282+
keyword.
283+
"""
284+
285+
def __init__(self, program):
286+
if not isinstance(program, Program):
287+
raise TypeError("BlockGuard takes a program")
288+
self.program = program
289+
290+
def __enter__(self):
291+
self.program.create_block()
292+
293+
def __exit__(self, exc_type, exc_val, exc_tb):
294+
self.program.rollback()
295+
if exc_type is not None:
296+
return False # re-raise exception
297+
return True
298+
299+
300+
class StaticRNNGuard(BlockGuard):
301+
def __init__(self, rnn):
302+
if not isinstance(rnn, StaticRNN):
303+
raise TypeError("StaticRNNGuard takes an StaticRNN")
304+
super(StaticRNNGuard, self).__init__(rnn.helper.program)
305+
self.rnn = rnn
306+
307+
def __enter__(self):
308+
self.rnn.status = StaticRNN.IN_RNN_BLOCK
309+
return super(StaticRNNGuard, self).__enter__()
310+
311+
def __exit__(self, exc_type, exc_val, exc_tb):
312+
self.rnn.status = StaticRNN.AFTER_RNN_BLOCK
313+
self.rnn.complete_rnn_op()
314+
return super(StaticRNNGuard, self).__exit__(exc_type, exc_val, exc_tb)
315+
316+
317+
class StaticRNNMemoryLink(object):
318+
"""
319+
:param init: the initial variable for Memory
320+
:type init: Variable
321+
:param pre_mem: the memory variable in previous time step
322+
:type pre_mem: Variable
323+
:param mem: the memory variable in current time step
324+
:type mem: Variable
325+
"""
326+
327+
def __init__(self, init, pre_mem, mem=None):
328+
self.init = init
329+
self.pre_mem = pre_mem
330+
self.mem = mem
331+
332+
333+
class StaticRNN(object):
334+
BEFORE_RNN_BLOCK = 0
335+
IN_RNN_BLOCK = 1
336+
AFTER_RNN_BLOCK = 2
337+
338+
def __init__(self, name=None, program=None):
339+
self.helper = LayerHelper("static_rnn", name=name, program=program)
340+
self.memories = {} # memory map, from pre_mem.name --> MemoryLink
341+
self.inputs = [] # input variable list in current block
342+
self.outputs = [] # output variable list in parent block
343+
self.status = StaticRNN.BEFORE_RNN_BLOCK # status flag.
344+
# sequence length, since it is a static RNN, sequence length are fixed.
345+
self.seq_len = None
346+
347+
def step(self):
348+
return StaticRNNGuard(self)
349+
350+
def _assert_in_rnn_block_(self, method):
351+
if self.status != StaticRNN.IN_RNN_BLOCK:
352+
raise ValueError("You must invoke {0} in rnn block".format(method))
353+
354+
def memory(self, init=None, shape=None, dtype=None, init_value=0):
355+
self._assert_in_rnn_block_('memory')
356+
if init is None:
357+
if shape is None or dtype is None:
358+
raise ValueError(
359+
"if init is None, memory at least need shape and dtype")
360+
parent_block = self.parent_block()
361+
var_name = unique_name("@".join([self.helper.name, "memory_boot"]))
362+
boot_var = parent_block.create_var(
363+
name=var_name, shape=shape, dtype=dtype, persistable=False)
364+
365+
parent_block.append_op(
366+
type="fill_constant",
367+
inputs={},
368+
outputs={'Out': [boot_var]},
369+
attrs={
370+
'value': init_value,
371+
'shape': boot_var.shape,
372+
'data_type': boot_var.data_type
373+
})
374+
375+
return self.memory(init=boot_var)
376+
else:
377+
pre_mem = self.helper.create_variable(
378+
name=unique_name("@".join([self.helper.name, "mem"])),
379+
dtype=init.data_type,
380+
shape=init.shape)
381+
self.memories[pre_mem.name] = StaticRNNMemoryLink(
382+
init=init, pre_mem=pre_mem)
383+
return pre_mem
384+
385+
def step_input(self, x):
386+
self._assert_in_rnn_block_('step_input')
387+
if not isinstance(x, Variable):
388+
raise TypeError("step input takes a Variable")
389+
if self.seq_len is None:
390+
self.seq_len = x.shape[1]
391+
elif self.seq_len != x.shape[1]:
392+
raise ValueError("Static RNN only take fix seq_len input")
393+
394+
ipt = self.helper.create_variable(
395+
name=x.name,
396+
dtype=x.data_type,
397+
shape=[-1] + list(x.shape[2:]),
398+
type=x.type)
399+
self.inputs.append(ipt)
400+
return ipt
401+
402+
def step_output(self, o):
403+
self._assert_in_rnn_block_('step_output')
404+
if not isinstance(o, Variable):
405+
raise TypeError("step output takes a Variable")
406+
407+
out_var = self.parent_block().create_var(
408+
name=o.name,
409+
shape=[-1, self.seq_len] + list(o.shape[1:]),
410+
dtype=o.data_type)
411+
412+
self.outputs.append(out_var)
413+
414+
def output(self, *outputs):
415+
for each in outputs:
416+
self.step_output(each)
417+
418+
def update_memory(self, mem, var):
419+
if not isinstance(mem, Variable) or not isinstance(var, Variable):
420+
raise TypeError("update memory should take variables")
421+
self.memories[mem.name].mem = var
422+
423+
def parent_block(self):
424+
prog = self.helper.program
425+
parent_idx = prog.current_block().parent_idx
426+
assert parent_idx >= 0
427+
parent_block = prog.block(parent_idx)
428+
return parent_block
429+
430+
def __call__(self, *args, **kwargs):
431+
if self.status != StaticRNN.AFTER_RNN_BLOCK:
432+
raise ValueError("RNN output can only be retrieved after rnn block")
433+
if len(self.outputs) == 0:
434+
raise ValueError("RNN has no output")
435+
elif len(self.outputs) == 1:
436+
return self.outputs[0]
437+
else:
438+
return self.outputs
439+
440+
def complete_rnn_op(self):
441+
# TODO(yuyang18): Create RNN Op here.
442+
# Implement this method after RNN op complete.
443+
pass
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import unittest
2+
from paddle.v2.framework.layers import *
3+
from paddle.v2.framework.framework import g_program
4+
5+
6+
class TestRNN(unittest.TestCase):
7+
def test_rnn(self):
8+
img = data(
9+
shape=[
10+
80, # sequence length
11+
22, # image height
12+
22
13+
], # image width
14+
data_type='float32',
15+
name='image')
16+
hidden = fc(input=img, size=100, act='sigmoid', num_flatten_dims=2)
17+
self.assertEqual((-1, 80, 100), hidden.shape)
18+
hidden = fc(input=hidden, size=100, act='sigmoid', num_flatten_dims=2)
19+
self.assertEqual((-1, 80, 100), hidden.shape)
20+
21+
rnn = StaticRNN()
22+
with rnn.step():
23+
hidden = rnn.step_input(hidden)
24+
self.assertEqual((-1, 100), hidden.shape)
25+
memory = rnn.memory(shape=(-1, 32), dtype='float32', init_value=0.0)
26+
27+
rnn_out = fc(input=[hidden, memory], size=32, act='sigmoid')
28+
self.assertEqual((-1, 32), rnn_out.shape)
29+
rnn.update_memory(memory, rnn_out)
30+
rnn.output(rnn_out)
31+
32+
out = rnn()
33+
self.assertEqual((-1, 80, 32), out.shape)
34+
print g_program
35+
36+
37+
if __name__ == '__main__':
38+
unittest.main()

0 commit comments

Comments
 (0)