Skip to content

Commit 7a57b3b

Browse files
authored
Merge pull request #5623 from guoshengCS/fix-H0-GRUOp
Fix data order of H0 in GRU Operator
2 parents 093c526 + aa83e19 commit 7a57b3b

File tree

2 files changed

+49
-19
lines changed

2 files changed

+49
-19
lines changed

paddle/operators/gru_op.h

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,24 @@
2424
namespace paddle {
2525
namespace operators {
2626

27-
using Tensor = framework::Tensor;
2827
using LoDTensor = framework::LoDTensor;
28+
using Tensor = framework::Tensor;
29+
30+
template <typename Place, typename T>
31+
inline void ReorderInitState(const platform::DeviceContext& ctx,
32+
const framework::Tensor& src, const size_t* index,
33+
framework::Tensor* dst, bool indexed_src) {
34+
math::CopyMatrixRowsFunctor<Place, T> row_shuffle;
35+
dst->mutable_data<T>(src.dims(), ctx.GetPlace());
36+
row_shuffle(ctx, src, index, *dst, indexed_src);
37+
}
2938

3039
template <typename Place, typename T>
3140
class GRUKernel : public framework::OpKernel<T> {
3241
public:
3342
void BatchCompute(const framework::ExecutionContext& context) const {
3443
auto* input = context.Input<LoDTensor>("Input");
3544
auto* h0 = context.Input<Tensor>("H0");
36-
const T* h0_data = h0 ? h0->data<T>() : nullptr;
3745
auto* weight = context.Input<Tensor>("Weight");
3846
const T* weight_data = weight->data<T>();
3947
auto* bias = context.Input<Tensor>("Bias");
@@ -66,7 +74,18 @@ class GRUKernel : public framework::OpKernel<T> {
6674
gru_value.gateWeight = const_cast<T*>(weight_data);
6775
gru_value.stateWeight =
6876
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
69-
gru_value.prevOutValue = const_cast<T*>(h0_data);
77+
Tensor ordered_h0;
78+
const size_t* order = batch_gate->lod()[2].data();
79+
if (h0) {
80+
// Since the batch computing for GRU reorders the input sequences
81+
// according to their length. The initialized cell state also needs
82+
// to reorder.
83+
ReorderInitState<Place, T>(context.device_context(), *h0, order,
84+
&ordered_h0, true);
85+
gru_value.prevOutValue = ordered_h0.data<T>();
86+
} else {
87+
gru_value.prevOutValue = nullptr;
88+
}
7089
auto batch_starts = batch_gate->lod()[0];
7190
size_t num_batch = batch_starts.size() - 1;
7291
for (size_t n = 0; n < num_batch; n++) {
@@ -102,7 +121,6 @@ class GRUGradKernel : public framework::OpKernel<T> {
102121
public:
103122
void BatchCompute(const framework::ExecutionContext& context) const {
104123
auto* h0 = context.Input<Tensor>("H0");
105-
const T* h0_data = h0 ? h0->data<T>() : nullptr;
106124
auto* weight = context.Input<Tensor>("Weight");
107125
const T* weight_data = weight->data<T>();
108126
auto* batch_gate = context.Input<LoDTensor>("BatchGate");
@@ -135,6 +153,17 @@ class GRUGradKernel : public framework::OpKernel<T> {
135153
zero(dev_ctx, &batch_gate_grad, static_cast<T>(0.0));
136154
zero(dev_ctx, &batch_reset_hidden_prev_grad, static_cast<T>(0.0));
137155

156+
Tensor ordered_h0, ordered_h0_grad;
157+
const size_t* order = batch_gate->lod()[2].data();
158+
if (h0) {
159+
ReorderInitState<Place, T>(context.device_context(), *h0, order,
160+
&ordered_h0, true);
161+
}
162+
if (h0_grad) {
163+
ordered_h0_grad.mutable_data<T>(h0_grad->dims(), context.GetPlace());
164+
zero(context.device_context(), &ordered_h0_grad, static_cast<T>(0.0));
165+
}
166+
138167
bool is_reverse = context.Attr<bool>("is_reverse");
139168
batch_hidden_grad.set_lod(batch_hidden->lod());
140169
to_batch(dev_ctx, *hidden_grad, batch_hidden_grad, false, is_reverse);
@@ -176,14 +205,9 @@ class GRUGradKernel : public framework::OpKernel<T> {
176205
batch_reset_hidden_prev_grad.Slice(bstart, bend);
177206
gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data<T>();
178207
if (n == 0) {
179-
gru_value.prevOutValue = const_cast<T*>(h0_data);
180-
if (h0_grad) {
181-
T* h0_grad_data = h0_grad->mutable_data<T>(context.GetPlace());
182-
zero(dev_ctx, h0_grad, static_cast<T>(0.0));
183-
gru_grad.prevOutGrad = h0_grad_data;
184-
} else {
185-
gru_grad.prevOutGrad = nullptr;
186-
}
208+
gru_value.prevOutValue = h0 ? ordered_h0.data<T>() : nullptr;
209+
gru_grad.prevOutGrad =
210+
h0 && h0_grad ? ordered_h0_grad.data<T>() : nullptr;
187211
} else {
188212
int bstart_pre = static_cast<int>(batch_starts[n - 1]);
189213
Tensor hidden_prev_t = batch_hidden->Slice(bstart_pre, bstart);
@@ -208,6 +232,10 @@ class GRUGradKernel : public framework::OpKernel<T> {
208232
math::ColwiseSum<Place, T> col_sum;
209233
col_sum(dev_ctx, batch_gate_grad, bias_grad);
210234
}
235+
if (h0 && h0_grad) {
236+
ReorderInitState<Place, T>(context.device_context(), ordered_h0_grad,
237+
order, h0_grad, false);
238+
}
211239
}
212240

213241
void Compute(const framework::ExecutionContext& context) const override {

python/paddle/v2/fluid/tests/test_gru_op.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77

88
class TestGRUOp(OpTest):
9-
batch_size = 9
9+
lod = [[0, 2, 6, 9]]
10+
batch_size = lod[0][-1]
1011
frame_size = 5
1112
activate = {
1213
'identity': identity,
@@ -35,7 +36,7 @@ def seq_to_batch(lod, is_reverse):
3536
seq_starts[sorted_seqs[i]] + batch_idx)
3637
idx_in_seq.append(idx)
3738
idx_in_seq_list.append(idx_in_seq)
38-
return idx_in_seq_list
39+
return idx_in_seq_list, sorted_seqs
3940

4041
def gru_step(self, x, h_p, w, b):
4142
batch_size = x.shape[0]
@@ -66,8 +67,8 @@ def gru(self):
6667
batch_hidden = self.outputs['BatchHidden']
6768
hidden = self.outputs['Hidden']
6869
idx_in_seq_list = self.idx_in_seq_list
69-
h_p = self.inputs['H0'] if self.inputs.has_key('H0') else np.zeros(
70-
(len(idx_in_seq_list[0]), self.frame_size))
70+
h_p = self.inputs['H0'][self.sorted_seqs] if self.inputs.has_key(
71+
'H0') else np.zeros((len(idx_in_seq_list[0]), self.frame_size))
7172
num_batch = len(idx_in_seq_list)
7273
end_idx = 0
7374
for batch_idx in range(num_batch):
@@ -84,8 +85,9 @@ def gru(self):
8485
return batch_gate, batch_reset_hidden_prev, hidden
8586

8687
def set_data(self):
87-
lod = [[0, 2, 6, self.batch_size]]
88-
self.idx_in_seq_list = self.seq_to_batch(lod, self.is_reverse)
88+
lod = self.lod
89+
self.idx_in_seq_list, self.sorted_seqs = self.seq_to_batch(
90+
lod, self.is_reverse)
8991
batch_size = self.batch_size
9092
frame_size = self.frame_size
9193
input = np.random.rand(batch_size, frame_size * 3).astype('float64')
@@ -146,7 +148,7 @@ class TestGRUOpReverse(TestGRUOp):
146148
def set_confs(self):
147149
self.is_reverse = True
148150
self.attrs = {
149-
'activation': 'identity',
151+
'activation': 'tanh',
150152
'gate_activation': 'sigmoid',
151153
'is_reverse': self.is_reverse
152154
}

0 commit comments

Comments
 (0)