Skip to content

Commit b103072

Browse files
committed
Fix data order of H0 in GRU Operator
1 parent 80de144 commit b103072

File tree

2 files changed

+44
-23
lines changed

2 files changed

+44
-23
lines changed

paddle/operators/gru_op.h

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#pragma once
1616

17+
#include "paddle/operators/lstm_op.h"
1718
#include "paddle/operators/math/gru_compute.h"
1819
#include "paddle/operators/math/math_function.h"
1920
#include "paddle/operators/math/sequence2batch.h"
@@ -24,20 +25,12 @@
2425
namespace paddle {
2526
namespace operators {
2627

27-
using Tensor = framework::Tensor;
28-
using LoDTensor = framework::LoDTensor;
29-
30-
template <typename T, int MajorType = Eigen::RowMajor,
31-
typename IndexType = Eigen::DenseIndex>
32-
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
33-
3428
template <typename Place, typename T>
3529
class GRUKernel : public framework::OpKernel<T> {
3630
public:
3731
void BatchCompute(const framework::ExecutionContext& context) const {
3832
auto* input = context.Input<LoDTensor>("Input");
3933
auto* h0 = context.Input<Tensor>("H0");
40-
const T* h0_data = h0 ? h0->data<T>() : nullptr;
4134
auto* weight = context.Input<Tensor>("Weight");
4235
const T* weight_data = weight->data<T>();
4336
auto* bias = context.Input<Tensor>("Bias");
@@ -74,7 +67,18 @@ class GRUKernel : public framework::OpKernel<T> {
7467
gru_value.gateWeight = const_cast<T*>(weight_data);
7568
gru_value.stateWeight =
7669
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
77-
gru_value.prevOutValue = const_cast<T*>(h0_data);
70+
Tensor ordered_h0;
71+
const size_t* order = batch_gate->lod()[2].data();
72+
if (h0) {
73+
// Since the batch computing for GRU reorders the input sequences
74+
// according to their length. The initialized cell state also needs
75+
// to reorder.
76+
ReorderInitState<Place, T>(context.device_context(), *h0, order,
77+
&ordered_h0, true);
78+
gru_value.prevOutValue = ordered_h0.data<T>();
79+
} else {
80+
gru_value.prevOutValue = nullptr;
81+
}
7882
auto batch_starts = batch_gate->lod()[0];
7983
size_t num_batch = batch_starts.size() - 1;
8084
for (size_t n = 0; n < num_batch; n++) {
@@ -110,7 +114,6 @@ class GRUGradKernel : public framework::OpKernel<T> {
110114
public:
111115
void BatchCompute(const framework::ExecutionContext& context) const {
112116
auto* h0 = context.Input<Tensor>("H0");
113-
const T* h0_data = h0 ? h0->data<T>() : nullptr;
114117
auto* weight = context.Input<Tensor>("Weight");
115118
const T* weight_data = weight->data<T>();
116119
auto* batch_gate = context.Input<LoDTensor>("BatchGate");
@@ -143,6 +146,16 @@ class GRUGradKernel : public framework::OpKernel<T> {
143146
zero(context.device_context(), &batch_reset_hidden_prev_grad,
144147
static_cast<T>(0.0));
145148

149+
Tensor ordered_h0, ordered_h0_grad;
150+
const size_t* order = batch_gate->lod()[2].data();
151+
if (h0) {
152+
ReorderInitState<Place, T>(context.device_context(), *h0, order,
153+
&ordered_h0, true);
154+
}
155+
if (h0_grad) {
156+
ordered_h0_grad.mutable_data<T>(h0_grad->dims(), context.GetPlace());
157+
}
158+
146159
bool is_reverse = context.Attr<bool>("is_reverse");
147160
batch_hidden_grad.set_lod(batch_hidden->lod());
148161
to_batch(context.device_context(), *hidden_grad, batch_hidden_grad, false,
@@ -185,11 +198,13 @@ class GRUGradKernel : public framework::OpKernel<T> {
185198
batch_reset_hidden_prev_grad.Slice(bstart, bend);
186199
gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data<T>();
187200
if (n == 0) {
188-
gru_value.prevOutValue = const_cast<T*>(h0_data);
189-
if (h0_grad) {
190-
T* h0_grad_data = h0_grad->mutable_data<T>(context.GetPlace());
191-
zero(context.device_context(), h0_grad, static_cast<T>(0.0));
192-
gru_grad.prevOutGrad = h0_grad_data;
201+
if (h0) {
202+
gru_value.prevOutValue = ordered_h0.data<T>();
203+
} else {
204+
gru_value.prevOutValue = nullptr;
205+
}
206+
if (h0 && h0_grad) {
207+
gru_grad.prevOutGrad = ordered_h0_grad.data<T>();
193208
} else {
194209
gru_grad.prevOutGrad = nullptr;
195210
}
@@ -220,6 +235,10 @@ class GRUGradKernel : public framework::OpKernel<T> {
220235
auto place = context.GetEigenDevice<Place>();
221236
d_b.device(place) = d_g.sum(Eigen::array<int, 1>({{0}}));
222237
}
238+
if (h0 && h0_grad) {
239+
ReorderInitState<Place, T>(context.device_context(), ordered_h0_grad,
240+
order, h0_grad, false);
241+
}
223242
}
224243

225244
void Compute(const framework::ExecutionContext& context) const override {

python/paddle/v2/framework/tests/test_gru_op.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77

88
class TestGRUOp(OpTest):
9-
batch_size = 9
9+
lod = [[0, 2, 6, 9]]
10+
batch_size = lod[0][-1]
1011
frame_size = 5
1112
activate = {
1213
'identity': identity,
@@ -35,7 +36,7 @@ def seq_to_batch(lod, is_reverse):
3536
seq_starts[sorted_seqs[i]] + batch_idx)
3637
idx_in_seq.append(idx)
3738
idx_in_seq_list.append(idx_in_seq)
38-
return idx_in_seq_list
39+
return idx_in_seq_list, sorted_seqs
3940

4041
def gru_step(self, x, h_p, w, b):
4142
batch_size = x.shape[0]
@@ -66,8 +67,8 @@ def gru(self):
6667
batch_hidden = self.outputs['BatchHidden']
6768
hidden = self.outputs['Hidden']
6869
idx_in_seq_list = self.idx_in_seq_list
69-
h_p = self.inputs['H0'] if self.inputs.has_key('H0') else np.zeros(
70-
(len(idx_in_seq_list[0]), self.frame_size))
70+
h_p = self.inputs['H0'][self.sorted_seqs] if self.inputs.has_key(
71+
'H0') else np.zeros((len(idx_in_seq_list[0]), self.frame_size))
7172
num_batch = len(idx_in_seq_list)
7273
end_idx = 0
7374
for batch_idx in range(num_batch):
@@ -84,8 +85,9 @@ def gru(self):
8485
return batch_gate, batch_reset_hidden_prev, hidden
8586

8687
def set_data(self):
87-
lod = [[0, 2, 6, self.batch_size]]
88-
self.idx_in_seq_list = self.seq_to_batch(lod, self.is_reverse)
88+
lod = self.lod
89+
self.idx_in_seq_list, self.sorted_seqs = self.seq_to_batch(
90+
lod, self.is_reverse)
8991
batch_size = self.batch_size
9092
frame_size = self.frame_size
9193
input = np.random.rand(batch_size, frame_size * 3).astype('float64')
@@ -146,8 +148,8 @@ class TestGRUOpReverse(TestGRUOp):
146148
def set_confs(self):
147149
self.is_reverse = True
148150
self.attrs = {
149-
'activation': 'identity',
150-
'gate_activation': 'sigmoid',
151+
'activation': 'tanh',
152+
'gate_activation': 'tanh',
151153
'is_reverse': self.is_reverse
152154
}
153155

0 commit comments

Comments
 (0)