|
14 | 14 |
|
15 | 15 | #pragma once
|
16 | 16 |
|
| 17 | +#include "paddle/operators/lstm_op.h" |
17 | 18 | #include "paddle/operators/math/gru_compute.h"
|
18 | 19 | #include "paddle/operators/math/math_function.h"
|
19 | 20 | #include "paddle/operators/math/sequence2batch.h"
|
|
24 | 25 | namespace paddle {
|
25 | 26 | namespace operators {
|
26 | 27 |
|
27 |
| -using Tensor = framework::Tensor; |
28 |
| -using LoDTensor = framework::LoDTensor; |
29 |
| - |
30 |
| -template <typename T, int MajorType = Eigen::RowMajor, |
31 |
| - typename IndexType = Eigen::DenseIndex> |
32 |
| -using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; |
33 |
| - |
34 | 28 | template <typename Place, typename T>
|
35 | 29 | class GRUKernel : public framework::OpKernel<T> {
|
36 | 30 | public:
|
37 | 31 | void BatchCompute(const framework::ExecutionContext& context) const {
|
38 | 32 | auto* input = context.Input<LoDTensor>("Input");
|
39 | 33 | auto* h0 = context.Input<Tensor>("H0");
|
40 |
| - const T* h0_data = h0 ? h0->data<T>() : nullptr; |
41 | 34 | auto* weight = context.Input<Tensor>("Weight");
|
42 | 35 | const T* weight_data = weight->data<T>();
|
43 | 36 | auto* bias = context.Input<Tensor>("Bias");
|
@@ -74,7 +67,18 @@ class GRUKernel : public framework::OpKernel<T> {
|
74 | 67 | gru_value.gateWeight = const_cast<T*>(weight_data);
|
75 | 68 | gru_value.stateWeight =
|
76 | 69 | const_cast<T*>(weight_data + 2 * frame_size * frame_size);
|
77 |
| - gru_value.prevOutValue = const_cast<T*>(h0_data); |
| 70 | + Tensor ordered_h0; |
| 71 | + const size_t* order = batch_gate->lod()[2].data(); |
| 72 | + if (h0) { |
| 73 | + // Since the batch computing for GRU reorders the input sequences |
| 74 | + // according to their length. The initialized cell state also needs |
| 75 | + // to reorder. |
| 76 | + ReorderInitState<Place, T>(context.device_context(), *h0, order, |
| 77 | + &ordered_h0, true); |
| 78 | + gru_value.prevOutValue = ordered_h0.data<T>(); |
| 79 | + } else { |
| 80 | + gru_value.prevOutValue = nullptr; |
| 81 | + } |
78 | 82 | auto batch_starts = batch_gate->lod()[0];
|
79 | 83 | size_t num_batch = batch_starts.size() - 1;
|
80 | 84 | for (size_t n = 0; n < num_batch; n++) {
|
@@ -110,7 +114,6 @@ class GRUGradKernel : public framework::OpKernel<T> {
|
110 | 114 | public:
|
111 | 115 | void BatchCompute(const framework::ExecutionContext& context) const {
|
112 | 116 | auto* h0 = context.Input<Tensor>("H0");
|
113 |
| - const T* h0_data = h0 ? h0->data<T>() : nullptr; |
114 | 117 | auto* weight = context.Input<Tensor>("Weight");
|
115 | 118 | const T* weight_data = weight->data<T>();
|
116 | 119 | auto* batch_gate = context.Input<LoDTensor>("BatchGate");
|
@@ -143,6 +146,16 @@ class GRUGradKernel : public framework::OpKernel<T> {
|
143 | 146 | zero(context.device_context(), &batch_reset_hidden_prev_grad,
|
144 | 147 | static_cast<T>(0.0));
|
145 | 148 |
|
| 149 | + Tensor ordered_h0, ordered_h0_grad; |
| 150 | + const size_t* order = batch_gate->lod()[2].data(); |
| 151 | + if (h0) { |
| 152 | + ReorderInitState<Place, T>(context.device_context(), *h0, order, |
| 153 | + &ordered_h0, true); |
| 154 | + } |
| 155 | + if (h0_grad) { |
| 156 | + ordered_h0_grad.mutable_data<T>(h0_grad->dims(), context.GetPlace()); |
| 157 | + } |
| 158 | + |
146 | 159 | bool is_reverse = context.Attr<bool>("is_reverse");
|
147 | 160 | batch_hidden_grad.set_lod(batch_hidden->lod());
|
148 | 161 | to_batch(context.device_context(), *hidden_grad, batch_hidden_grad, false,
|
@@ -185,11 +198,13 @@ class GRUGradKernel : public framework::OpKernel<T> {
|
185 | 198 | batch_reset_hidden_prev_grad.Slice(bstart, bend);
|
186 | 199 | gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data<T>();
|
187 | 200 | if (n == 0) {
|
188 |
| - gru_value.prevOutValue = const_cast<T*>(h0_data); |
189 |
| - if (h0_grad) { |
190 |
| - T* h0_grad_data = h0_grad->mutable_data<T>(context.GetPlace()); |
191 |
| - zero(context.device_context(), h0_grad, static_cast<T>(0.0)); |
192 |
| - gru_grad.prevOutGrad = h0_grad_data; |
| 201 | + if (h0) { |
| 202 | + gru_value.prevOutValue = ordered_h0.data<T>(); |
| 203 | + } else { |
| 204 | + gru_value.prevOutValue = nullptr; |
| 205 | + } |
| 206 | + if (h0 && h0_grad) { |
| 207 | + gru_grad.prevOutGrad = ordered_h0_grad.data<T>(); |
193 | 208 | } else {
|
194 | 209 | gru_grad.prevOutGrad = nullptr;
|
195 | 210 | }
|
@@ -220,6 +235,10 @@ class GRUGradKernel : public framework::OpKernel<T> {
|
220 | 235 | auto place = context.GetEigenDevice<Place>();
|
221 | 236 | d_b.device(place) = d_g.sum(Eigen::array<int, 1>({{0}}));
|
222 | 237 | }
|
| 238 | + if (h0 && h0_grad) { |
| 239 | + ReorderInitState<Place, T>(context.device_context(), ordered_h0_grad, |
| 240 | + order, h0_grad, false); |
| 241 | + } |
223 | 242 | }
|
224 | 243 |
|
225 | 244 | void Compute(const framework::ExecutionContext& context) const override {
|
|
0 commit comments