Skip to content

Commit 778b981

Browse files
authored
Merge pull request #5804 from guoshengCS/fix-GRUUnitOp-dev
Fix calculations in gru_unit_op to consistent with gru_op
2 parents 23741aa + b6b7ab6 commit 778b981

File tree

3 files changed

+60
-60
lines changed

3 files changed

+60
-60
lines changed

paddle/operators/gru_unit_op.cc

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -114,18 +114,19 @@ class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker {
114114
.SetDefault(sigmoid)
115115
.InEnum({identity, sigmoid, tanh, relu});
116116
AddComment(R"DOC(
117-
GRUUnit Operator.
118-
119-
This operator implements partial calculations of the GRU unit as follows:
117+
GRUUnit Operator implements partial calculations of the GRU unit as following:
120118
121119
$$
122-
update \ gate: u_t = actGate(xu_t + W_u * hidden_{prev} + bias_u) \\
123-
reset \ gate: r_t = actGate(xr_t + W_r * hidden_{prev} + bias_r) \\
124-
output \ candidate: {h}_t = actNode({xc}_t + W_c * dot(r_t, hidden_{prev}) + bias_c) \\
125-
output: h_t = dot((1-u_t), {h}_t) + dot(u_t, hidden_{prev})
120+
update \ gate: u_t = actGate(xu_t + W_u * h_{t-1} + b_u) \\
121+
reset \ gate: r_t = actGate(xr_t + W_r * h_{t-1} + b_r) \\
122+
output \ candidate: {h}_t = actNode(xc_t + W_c * dot(r_t, h_{t-1}) + b_c) \\
123+
output: h_t = dot((1 - u_t), h_{t-1}) + dot(u_t, {h}_t)
126124
$$
127125
128-
The rest of GRU unit can be completed by using FCOp's output as the input of GRUUnitOp.
126+
which is same as one time step of GRU Operator.
127+
128+
@note To implement the complete GRU unit, fully-connected operator must be
129+
used before to feed xu, xr and xc as the Input of GRUUnit operator.
129130
130131
)DOC");
131132
}
@@ -150,12 +151,6 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
150151
"ResetHiddenPrev");
151152
PADDLE_ENFORCE(ctx->HasInput("Hidden"),
152153
"Input(%s) of GRUUnitGradOp should not be null.", "Hidden");
153-
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Gate")),
154-
"Input(%s@GRAD) of GRUUnitGradOp should not be null.",
155-
"Gate");
156-
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("ResetHiddenPrev")),
157-
"Input(%s@GRAD) of GRUUnitGradOp should not be null.",
158-
"ResetHiddenPrev");
159154
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
160155
"Input(%s@GRAD) of GRUUnitGradOp should not be null.",
161156
"Hidden");

paddle/operators/gru_unit_op.h

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class GRUUnitKernel : public framework::OpKernel<T> {
110110
auto c = g.slice(c_offsets, extents); // output candidate
111111

112112
// calculate final output
113-
h.device(place) = u * (h_p - c) + c;
113+
h.device(place) = u * (c - h_p) + h_p;
114114
}
115115
};
116116

@@ -146,35 +146,27 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
146146
auto* weight_grad =
147147
context.Output<Tensor>(framework::GradVarName("Weight"));
148148
auto* bias_grad = context.Output<Tensor>(framework::GradVarName("Bias"));
149-
input_grad->mutable_data<T>(context.GetPlace());
150-
hidden_prev_grad->mutable_data<T>(context.GetPlace());
151-
weight_grad->mutable_data<T>(context.GetPlace());
152149
Tensor gate_grad;
153-
gate_grad.mutable_data<T>(input->dims(), context.GetPlace());
154150
Tensor reset_hidden_prev_grad;
155-
reset_hidden_prev_grad.mutable_data<T>(reset_hidden_prev->dims(),
156-
context.GetPlace());
157-
158-
int batch_size = input->dims()[0];
159-
int frame_size = hidden_prev->dims()[1];
160151

161152
const T* hidden_prev_data = hidden_prev->data<T>();
162-
T* hidden_prev_grad_data = hidden_prev_grad->data<T>();
163153
const T* weight_data = weight->data<T>();
164-
T* weight_grad_data = weight_grad->data<T>();
165-
T* gate_grad_data = gate_grad.data<T>();
154+
T* gate_grad_data =
155+
gate_grad.mutable_data<T>(input->dims(), context.GetPlace());
166156
const T* reset_hidden_prev_data = reset_hidden_prev->data<T>();
167-
T* reset_hidden_prev_grad_data = reset_hidden_prev_grad.data<T>();
157+
T* reset_hidden_prev_grad_data = reset_hidden_prev_grad.mutable_data<T>(
158+
reset_hidden_prev->dims(), context.GetPlace());
168159

169160
auto h_p = EigenMatrix<T>::From(*hidden_prev);
170161
auto g = EigenMatrix<T>::From(*gate);
171162
auto d_h = EigenMatrix<T>::From(*hidden_grad);
172-
auto d_x = EigenMatrix<T>::From(*input_grad);
173-
auto d_h_p = EigenMatrix<T>::From(*hidden_prev_grad);
174163
auto d_g = EigenMatrix<T>::From(gate_grad);
175164
auto d_r_h_p = EigenMatrix<T>::From(reset_hidden_prev_grad);
176165
auto place = context.GetEigenDevice<Place>();
177166

167+
int batch_size = input->dims()[0];
168+
int frame_size = hidden_prev->dims()[1];
169+
178170
Eigen::array<int, 2> extents({{batch_size, frame_size}});
179171
Eigen::array<int, 2> u_offsets({{0, 0}});
180172
auto u = g.slice(u_offsets, extents); // update gate
@@ -185,38 +177,52 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
185177

186178
// backward for unactivated update gate
187179
ActGradCompute(context.Attr<int>("gate_activation"), place, u, u,
188-
d_g.slice(u_offsets, extents), d_h * (h_p - c));
180+
d_g.slice(u_offsets, extents), d_h * (c - h_p));
189181
// backward for unactivated output candidate
190182
ActGradCompute(context.Attr<int>("activation"), place, c, c,
191-
d_g.slice(c_offsets, extents), d_h * (u.constant(T(1)) - u));
183+
d_g.slice(c_offsets, extents), d_h * u);
192184
// backward for reset_hidden_prev
193185
math::gemm<Place, T>(context.device_context(), false, true, batch_size,
194186
frame_size, frame_size, 1,
195187
gate_grad_data + frame_size * 2, frame_size * 3,
196188
weight_data + frame_size * frame_size * 2, frame_size,
197189
0, reset_hidden_prev_grad_data, frame_size);
198-
// backward for state_weight
199-
math::gemm<Place, T>(
200-
context.device_context(), true, false, frame_size, frame_size,
201-
batch_size, 1, reset_hidden_prev_data, frame_size,
202-
gate_grad_data + frame_size * 2, frame_size * 3, 0,
203-
weight_grad_data + frame_size * frame_size * 2, frame_size);
204190
// backward for unactivated reset gate
205191
ActGradCompute(context.Attr<int>("gate_activation"), place, r, r,
206192
d_g.slice(r_offsets, extents), d_r_h_p * h_p);
207-
// backward for update_gate_weight and reset_gate_weight
208-
math::gemm<Place, T>(context.device_context(), true, false, frame_size,
209-
frame_size * 2, batch_size, 1, hidden_prev_data,
210-
frame_size, gate_grad_data, frame_size * 3, 0,
211-
weight_grad_data, frame_size * 2);
193+
// backward for weight
194+
if (weight_grad) {
195+
T* weight_grad_data = weight_grad->mutable_data<T>(context.GetPlace());
196+
// backward for state_weight
197+
math::gemm<Place, T>(
198+
context.device_context(), true, false, frame_size, frame_size,
199+
batch_size, 1, reset_hidden_prev_data, frame_size,
200+
gate_grad_data + frame_size * 2, frame_size * 3, 0,
201+
weight_grad_data + frame_size * frame_size * 2, frame_size);
202+
203+
// backward for update_gate_weight and reset_gate_weight
204+
math::gemm<Place, T>(context.device_context(), true, false, frame_size,
205+
frame_size * 2, batch_size, 1, hidden_prev_data,
206+
frame_size, gate_grad_data, frame_size * 3, 0,
207+
weight_grad_data, frame_size * 2);
208+
}
212209
// backward for hidden_prev
213-
d_h_p.device(place) = d_r_h_p * r + d_h * u;
214-
math::gemm<Place, T>(context.device_context(), false, true, batch_size,
215-
frame_size, frame_size * 2, 1, gate_grad_data,
216-
frame_size * 3, weight_data, frame_size * 2, 1,
217-
hidden_prev_grad_data, frame_size);
210+
if (hidden_prev_grad) {
211+
T* hidden_prev_grad_data =
212+
hidden_prev_grad->mutable_data<T>(context.GetPlace());
213+
auto d_h_p = EigenMatrix<T>::From(*hidden_prev_grad);
214+
d_h_p.device(place) = d_r_h_p * r + d_h * (u.constant(T(1)) - u);
215+
math::gemm<Place, T>(context.device_context(), false, true, batch_size,
216+
frame_size, frame_size * 2, 1, gate_grad_data,
217+
frame_size * 3, weight_data, frame_size * 2, 1,
218+
hidden_prev_grad_data, frame_size);
219+
}
218220
// backward for input
219-
d_x.device(place) = d_g;
221+
if (input_grad) {
222+
input_grad->mutable_data<T>(context.GetPlace());
223+
auto d_x = EigenMatrix<T>::From(*input_grad);
224+
d_x.device(place) = d_g;
225+
}
220226
// backward for bias
221227
if (bias_grad) {
222228
bias_grad->mutable_data<T>(context.GetPlace());

python/paddle/v2/fluid/tests/test_gru_unit_op.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def relu(x):
2828

2929

3030
class TestGRUUnitOp(OpTest):
31-
batch_size = 3
32-
frame_size = 5
31+
batch_size = 5
32+
frame_size = 10
3333
activate = {
3434
GRUActivationType.identity: identity,
3535
GRUActivationType.sigmoid: sigmoid,
@@ -77,7 +77,7 @@ def set_outputs(self):
7777
c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) +
7878
g[:, frame_size * 2:])
7979
g = np.hstack((u_r, c))
80-
h = u * h_p + (1 - u) * c
80+
h = u * c + (1 - u) * h_p
8181
self.outputs = {
8282
'Gate': g.astype('float64'),
8383
'ResetHiddenPrev': r_h_p.astype('float64'),
@@ -92,10 +92,7 @@ def test_check_output(self):
9292
self.check_output()
9393

9494
def test_check_grad(self):
95-
self.check_grad(
96-
['Input', 'HiddenPrev', 'Weight'],
97-
['Hidden', 'ResetHiddenPrev', 'Gate'],
98-
max_relative_error=0.007)
95+
self.check_grad(['Input', 'HiddenPrev', 'Weight'], ['Hidden'])
9996

10097

10198
class TestGRUUnitOpWithBias(TestGRUUnitOp):
@@ -104,18 +101,20 @@ def set_inputs(self):
104101
frame_size = self.frame_size
105102
super(TestGRUUnitOpWithBias, self).set_inputs()
106103
self.inputs['Bias'] = np.random.uniform(
107-
-0.1, 0.1, (1, frame_size * 3)).astype('float32')
104+
-0.1, 0.1, (1, frame_size * 3)).astype('float64')
108105
self.attrs = {
109106
'activation': GRUActivationType.identity,
110107
'gate_activation': GRUActivationType.sigmoid
111108
}
112109

113110
def test_check_grad(self):
111+
self.check_grad(['Input', 'HiddenPrev', 'Weight', 'Bias'], ['Hidden'])
112+
113+
def test_check_grad_ingore_input(self):
114114
self.check_grad(
115-
['Input', 'HiddenPrev', 'Weight', 'Bias'], ['Hidden'],
116-
max_relative_error=0.007)
115+
['HiddenPrev', 'Weight', 'Bias'], ['Hidden'],
116+
no_grad_set=set('Input'))
117117

118118

119119
if __name__ == '__main__':
120-
exit(0) # FIXME(yuyang18): This unittest is not pass. Fix it later
121120
unittest.main()

0 commit comments

Comments
 (0)