6
6
import re
7
7
import cStringIO
8
8
from param_attr import ParamAttr
9
+ import contextlib
9
10
10
11
__all__ = [
11
12
'fc' , 'data' , 'cross_entropy' , 'conv2d' , 'pool2d' , 'embedding' , 'concat' ,
@@ -1395,7 +1396,7 @@ def lod_tensor_to_array(x, table, main_program=None):
1395
1396
return array
1396
1397
1397
1398
1398
- def array_to_lod_tensor (x , table , main_program = None ):
1399
+ def array_to_lod_tensor (x , table , main_program = None , startup_program = None ):
1399
1400
"""
1400
1401
This function creates an operator to convert an array to a
1401
1402
LOD_Tensor.
@@ -1476,7 +1477,11 @@ def zeros(shape, dtype, main_program=None):
1476
1477
return fill_constant (value = 0.0 , ** locals ())
1477
1478
1478
1479
1479
- def increment (x , value = 1.0 , in_place = True , main_program = None ):
1480
+ def increment (x ,
1481
+ value = 1.0 ,
1482
+ in_place = True ,
1483
+ main_program = None ,
1484
+ startup_program = None ):
1480
1485
"""
1481
1486
This function creates an operator to increment each value in the input
1482
1487
`x` by an amount: `value` as mentioned in the input parameter. This
@@ -1495,7 +1500,7 @@ def increment(x, value=1.0, in_place=True, main_program=None):
1495
1500
return out
1496
1501
1497
1502
1498
- def array_write (x , i , array = None , main_program = None ):
1503
+ def array_write (x , i , array = None , main_program = None , startup_program = None ):
1499
1504
"""
1500
1505
This function creates an operator to write the data out as a
1501
1506
LOD_TENSOR_ARRAY.
@@ -1534,7 +1539,7 @@ def less_than(x, y, cond=None, main_program=None, **ignored):
1534
1539
return cond
1535
1540
1536
1541
1537
- def array_read (array , i , main_program = None ):
1542
+ def array_read (array , i , main_program = None , startup_program = None ):
1538
1543
"""
1539
1544
This function creates an operator to read the data in as a
1540
1545
LOD_TENSOR_ARRAY.
@@ -1553,7 +1558,7 @@ def array_read(array, i, main_program=None):
1553
1558
return out
1554
1559
1555
1560
1556
- def shrink_memory (x , i , table , main_program = None ):
1561
+ def shrink_memory (x , i , table , main_program = None , startup_program = None ):
1557
1562
"""
1558
1563
This function creates an operator to shrink_rnn_memory using the RankTable
1559
1564
as mentioned in the input parameter.
@@ -1890,3 +1895,209 @@ def __call__(self):
1890
1895
main_program = self .helper .main_program ,
1891
1896
startup_program = self .helper .startup_program ))
1892
1897
return rlist
1898
+
1899
+
1900
+ class DynamicRNN (object ):
1901
+ BEFORE_RNN = 0
1902
+ IN_RNN = 1
1903
+ AFTER_RNN = 2
1904
+
1905
+ def __init__ (self , name = None , main_program = None , startup_program = None ):
1906
+ self .helper = LayerHelper (
1907
+ 'dynamic_rnn' ,
1908
+ name = name ,
1909
+ main_program = main_program ,
1910
+ startup_program = startup_program )
1911
+ self .status = DynamicRNN .BEFORE_RNN
1912
+ self .lod_rank_table = None
1913
+ self .max_seq_len = None
1914
+ self .step_idx = None
1915
+ self .zero_idx = fill_constant (shape = [1 ], value = 0 , dtype = 'int64' )
1916
+ self .mem_dict = dict ()
1917
+ self .output_array = []
1918
+ self .outputs = []
1919
+ self .cond = self .helper .create_tmp_variable (dtype = 'bool' )
1920
+ self .cond .stop_gradient = False
1921
+ self .while_op = While (self .cond )
1922
+ self .input_array = []
1923
+ self .mem_link = []
1924
+
1925
+ def step_input (self , x ):
1926
+ self ._assert_in_rnn_block_ ("step_input" )
1927
+ if not isinstance (x , Variable ):
1928
+ raise TypeError (
1929
+ "step_input() can only take a Variable as its input" )
1930
+ parent_block = self ._parent_block_ ()
1931
+ if self .lod_rank_table is None :
1932
+ self .lod_rank_table = parent_block .create_var (
1933
+ name = unique_name ('lod_rank_table' ),
1934
+ type = core .VarDesc .VarType .LOD_RANK_TABLE )
1935
+ self .lod_rank_table .stop_gradient = True
1936
+ parent_block .append_op (
1937
+ type = 'lod_rank_table' ,
1938
+ inputs = {"X" : x },
1939
+ outputs = {"Out" : self .lod_rank_table })
1940
+ self .max_seq_len = parent_block .create_var (
1941
+ name = unique_name ('dynamic_rnn_max_seq_len' ), dtype = 'int64' )
1942
+ self .max_seq_len .stop_gradient = False
1943
+ parent_block .append_op (
1944
+ type = 'max_sequence_len' ,
1945
+ inputs = {'RankTable' : self .lod_rank_table },
1946
+ outputs = {"Out" : self .max_seq_len })
1947
+ self .cond .stop_gradient = True
1948
+ parent_block .append_op (
1949
+ type = 'less_than' ,
1950
+ inputs = {'X' : self .step_idx ,
1951
+ 'Y' : self .max_seq_len },
1952
+ outputs = {'Out' : self .cond })
1953
+
1954
+ input_array = parent_block .create_var (
1955
+ name = unique_name ('dynamic_rnn_input_array' ),
1956
+ type = core .VarDesc .VarType .LOD_TENSOR_ARRAY ,
1957
+ dtype = x .dtype )
1958
+ self .input_array .append ((input_array , x .dtype ))
1959
+ parent_block .append_op (
1960
+ type = 'lod_tensor_to_array' ,
1961
+ inputs = {'X' : x ,
1962
+ 'RankTable' : self .lod_rank_table },
1963
+ outputs = {'Out' : input_array })
1964
+ return array_read (
1965
+ array = input_array , i = self .step_idx , ** self .helper .to_kwargs )
1966
+
1967
+ @contextlib .contextmanager
1968
+ def block (self ):
1969
+ if self .status != DynamicRNN .BEFORE_RNN :
1970
+ raise ValueError ("rnn.block() can only be invoke once" )
1971
+ self .step_idx = fill_constant (shape = [1 ], dtype = 'int64' , value = 0 )
1972
+ self .step_idx .stop_gradient = False
1973
+ self .status = DynamicRNN .IN_RNN
1974
+ with self .while_op .block ():
1975
+ yield
1976
+ increment (
1977
+ x = self .step_idx ,
1978
+ value = 1.0 ,
1979
+ in_place = True ,
1980
+ ** self .helper .to_kwargs )
1981
+
1982
+ for new_mem , mem_array in self .mem_link :
1983
+ array_write (
1984
+ x = new_mem ,
1985
+ i = self .step_idx ,
1986
+ array = mem_array ,
1987
+ ** self .helper .to_kwargs )
1988
+
1989
+ less_than (
1990
+ x = self .step_idx ,
1991
+ y = self .max_seq_len ,
1992
+ cond = self .cond ,
1993
+ ** self .helper .to_kwargs )
1994
+
1995
+ self .status = DynamicRNN .AFTER_RNN
1996
+ for each_array in self .output_array :
1997
+ self .outputs .append (
1998
+ array_to_lod_tensor (
1999
+ x = each_array ,
2000
+ table = self .lod_rank_table ,
2001
+ ** self .helper .to_kwargs ))
2002
+
2003
+ def __call__ (self , * args , ** kwargs ):
2004
+ if self .status != DynamicRNN .AFTER_RNN :
2005
+ raise ValueError (
2006
+ "Dynamic RNN outputs can only be retrieved after rnn block" )
2007
+ if len (self .outputs ) == 1 :
2008
+ return self .outputs [0 ]
2009
+ else :
2010
+ return self .outputs
2011
+
2012
+ def memory (self , init = None , shape = None , value = 0.0 , dtype = 'float32' ):
2013
+ self ._assert_in_rnn_block_ ('memory' )
2014
+ if init is not None :
2015
+ if not isinstance (init , Variable ):
2016
+ raise TypeError (
2017
+ "The input arg `init` of memory() must be a Variable" )
2018
+ parent_block = self ._parent_block_ ()
2019
+ mem_array = parent_block .create_var (
2020
+ name = unique_name ('dynamic_rnn_mem_array' ),
2021
+ type = core .VarDesc .VarType .LOD_TENSOR_ARRAY ,
2022
+ dtype = init .dtype )
2023
+ parent_block .append_op (
2024
+ type = 'write_to_array' ,
2025
+ inputs = {'X' : init ,
2026
+ 'I' : self .zero_idx },
2027
+ outputs = {'Out' : mem_array })
2028
+ retv = array_read (
2029
+ array = mem_array , i = self .step_idx , ** self .helper .to_kwargs )
2030
+ retv = shrink_memory (
2031
+ x = retv ,
2032
+ i = self .step_idx ,
2033
+ table = self .lod_rank_table ,
2034
+ ** self .helper .to_kwargs )
2035
+ self .mem_dict [retv .name ] = mem_array
2036
+ return retv
2037
+ else :
2038
+ if len (self .input_array ) == 0 :
2039
+ raise ValueError (
2040
+ "step_input should be invoked before memory(shape=..., value=...)"
2041
+ )
2042
+ parent_block = self ._parent_block_ ()
2043
+ init = parent_block .create_var (
2044
+ name = unique_name ('mem_init' ), dtype = dtype )
2045
+ arr , dtype = self .input_array [0 ]
2046
+ in0 = parent_block .create_var (name = unique_name ('in0' ), dtype = dtype )
2047
+ parent_block .append_op (
2048
+ type = 'read_from_array' ,
2049
+ inputs = {'X' : [arr ],
2050
+ 'I' : [self .zero_idx ]},
2051
+ outputs = {'Out' : [in0 ]})
2052
+ parent_block .append_op (
2053
+ type = 'fill_constant_batch_size_like' ,
2054
+ inputs = {'Input' : [in0 ]},
2055
+ outputs = {'Out' : [init ]},
2056
+ attrs = {
2057
+ 'shape' : [- 1 ] + shape ,
2058
+ 'value' : float (value ),
2059
+ 'dtype' : init .dtype
2060
+ })
2061
+ return self .memory (init = init )
2062
+
2063
+ def update_memory (self , ex_mem , new_mem ):
2064
+ self ._assert_in_rnn_block_ ('update_memory' )
2065
+ if not isinstance (ex_mem , Variable ):
2066
+ raise TypeError ("The input arg `ex_mem` of update_memory() must "
2067
+ "be a Variable" )
2068
+ if not isinstance (new_mem , Variable ):
2069
+ raise TypeError ("The input arg `new_mem` of update_memory() must "
2070
+ "be a Variable" )
2071
+
2072
+ mem_array = self .mem_dict .get (ex_mem .name , None )
2073
+ if mem_array is None :
2074
+ raise ValueError ("Please invoke memory before update_memory" )
2075
+ if self .lod_rank_table is None :
2076
+ raise ValueError ("Please invoke step_input before update_memory" )
2077
+
2078
+ self .mem_link .append ((new_mem , mem_array ))
2079
+
2080
+ def output (self , * outputs ):
2081
+ self ._assert_in_rnn_block_ ('output' )
2082
+ parent_block = self ._parent_block_ ()
2083
+ for each in outputs :
2084
+ outside_array = parent_block .create_var (
2085
+ name = unique_name ("_" .join (
2086
+ [self .helper .name , "output_array" , each .name ])),
2087
+ type = core .VarDesc .VarType .LOD_TENSOR_ARRAY ,
2088
+ dtype = each .dtype )
2089
+ array_write (x = each , i = self .step_idx , array = outside_array )
2090
+ self .output_array .append (outside_array )
2091
+
2092
+ def _parent_block_ (self ):
2093
+ prog = self .helper .main_program
2094
+ parent_idx = prog .current_block ().parent_idx
2095
+ assert parent_idx >= 0
2096
+ parent_block = prog .block (parent_idx )
2097
+
2098
+ return parent_block
2099
+
2100
+ def _assert_in_rnn_block_ (self , method ):
2101
+ if self .status != DynamicRNN .IN_RNN :
2102
+ raise ValueError ("{0} can only be invoked inside rnn block." .format (
2103
+ method ))
0 commit comments