|
14 | 14 |
|
15 | 15 | #pragma once
|
16 | 16 |
|
17 |
| -#include "paddle/operators/lstm_op.h" |
18 | 17 | #include "paddle/operators/math/gru_compute.h"
|
19 | 18 | #include "paddle/operators/math/math_function.h"
|
20 | 19 | #include "paddle/operators/math/sequence2batch.h"
|
|
25 | 24 | namespace paddle {
|
26 | 25 | namespace operators {
|
27 | 26 |
|
| 27 | +using LoDTensor = framework::LoDTensor; |
| 28 | +using Tensor = framework::Tensor; |
| 29 | + |
| 30 | +template <typename Place, typename T> |
| 31 | +inline void ReorderInitState(const platform::DeviceContext& ctx, |
| 32 | + const framework::Tensor& src, const size_t* index, |
| 33 | + framework::Tensor* dst, bool indexed_src) { |
| 34 | + math::CopyMatrixRowsFunctor<Place, T> row_shuffle; |
| 35 | + dst->mutable_data<T>(src.dims(), ctx.GetPlace()); |
| 36 | + row_shuffle(ctx, src, index, *dst, indexed_src); |
| 37 | +} |
| 38 | + |
28 | 39 | template <typename Place, typename T>
|
29 | 40 | class GRUKernel : public framework::OpKernel<T> {
|
30 | 41 | public:
|
@@ -194,16 +205,9 @@ class GRUGradKernel : public framework::OpKernel<T> {
|
194 | 205 | batch_reset_hidden_prev_grad.Slice(bstart, bend);
|
195 | 206 | gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data<T>();
|
196 | 207 | if (n == 0) {
|
197 |
| - if (h0) { |
198 |
| - gru_value.prevOutValue = ordered_h0.data<T>(); |
199 |
| - } else { |
200 |
| - gru_value.prevOutValue = nullptr; |
201 |
| - } |
202 |
| - if (h0 && h0_grad) { |
203 |
| - gru_grad.prevOutGrad = ordered_h0_grad.data<T>(); |
204 |
| - } else { |
205 |
| - gru_grad.prevOutGrad = nullptr; |
206 |
| - } |
| 208 | + gru_value.prevOutValue = h0 ? ordered_h0.data<T>() : nullptr; |
| 209 | + gru_grad.prevOutGrad = |
| 210 | + h0 && h0_grad ? ordered_h0_grad.data<T>() : nullptr; |
207 | 211 | } else {
|
208 | 212 | int bstart_pre = static_cast<int>(batch_starts[n - 1]);
|
209 | 213 | Tensor hidden_prev_t = batch_hidden->Slice(bstart_pre, bstart);
|
|
0 commit comments