Skip to content

Commit aa83e19

Browse files
committed
Remove lstm_op including in gru_op
1 parent afd1f36 commit aa83e19

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

paddle/operators/gru_op.h

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
#pragma once
1616

17-
#include "paddle/operators/lstm_op.h"
1817
#include "paddle/operators/math/gru_compute.h"
1918
#include "paddle/operators/math/math_function.h"
2019
#include "paddle/operators/math/sequence2batch.h"
@@ -25,6 +24,18 @@
2524
namespace paddle {
2625
namespace operators {
2726

27+
using LoDTensor = framework::LoDTensor;
28+
using Tensor = framework::Tensor;
29+
30+
template <typename Place, typename T>
31+
inline void ReorderInitState(const platform::DeviceContext& ctx,
32+
const framework::Tensor& src, const size_t* index,
33+
framework::Tensor* dst, bool indexed_src) {
34+
math::CopyMatrixRowsFunctor<Place, T> row_shuffle;
35+
dst->mutable_data<T>(src.dims(), ctx.GetPlace());
36+
row_shuffle(ctx, src, index, *dst, indexed_src);
37+
}
38+
2839
template <typename Place, typename T>
2940
class GRUKernel : public framework::OpKernel<T> {
3041
public:
@@ -194,16 +205,9 @@ class GRUGradKernel : public framework::OpKernel<T> {
194205
batch_reset_hidden_prev_grad.Slice(bstart, bend);
195206
gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data<T>();
196207
if (n == 0) {
197-
if (h0) {
198-
gru_value.prevOutValue = ordered_h0.data<T>();
199-
} else {
200-
gru_value.prevOutValue = nullptr;
201-
}
202-
if (h0 && h0_grad) {
203-
gru_grad.prevOutGrad = ordered_h0_grad.data<T>();
204-
} else {
205-
gru_grad.prevOutGrad = nullptr;
206-
}
208+
gru_value.prevOutValue = h0 ? ordered_h0.data<T>() : nullptr;
209+
gru_grad.prevOutGrad =
210+
h0 && h0_grad ? ordered_h0_grad.data<T>() : nullptr;
207211
} else {
208212
int bstart_pre = static_cast<int>(batch_starts[n - 1]);
209213
Tensor hidden_prev_t = batch_hidden->Slice(bstart_pre, bstart);

0 commit comments

Comments
 (0)