@@ -52,24 +52,23 @@ def _step(x, w_h, w_c, h_pre, c_pre, act_gate, act_cell, act_cand):
52
52
g = np .dot (h_pre , w_h ) # 1 x 4D
53
53
g = g + x
54
54
g = np .reshape (g , (1 , g .size ))
55
- c_tmp , g_i , g_f , g_o = np .split (g , 4 , axis = 1 )
55
+ c , g_i , g_f , g_o = np .split (g , 4 , axis = 1 )
56
56
if w_c is None :
57
57
g_i = act_gate (g_i ) # 1 x D
58
58
g_f = act_gate (g_f ) # 1 x D
59
59
else :
60
60
w_ic , w_fc , w_oc = np .split (w_c , 3 , axis = 1 )
61
61
g_i = act_gate (g_i + w_ic * c_pre ) # 1 x D
62
62
g_f = act_gate (g_f + w_fc * c_pre ) # 1 x D
63
- c = g_f * c_pre + g_i * act_cand (c_tmp ) # 1 x D
63
+ c = g_f * c_pre + g_i * act_cand (c ) # 1 x D
64
64
65
65
if w_c is None :
66
66
g_o = act_gate (g_o ) # 1 x D
67
67
else :
68
68
_ , _ , w_oc = np .split (w_c , 3 , axis = 1 )
69
69
g_o = act_gate (g_o + w_oc * c ) # 1 x D
70
70
h = g_o * act_cell (c )
71
- bg = np .concatenate ((act_cand (c_tmp ), g_i , g_f , g_o ), axis = 1 )
72
- return h , c , bg
71
+ return h , c
73
72
74
73
def _reverse (x , lod ):
75
74
y = np .zeros_like (x )
@@ -82,7 +81,6 @@ def _reverse(x, lod):
82
81
batch_size = len (offset ) - 1
83
82
hidden = []
84
83
cell = []
85
- gate = []
86
84
input = _reverse (input , offset ) if is_reverse else input
87
85
if w_b is not None :
88
86
input = input + np .tile (w_b , (offset [- 1 ], 1 ))
@@ -94,30 +92,26 @@ def _reverse(x, lod):
94
92
c_pre = c0 [i ] # 1 x D
95
93
for j in range (seq_len ):
96
94
# compute one step
97
- h_pre , c_pre , g_pre = _step (x [j ], w_h , w_c , h_pre , c_pre , act_gate ,
98
- act_cell , act_cand )
95
+ h_pre , c_pre = _step (x [j ], w_h , w_c , h_pre , c_pre , act_gate ,
96
+ act_cell , act_cand )
99
97
hidden .append (h_pre .flatten ())
100
98
cell .append (c_pre .flatten ())
101
- gate .append (g_pre .flatten ())
102
99
103
100
hidden = np .array (hidden ).astype ('float64' )
104
101
cell = np .array (cell ).astype ('float64' )
105
- gate = np .array (gate ).astype ('float64' )
106
102
107
103
hidden = _reverse (hidden , offset ) if is_reverse else hidden
108
104
cell = _reverse (cell , offset ) if is_reverse else cell
109
105
110
- assert gate .shape == input .shape
111
106
assert hidden .shape == (input .shape [0 ], input .shape [1 ] / 4 )
112
107
assert cell .shape == (input .shape [0 ], input .shape [1 ] / 4 )
113
- return hidden , cell , gate
108
+ return hidden , cell
114
109
115
110
116
111
class TestLstmOp (OpTest ):
117
112
def set_argument (self ):
118
- self .lod = [[0 , 2 , 6 , 9 ]]
113
+ self .lod = [[0 , 2 , 6 ]]
119
114
self .D = 16
120
- self .sort_idx = [2 , 6 , 0 , 3 , 7 , 1 , 4 , 8 , 5 ]
121
115
122
116
self .act_gate = 'sigmoid'
123
117
self .act_cell = 'tanh'
@@ -141,22 +135,18 @@ def setUp(self):
141
135
142
136
w_b = b [:, 0 :4 * self .D ]
143
137
w_c = b [:, 4 * self .D :]
144
- h , c , g = lstm (x , self .lod , h0 , c0 , w , w_b , w_c , self .is_reverse ,
145
- ACTVATION [self .act_gate ], ACTVATION [self .act_cell ],
146
- ACTVATION [self .act_cand ])
147
-
148
- g_sort = np .zeros_like (x )
149
- for i , j in enumerate (self .sort_idx ):
150
- g_sort [i , :] = g [j , :]
138
+ h , c = lstm (x , self .lod , h0 , c0 , w , w_b , w_c , self .is_reverse ,
139
+ ACTVATION [self .act_gate ], ACTVATION [self .act_cell ],
140
+ ACTVATION [self .act_cand ])
151
141
152
142
self .inputs = {'Input' : (x , self .lod ), 'Weight' : w , 'Bias' : b }
153
- self .inputs ['H0' ] = h0
154
- self .inputs ['C0' ] = c0
143
+ if self .has_initial_state :
144
+ self .inputs ['H0' ] = h0
145
+ self .inputs ['C0' ] = c0
155
146
156
147
self .outputs = {
157
148
'Hidden' : (h , self .lod ),
158
149
'Cell' : (c , self .lod ),
159
- 'BatchGate' : g_sort ,
160
150
}
161
151
self .attrs = {
162
152
'usePeepholes' : True ,
@@ -179,9 +169,8 @@ def test_check_grad(self):
179
169
180
170
class TestLstmOpHasNoInitial (TestLstmOp ):
181
171
def set_argument (self ):
182
- self .lod = [[0 , 2 , 6 , 9 ]]
183
- self .D = 64
184
- self .sort_idx = [2 , 6 , 0 , 3 , 7 , 1 , 4 , 8 , 5 ]
172
+ self .lod = [[0 , 2 , 6 ]]
173
+ self .D = 16
185
174
186
175
self .act_gate = 'sigmoid'
187
176
self .act_cell = 'tanh'
@@ -193,9 +182,8 @@ def set_argument(self):
193
182
194
183
class TestLstmOpRerverse (TestLstmOp ):
195
184
def set_argument (self ):
196
- self .lod = [[0 , 2 , 6 , 9 ]]
197
- self .D = 64
198
- self .sort_idx = [2 , 6 , 0 , 3 , 7 , 1 , 4 , 8 , 5 ]
185
+ self .lod = [[0 , 2 , 6 ]]
186
+ self .D = 16
199
187
200
188
self .act_gate = 'sigmoid'
201
189
self .act_cell = 'tanh'
0 commit comments