24
24
namespace paddle {
25
25
namespace operators {
26
26
27
- using Tensor = framework::Tensor;
28
27
using LoDTensor = framework::LoDTensor;
28
+ using Tensor = framework::Tensor;
29
+
30
+ template <typename Place, typename T>
31
+ inline void ReorderInitState (const platform::DeviceContext& ctx,
32
+ const framework::Tensor& src, const size_t * index,
33
+ framework::Tensor* dst, bool indexed_src) {
34
+ math::CopyMatrixRowsFunctor<Place, T> row_shuffle;
35
+ dst->mutable_data <T>(src.dims (), ctx.GetPlace ());
36
+ row_shuffle (ctx, src, index, *dst, indexed_src);
37
+ }
29
38
30
39
template <typename Place, typename T>
31
40
class GRUKernel : public framework ::OpKernel<T> {
32
41
public:
33
42
void BatchCompute (const framework::ExecutionContext& context) const {
34
43
auto * input = context.Input <LoDTensor>(" Input" );
35
44
auto * h0 = context.Input <Tensor>(" H0" );
36
- const T* h0_data = h0 ? h0->data <T>() : nullptr ;
37
45
auto * weight = context.Input <Tensor>(" Weight" );
38
46
const T* weight_data = weight->data <T>();
39
47
auto * bias = context.Input <Tensor>(" Bias" );
@@ -66,7 +74,18 @@ class GRUKernel : public framework::OpKernel<T> {
66
74
gru_value.gateWeight = const_cast <T*>(weight_data);
67
75
gru_value.stateWeight =
68
76
const_cast <T*>(weight_data + 2 * frame_size * frame_size);
69
- gru_value.prevOutValue = const_cast <T*>(h0_data);
77
+ Tensor ordered_h0;
78
+ const size_t * order = batch_gate->lod ()[2 ].data ();
79
+ if (h0) {
80
+ // Since the batch computing for GRU reorders the input sequences
81
+ // according to their length. The initialized cell state also needs
82
+ // to reorder.
83
+ ReorderInitState<Place, T>(context.device_context (), *h0, order,
84
+ &ordered_h0, true );
85
+ gru_value.prevOutValue = ordered_h0.data <T>();
86
+ } else {
87
+ gru_value.prevOutValue = nullptr ;
88
+ }
70
89
auto batch_starts = batch_gate->lod ()[0 ];
71
90
size_t num_batch = batch_starts.size () - 1 ;
72
91
for (size_t n = 0 ; n < num_batch; n++) {
@@ -102,7 +121,6 @@ class GRUGradKernel : public framework::OpKernel<T> {
102
121
public:
103
122
void BatchCompute (const framework::ExecutionContext& context) const {
104
123
auto * h0 = context.Input <Tensor>(" H0" );
105
- const T* h0_data = h0 ? h0->data <T>() : nullptr ;
106
124
auto * weight = context.Input <Tensor>(" Weight" );
107
125
const T* weight_data = weight->data <T>();
108
126
auto * batch_gate = context.Input <LoDTensor>(" BatchGate" );
@@ -135,6 +153,17 @@ class GRUGradKernel : public framework::OpKernel<T> {
135
153
zero (dev_ctx, &batch_gate_grad, static_cast <T>(0.0 ));
136
154
zero (dev_ctx, &batch_reset_hidden_prev_grad, static_cast <T>(0.0 ));
137
155
156
+ Tensor ordered_h0, ordered_h0_grad;
157
+ const size_t * order = batch_gate->lod ()[2 ].data ();
158
+ if (h0) {
159
+ ReorderInitState<Place, T>(context.device_context (), *h0, order,
160
+ &ordered_h0, true );
161
+ }
162
+ if (h0_grad) {
163
+ ordered_h0_grad.mutable_data <T>(h0_grad->dims (), context.GetPlace ());
164
+ zero (context.device_context (), &ordered_h0_grad, static_cast <T>(0.0 ));
165
+ }
166
+
138
167
bool is_reverse = context.Attr <bool >(" is_reverse" );
139
168
batch_hidden_grad.set_lod (batch_hidden->lod ());
140
169
to_batch (dev_ctx, *hidden_grad, batch_hidden_grad, false , is_reverse);
@@ -176,14 +205,9 @@ class GRUGradKernel : public framework::OpKernel<T> {
176
205
batch_reset_hidden_prev_grad.Slice (bstart, bend);
177
206
gru_grad.resetOutputGrad = reset_hidden_prev_grad_t .data <T>();
178
207
if (n == 0 ) {
179
- gru_value.prevOutValue = const_cast <T*>(h0_data);
180
- if (h0_grad) {
181
- T* h0_grad_data = h0_grad->mutable_data <T>(context.GetPlace ());
182
- zero (dev_ctx, h0_grad, static_cast <T>(0.0 ));
183
- gru_grad.prevOutGrad = h0_grad_data;
184
- } else {
185
- gru_grad.prevOutGrad = nullptr ;
186
- }
208
+ gru_value.prevOutValue = h0 ? ordered_h0.data <T>() : nullptr ;
209
+ gru_grad.prevOutGrad =
210
+ h0 && h0_grad ? ordered_h0_grad.data <T>() : nullptr ;
187
211
} else {
188
212
int bstart_pre = static_cast <int >(batch_starts[n - 1 ]);
189
213
Tensor hidden_prev_t = batch_hidden->Slice (bstart_pre, bstart);
@@ -208,6 +232,10 @@ class GRUGradKernel : public framework::OpKernel<T> {
208
232
math::ColwiseSum<Place, T> col_sum;
209
233
col_sum (dev_ctx, batch_gate_grad, bias_grad);
210
234
}
235
+ if (h0 && h0_grad) {
236
+ ReorderInitState<Place, T>(context.device_context (), ordered_h0_grad,
237
+ order, h0_grad, false );
238
+ }
211
239
}
212
240
213
241
void Compute (const framework::ExecutionContext& context) const override {
0 commit comments