@@ -28,6 +28,14 @@ def relu(x):
28
28
return np .maximum (x , 0 )
29
29
30
30
31
+ ACTVATION = {
32
+ 'identity' : identity ,
33
+ 'sigmoid' : sigmoid ,
34
+ 'tanh' : tanh ,
35
+ 'relu' : relu
36
+ }
37
+
38
+
31
39
def lstm (
32
40
input , # T x 4D
33
41
lod , # 1 x N
@@ -37,37 +45,45 @@ def lstm(
37
45
w_b = None , # 1 x 4D
38
46
w_c = None , # 1 x 3D
39
47
is_reverse = False ,
40
- gate_act = None ,
41
- cell_act = None ,
42
- cand_act = None ):
43
- def _step (x , w_h , w_c , h_pre , c_pre , gate_act , cell_act , cand_act ):
48
+ act_gate = None ,
49
+ act_cell = None ,
50
+ act_cand = None ):
51
+ def _step (x , w_h , w_c , h_pre , c_pre , act_gate , act_cell , act_cand ):
44
52
g = np .dot (h_pre , w_h ) # 1 x 4D
45
53
g = g + x
46
54
g = np .reshape (g , (1 , g .size ))
47
55
c_tmp , g_i , g_f , g_o = np .split (g , 4 , axis = 1 )
48
56
if w_c is None :
49
- g_i = gate_act (g_i ) # 1 x D
50
- g_f = gate_act (g_f ) # 1 x D
57
+ g_i = act_gate (g_i ) # 1 x D
58
+ g_f = act_gate (g_f ) # 1 x D
51
59
else :
52
60
w_ic , w_fc , w_oc = np .split (w_c , 3 , axis = 1 )
53
- g_i = gate_act (g_i + w_ic * c_pre ) # 1 x D
54
- g_f = gate_act (g_f + w_fc * c_pre ) # 1 x D
55
- c = g_f * c_pre + g_i * cand_act (c_tmp ) # 1 x D
61
+ g_i = act_gate (g_i + w_ic * c_pre ) # 1 x D
62
+ g_f = act_gate (g_f + w_fc * c_pre ) # 1 x D
63
+ c = g_f * c_pre + g_i * act_cand (c_tmp ) # 1 x D
56
64
57
65
if w_c is None :
58
- g_o = gate_act (g_o ) # 1 x D
66
+ g_o = act_gate (g_o ) # 1 x D
59
67
else :
60
68
_ , _ , w_oc = np .split (w_c , 3 , axis = 1 )
61
- g_o = gate_act (g_o + w_oc * c ) # 1 x D
62
- h = g_o * cell_act (c )
63
- bg = np .concatenate ((cand_act (c_tmp ), g_i , g_f , g_o ), axis = 1 )
69
+ g_o = act_gate (g_o + w_oc * c ) # 1 x D
70
+ h = g_o * act_cell (c )
71
+ bg = np .concatenate ((act_cand (c_tmp ), g_i , g_f , g_o ), axis = 1 )
64
72
return h , c , bg
65
73
74
+ def _reverse (x , lod ):
75
+ y = np .zeros_like (x )
76
+ for i in range (len (lod ) - 1 ):
77
+ b , e = lod [i ], lod [i + 1 ]
78
+ y [b :e , :] = np .flip (x [b :e , :], 0 )
79
+ return y
80
+
66
81
offset = lod [0 ]
67
82
batch_size = len (offset ) - 1
68
83
hidden = []
69
84
cell = []
70
85
gate = []
86
+ input = _reverse (input , offset ) if is_reverse else input
71
87
if w_b is not None :
72
88
input = input + np .tile (w_b , (offset [- 1 ], 1 ))
73
89
for i in range (batch_size ):
@@ -78,47 +94,62 @@ def _step(x, w_h, w_c, h_pre, c_pre, gate_act, cell_act, cand_act):
78
94
c_pre = c0 [i ] # 1 x D
79
95
for j in range (seq_len ):
80
96
# compute one step
81
- h_pre , c_pre , g_pre = _step (x [j ], w_h , w_c , h_pre , c_pre , gate_act ,
82
- cell_act , cand_act )
97
+ h_pre , c_pre , g_pre = _step (x [j ], w_h , w_c , h_pre , c_pre , act_gate ,
98
+ act_cell , act_cand )
83
99
hidden .append (h_pre .flatten ())
84
100
cell .append (c_pre .flatten ())
85
101
gate .append (g_pre .flatten ())
86
102
87
103
hidden = np .array (hidden ).astype ("float64" )
88
104
cell = np .array (cell ).astype ("float64" )
89
105
gate = np .array (gate ).astype ("float64" )
106
+
107
+ hidden = _reverse (hidden , offset ) if is_reverse else hidden
108
+ cell = _reverse (cell , offset ) if is_reverse else cell
109
+
90
110
assert gate .shape == input .shape
91
111
assert hidden .shape == (input .shape [0 ], input .shape [1 ] / 4 )
92
112
assert cell .shape == (input .shape [0 ], input .shape [1 ] / 4 )
93
113
return hidden , cell , gate
94
114
95
115
96
- class LstmUnitTest (OpTest ):
116
+ class TestLstmOp (OpTest ):
97
117
def set_data (self ):
98
- D = 4
99
- #lod = [[0, 2, 6, 9]]
100
- lod = [[0 , 1 ]]
101
- shape = (1 , D )
102
-
103
- x = np .random .normal (size = (1 , 4 * D )).astype ("float64" )
104
- h0 = np .zeros ((4 , D )).astype ("float64" )
105
- c0 = np .zeros ((4 , D )).astype ("float64" )
106
- w = np .random .normal (size = (D , 4 * D )).astype ("float64" )
107
- b = np .random .normal (size = (1 , 7 * D )).astype ("float64" )
108
-
109
- w_b = b [:, 0 :4 * D ]
110
- w_c = b [:, 4 * D :]
111
- #h, c, g = lstm(x, lod, h0, c0, w, w_b, w_c, False, sigmoid, tanh, tanh)
112
- h , c , g = lstm (x , lod , h0 , c0 , w , w_b , w_c , False , identity , identity ,
113
- identity )
118
+ self .lod = [[0 , 2 , 6 , 9 ]]
119
+ self .D = 64
120
+ self .sort_idx = [2 , 6 , 0 , 3 , 7 , 1 , 4 , 8 , 5 ]
121
+
122
+ self .act_gate = "sigmoid"
123
+ self .act_cell = "tanh"
124
+ self .act_cand = "tanh"
125
+
126
+ self .is_reverse = False
127
+
128
+ def setUp (self ):
129
+ self .set_data ()
130
+ self .op_type = "lstm"
131
+
132
+ T = self .lod [0 ][- 1 ]
133
+ N = len (self .lod [0 ]) - 1
134
+
135
+ x = np .random .normal (size = (T , 4 * self .D )).astype ("float64" )
136
+ h0 = np .zeros ((N , self .D )).astype ("float64" )
137
+ c0 = np .zeros ((N , self .D )).astype ("float64" )
138
+ w = np .random .normal (size = (self .D , 4 * self .D )).astype ("float64" )
139
+ b = np .random .normal (size = (1 , 7 * self .D )).astype ("float64" )
140
+
141
+ w_b = b [:, 0 :4 * self .D ]
142
+ w_c = b [:, 4 * self .D :]
143
+ h , c , g = lstm (x , self .lod , h0 , c0 , w , w_b , w_c , self .is_reverse ,
144
+ ACTVATION [self .act_gate ], ACTVATION [self .act_cell ],
145
+ ACTVATION [self .act_cand ])
114
146
115
147
g_sort = np .zeros_like (x )
116
- #idx = [2,6,0,3,7,1,4,8,5]
117
- #for i, j in enumerate(idx):
118
- # g_sort[i, :] = g[j, :]
148
+ for i , j in enumerate (self .sort_idx ):
149
+ g_sort [i , :] = g [j , :]
119
150
120
151
self .inputs = {
121
- 'Input' : (x , lod ),
152
+ 'Input' : (x , self . lod ),
122
153
'H0' : h0 ,
123
154
'C0' : c0 ,
124
155
'Weight' : w ,
@@ -127,19 +158,28 @@ def set_data(self):
127
158
self .outputs = {'Hidden' : h , 'Cell' : c , 'BatchGate' : g_sort }
128
159
self .attrs = {
129
160
'usePeepholes' : True ,
130
- 'isReverse' : False ,
131
- 'gateActivation' : 'linear ' ,
132
- 'cellActivation' : 'linear ' ,
133
- 'candidateActivation' : 'linear '
161
+ 'isReverse' : self . is_reverse ,
162
+ 'gateActivation' : 'sigmoid ' ,
163
+ 'cellActivation' : 'tanh ' ,
164
+ 'candidateActivation' : 'tanh '
134
165
}
135
166
136
- def setUp (self ):
137
- self .set_data ()
138
- self .op_type = "lstm"
139
-
140
167
def test_check_output (self ):
141
168
self .check_output ()
142
169
143
170
171
+ class TestLstmOpRerverse (TestLstmOp ):
172
+ def set_data (self ):
173
+ self .lod = [[0 , 2 , 6 , 9 ]]
174
+ self .D = 64
175
+ self .sort_idx = [2 , 6 , 0 , 3 , 7 , 1 , 4 , 8 , 5 ]
176
+
177
+ self .act_gate = "sigmoid"
178
+ self .act_cell = "tanh"
179
+ self .act_cand = "tanh"
180
+
181
+ self .is_reverse = True
182
+
183
+
144
184
if __name__ == "__main__" :
145
185
unittest .main ()
0 commit comments