Skip to content

Commit 3e7fff4

Browse files
committed
Fix calculations in gru_unit_op
1 parent 01d6ccb commit 3e7fff4

File tree

3 files changed

+16
-23
lines changed

3 files changed

+16
-23
lines changed

paddle/operators/gru_unit_op.cc

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -114,18 +114,19 @@ class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker {
114114
.SetDefault(sigmoid)
115115
.InEnum({identity, sigmoid, tanh, relu});
116116
AddComment(R"DOC(
117-
GRUUnit Operator.
118-
119-
This operator implements partial calculations of the GRU unit as follows:
117+
GRUUnit Operator implements partial calculations of the GRU unit as following:
120118
121119
$$
122-
update \ gate: u_t = actGate(xu_t + W_u * hidden_{prev} + bias_u) \\
123-
reset \ gate: r_t = actGate(xr_t + W_r * hidden_{prev} + bias_r) \\
124-
output \ candidate: {h}_t = actNode({xc}_t + W_c * dot(r_t, hidden_{prev}) + bias_c) \\
125-
output: h_t = dot((1-u_t), {h}_t) + dot(u_t, hidden_{prev})
120+
update \ gate: u_t = actGate(xu_t + W_u * h_{t-1} + b_u) \\
121+
reset \ gate: r_t = actGate(xr_t + W_r * h_{t-1} + b_r) \\
122+
output \ candidate: {h}_t = actNode(xc_t + W_c * dot(r_t, h_{t-1}) + b_c) \\
123+
output: h_t = dot((1 - u_t), h_{t-1}) + dot(u_t, {h}_t)
126124
$$
127125
128-
The rest of GRU unit can be completed by using FCOp's output as the input of GRUUnitOp.
126+
which is same as one time step of GRU Operator.
127+
128+
@note To implement the complete GRU unit, fully-connected operator must be
129+
used before to feed xu, xr and xc as the Input of GRUUnit operator.
129130
130131
)DOC");
131132
}
@@ -150,12 +151,6 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
150151
"ResetHiddenPrev");
151152
PADDLE_ENFORCE(ctx->HasInput("Hidden"),
152153
"Input(%s) of GRUUnitGradOp should not be null.", "Hidden");
153-
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Gate")),
154-
"Input(%s@GRAD) of GRUUnitGradOp should not be null.",
155-
"Gate");
156-
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("ResetHiddenPrev")),
157-
"Input(%s@GRAD) of GRUUnitGradOp should not be null.",
158-
"ResetHiddenPrev");
159154
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
160155
"Input(%s@GRAD) of GRUUnitGradOp should not be null.",
161156
"Hidden");

paddle/operators/gru_unit_op.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class GRUUnitKernel : public framework::OpKernel<T> {
110110
auto c = g.slice(c_offsets, extents); // output candidate
111111

112112
// calculate final output
113-
h.device(place) = u * (h_p - c) + c;
113+
h.device(place) = u * (c - h_p) + h_p;
114114
}
115115
};
116116

@@ -185,10 +185,10 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
185185

186186
// backward for unactivated update gate
187187
ActGradCompute(context.Attr<int>("gate_activation"), place, u, u,
188-
d_g.slice(u_offsets, extents), d_h * (h_p - c));
188+
d_g.slice(u_offsets, extents), d_h * (c - h_p));
189189
// backward for unactivated output candidate
190190
ActGradCompute(context.Attr<int>("activation"), place, c, c,
191-
d_g.slice(c_offsets, extents), d_h * (u.constant(T(1)) - u));
191+
d_g.slice(c_offsets, extents), d_h * u);
192192
// backward for reset_hidden_prev
193193
math::gemm<Place, T>(context.device_context(), false, true, batch_size,
194194
frame_size, frame_size, 1,
@@ -210,7 +210,7 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
210210
frame_size, gate_grad_data, frame_size * 3, 0,
211211
weight_grad_data, frame_size * 2);
212212
// backward for hidden_prev
213-
d_h_p.device(place) = d_r_h_p * r + d_h * u;
213+
d_h_p.device(place) = d_r_h_p * r + d_h * (u.constant(T(1)) - u);
214214
math::gemm<Place, T>(context.device_context(), false, true, batch_size,
215215
frame_size, frame_size * 2, 1, gate_grad_data,
216216
frame_size * 3, weight_data, frame_size * 2, 1,

python/paddle/v2/fluid/tests/test_gru_unit_op.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def set_outputs(self):
7777
c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) +
7878
g[:, frame_size * 2:])
7979
g = np.hstack((u_r, c))
80-
h = u * h_p + (1 - u) * c
80+
h = u * c + (1 - u) * h_p
8181
self.outputs = {
8282
'Gate': g.astype('float64'),
8383
'ResetHiddenPrev': r_h_p.astype('float64'),
@@ -93,8 +93,7 @@ def test_check_output(self):
9393

9494
def test_check_grad(self):
9595
self.check_grad(
96-
['Input', 'HiddenPrev', 'Weight'],
97-
['Hidden', 'ResetHiddenPrev', 'Gate'],
96+
['Input', 'HiddenPrev', 'Weight'], ['Hidden'],
9897
max_relative_error=0.007)
9998

10099

@@ -104,7 +103,7 @@ def set_inputs(self):
104103
frame_size = self.frame_size
105104
super(TestGRUUnitOpWithBias, self).set_inputs()
106105
self.inputs['Bias'] = np.random.uniform(
107-
-0.1, 0.1, (1, frame_size * 3)).astype('float32')
106+
-0.1, 0.1, (1, frame_size * 3)).astype('float64')
108107
self.attrs = {
109108
'activation': GRUActivationType.identity,
110109
'gate_activation': GRUActivationType.sigmoid
@@ -117,5 +116,4 @@ def test_check_grad(self):
117116

118117

119118
if __name__ == '__main__':
120-
exit(0) # FIXME(yuyang18): This unittest is not pass. Fix it later
121119
unittest.main()

0 commit comments

Comments
 (0)