Skip to content

Commit 6f08a21

Browse files
authored
add gru unit layer wrapper (#6325)
1 parent e09e21b commit 6f08a21

File tree

1 file changed

+71
-0
lines changed

1 file changed

+71
-0
lines changed

python/paddle/v2/fluid/layers.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,77 @@ def dynamic_lstm(input,
180180
return hidden, cell
181181

182182

183+
def gru_unit(input,
184+
hidden,
185+
size,
186+
weight=None,
187+
bias=None,
188+
activation='tanh',
189+
gate_activation='sigmoid',
190+
main_program=None,
191+
startup_program=None):
192+
"""
193+
GRUUnit Operator implements partial calculations of the GRU unit as following:
194+
195+
$$
196+
update \ gate: u_t = actGate(xu_t + W_u * h_{t-1} + b_u) \\
197+
reset \ gate: r_t = actGate(xr_t + W_r * h_{t-1} + b_r) \\
198+
output \ candidate: {h}_t = actNode(xc_t + W_c * dot(r_t, h_{t-1}) + b_c) \\
199+
output: h_t = dot((1 - u_t), h_{t-1}) + dot(u_t, {h}_t)
200+
$$
201+
202+
which is same as one time step of GRU Operator.
203+
204+
@note To implement the complete GRU unit, fully-connected operator must be
205+
used before to feed xu, xr and xc as the Input of GRUUnit operator.
206+
207+
TODO(ChunweiYan) add more document here
208+
"""
209+
activation_dict = dict(
210+
identity=0,
211+
sigmoid=1,
212+
tanh=2,
213+
relu=3, )
214+
activation = activation_dict[activation]
215+
gate_activation = activation_dict[gate_activation]
216+
217+
helper = LayerHelper('gru_unit', **locals())
218+
dtype = helper.input_dtype()
219+
size = size / 3
220+
221+
# create weight
222+
if weight is None:
223+
weight = helper.create_parameter(
224+
attr=helper.param_attr, shape=[size, 3 * size], dtype=dtype)
225+
226+
# create bias
227+
if bias is None:
228+
bias_size = [1, 3 * size]
229+
bias = helper.create_parameter(
230+
attr=helper.bias_attr, shape=bias_size, dtype=dtype, is_bias=True)
231+
232+
gate = helper.create_tmp_variable(dtype)
233+
reset_hidden_pre = helper.create_tmp_variable(dtype)
234+
updated_hidden = helper.create_tmp_variable(dtype)
235+
236+
helper.append_op(
237+
type='gru_unit',
238+
inputs={'Input': input,
239+
'HiddenPrev': hidden,
240+
'Weight': weight},
241+
outputs={
242+
'Gate': gate,
243+
'ResetHiddenPrev': reset_hidden_pre,
244+
'Hidden': updated_hidden,
245+
},
246+
attrs={
247+
'activation': 0,
248+
'gate_activation': 1,
249+
})
250+
251+
return updated_hidden, reset_hidden_pre, gate
252+
253+
183254
def data(name,
184255
shape,
185256
append_batch_size=True,

0 commit comments

Comments
 (0)