@@ -180,6 +180,77 @@ def dynamic_lstm(input,
180
180
return hidden , cell
181
181
182
182
183
+ def gru_unit (input ,
184
+ hidden ,
185
+ size ,
186
+ weight = None ,
187
+ bias = None ,
188
+ activation = 'tanh' ,
189
+ gate_activation = 'sigmoid' ,
190
+ main_program = None ,
191
+ startup_program = None ):
192
+ """
193
+ GRUUnit Operator implements partial calculations of the GRU unit as following:
194
+
195
+ $$
196
+ update \ gate: u_t = actGate(xu_t + W_u * h_{t-1} + b_u) \\
197
+ reset \ gate: r_t = actGate(xr_t + W_r * h_{t-1} + b_r) \\
198
+ output \ candidate: {h}_t = actNode(xc_t + W_c * dot(r_t, h_{t-1}) + b_c) \\
199
+ output: h_t = dot((1 - u_t), h_{t-1}) + dot(u_t, {h}_t)
200
+ $$
201
+
202
+ which is same as one time step of GRU Operator.
203
+
204
+ @note To implement the complete GRU unit, fully-connected operator must be
205
+ used before to feed xu, xr and xc as the Input of GRUUnit operator.
206
+
207
+ TODO(ChunweiYan) add more document here
208
+ """
209
+ activation_dict = dict (
210
+ identity = 0 ,
211
+ sigmoid = 1 ,
212
+ tanh = 2 ,
213
+ relu = 3 , )
214
+ activation = activation_dict [activation ]
215
+ gate_activation = activation_dict [gate_activation ]
216
+
217
+ helper = LayerHelper ('gru_unit' , ** locals ())
218
+ dtype = helper .input_dtype ()
219
+ size = size / 3
220
+
221
+ # create weight
222
+ if weight is None :
223
+ weight = helper .create_parameter (
224
+ attr = helper .param_attr , shape = [size , 3 * size ], dtype = dtype )
225
+
226
+ # create bias
227
+ if bias is None :
228
+ bias_size = [1 , 3 * size ]
229
+ bias = helper .create_parameter (
230
+ attr = helper .bias_attr , shape = bias_size , dtype = dtype , is_bias = True )
231
+
232
+ gate = helper .create_tmp_variable (dtype )
233
+ reset_hidden_pre = helper .create_tmp_variable (dtype )
234
+ updated_hidden = helper .create_tmp_variable (dtype )
235
+
236
+ helper .append_op (
237
+ type = 'gru_unit' ,
238
+ inputs = {'Input' : input ,
239
+ 'HiddenPrev' : hidden ,
240
+ 'Weight' : weight },
241
+ outputs = {
242
+ 'Gate' : gate ,
243
+ 'ResetHiddenPrev' : reset_hidden_pre ,
244
+ 'Hidden' : updated_hidden ,
245
+ },
246
+ attrs = {
247
+ 'activation' : 0 ,
248
+ 'gate_activation' : 1 ,
249
+ })
250
+
251
+ return updated_hidden , reset_hidden_pre , gate
252
+
253
+
183
254
def data (name ,
184
255
shape ,
185
256
append_batch_size = True ,
0 commit comments