Skip to content

Commit 89d6d69

Browse files
authored
Merge pull request #12781 from tensor-tang/feature/op/fusion_gru
add fusion gru
2 parents d941192 + d9bf73f commit 89d6d69

File tree

4 files changed

+610
-103
lines changed

4 files changed

+610
-103
lines changed
Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/fusion_gru_op.h"
16+
#include <string>
17+
#include "paddle/fluid/framework/eigen.h"
18+
#include "paddle/fluid/operators/math/blas.h"
19+
#include "paddle/fluid/operators/math/detail/activation_functions.h"
20+
#include "paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
21+
#include "paddle/fluid/operators/math/detail/gru_kernel.h"
22+
#include "paddle/fluid/operators/math/fc_compute.h"
23+
#include "paddle/fluid/operators/math/gru_compute.h"
24+
#include "paddle/fluid/operators/math/math_function.h"
25+
#include "paddle/fluid/operators/math/sequence2batch.h"
26+
27+
namespace paddle {
28+
namespace operators {
29+
30+
void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
31+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of GRU should not be null.");
32+
PADDLE_ENFORCE(ctx->HasInput("WeightX"),
33+
"Input(WeightX) of GRU should not be null.");
34+
PADDLE_ENFORCE(ctx->HasInput("WeightH"),
35+
"Input(WeightH) of GRU should not be null.");
36+
37+
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of GRU should not be null.");
38+
PADDLE_ENFORCE(ctx->HasOutput("BatchedGate"),
39+
"Output(BatchedGate) of GRU should not be null.");
40+
PADDLE_ENFORCE(ctx->HasOutput("BatchResetHiddenPrev"),
41+
"Output(BatchResetHiddenPrev) of GRU should not be null.");
42+
PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"),
43+
"Output(BatchedHidden) of GRU should not be null.");
44+
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
45+
"Output(Hidden) of GRU should not be null.");
46+
47+
auto x_dims = ctx->GetInputDim("X");
48+
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
49+
50+
auto wx_dims = ctx->GetInputDim("WeightX");
51+
PADDLE_ENFORCE_EQ(wx_dims.size(), 2,
52+
"The rank of Input(WeightX) should be 2.");
53+
PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1],
54+
"The first dimension of Input(WeightX) "
55+
"should be %d.",
56+
x_dims[1]);
57+
58+
int frame_size = wx_dims[1] / 3;
59+
auto wh_dims = ctx->GetInputDim("WeightH");
60+
PADDLE_ENFORCE_EQ(wh_dims.size(), 2,
61+
"The rank of Input(WeightH) should be 2.");
62+
PADDLE_ENFORCE_EQ(wh_dims[0], frame_size,
63+
"The first dimension of Input(WeightH) "
64+
"should be %d.",
65+
frame_size);
66+
PADDLE_ENFORCE_EQ(wh_dims[1], 3 * frame_size,
67+
"The second dimension of Input(WeightH) "
68+
"should be 3 * %d.",
69+
frame_size);
70+
71+
if (ctx->HasInput("H0")) {
72+
auto h0_dims = ctx->GetInputDim("H0");
73+
PADDLE_ENFORCE_EQ(h0_dims[1], frame_size,
74+
"The width of H0 must be equal to frame_size.");
75+
}
76+
if (ctx->HasInput("Bias")) {
77+
auto b_dims = ctx->GetInputDim("Bias");
78+
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
79+
PADDLE_ENFORCE_EQ(b_dims[0], 1,
80+
"The first dimension of Input(Bias) should be 1.");
81+
PADDLE_ENFORCE_EQ(b_dims[1], frame_size * 3,
82+
"The shape of Bias must be [1, frame_size * 3].");
83+
}
84+
framework::DDim out_dims({x_dims[0], frame_size});
85+
ctx->SetOutputDim("Hidden", out_dims);
86+
ctx->SetOutputDim("BatchedGate", {x_dims[0], wx_dims[1]});
87+
ctx->SetOutputDim("BatchedHidden", out_dims);
88+
ctx->SetOutputDim("BatchResetHiddenPrev", out_dims);
89+
ctx->ShareLoD("X", "Hidden");
90+
91+
int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
92+
ctx->SetOutputDim("XX", {x_dims[0], xx_width});
93+
ctx->ShareLoD("X", "XX");
94+
}
95+
96+
framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
97+
const framework::ExecutionContext& ctx) const {
98+
return framework::OpKernelType(
99+
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
100+
ctx.device_context());
101+
}
102+
103+
void FusionGRUOpMaker::Make() {
104+
AddInput("X",
105+
"(LoDTensor) the input is a LodTensor, which support "
106+
"variable-time length input sequence. The underlying tensor in "
107+
"this LoDTensor is a matrix with shape (T X M), where T is the "
108+
"total time steps in this mini-batch, M is the dim size of x.");
109+
AddInput("H0",
110+
"(Tensor, optional) The initial hidden state is an optional "
111+
"input. This is a tensor with shape (N x D), where N is the "
112+
"batch size, D is the hidden size.")
113+
.AsDispensable();
114+
AddInput("WeightX",
115+
"(Tensor) The FC weight with shape (M x 3D),"
116+
"where M is the dim size of x, D is the hidden size. ");
117+
AddInput("WeightH",
118+
"(Tensor) (D x 3D) Same as GRUOp, where D is the hidden size. ");
119+
AddInput("Bias",
120+
"(Tensor, optional) (1 x 3D)."
121+
"Almost same as GRUOp."
122+
"Note: if have FC bias it should be added on this bias.")
123+
.AsDispensable();
124+
AddOutput("XX",
125+
"(LoDTensor) the result after X * WeightX (size is T x 4D)"
126+
" or batched_X (size is T x M), this will be automatically chosen,"
127+
" where T is the total time steps in this mini-batch,"
128+
" D is the hidden size, M is the dim size of x input.")
129+
.AsIntermediate();
130+
AddOutput("BatchedGate", "(LoDTensor) Same as GRUOp").AsIntermediate();
131+
AddOutput("BatchResetHiddenPrev", "(LoDTensor) (T x 3D) Same as GRUOp.")
132+
.AsIntermediate();
133+
AddOutput("BatchedHidden", "(LoDTensor) (T X D) Same as GRUOp.")
134+
.AsIntermediate();
135+
AddOutput("Hidden", "(LoDTensor) (T x D) Same as GRUOp");
136+
AddAttr<std::string>("activation",
137+
"(string, default tanh) "
138+
"The activation type used for output candidate {h}_t.")
139+
.SetDefault("tanh");
140+
AddAttr<std::string>(
141+
"gate_activation",
142+
"(string, default sigmoid) "
143+
"The activation type used in update gate and reset gate.")
144+
.SetDefault("sigmoid");
145+
AddAttr<bool>("is_reverse",
146+
"(bool, defalut: False) "
147+
"whether to compute reversed GRU.")
148+
.SetDefault(false);
149+
AddComment(R"DOC(
150+
The Fusion complete GRU Operator.
151+
This operator fuse the fully-connected operator into GRU,
152+
more details can refer to GRU op.
153+
)DOC");
154+
}
155+
156+
template <typename DeviceContext, typename T>
157+
inline void ReorderInitState(const DeviceContext& ctx,
158+
const framework::Tensor& src,
159+
framework::Vector<size_t> index_lod,
160+
framework::Tensor* dst, bool indexed_src) {
161+
math::CopyMatrixRowsFunctor<DeviceContext, T> row_shuffle;
162+
dst->mutable_data<T>(src.dims(), ctx.GetPlace());
163+
row_shuffle(ctx, src, index_lod, dst, indexed_src);
164+
}
165+
166+
template <typename DeviceContext, typename T>
167+
class FusionGRUKernel : public framework::OpKernel<T> {
168+
public:
169+
void Compute(const framework::ExecutionContext& ctx) const override {
170+
auto* x = ctx.Input<LoDTensor>("X");
171+
auto* wx = ctx.Input<Tensor>("WeightX");
172+
auto* wh = ctx.Input<Tensor>("WeightH");
173+
auto* bias = ctx.Input<Tensor>("Bias");
174+
auto* h0 = ctx.Input<Tensor>("H0");
175+
176+
auto* xx = ctx.Output<LoDTensor>("XX");
177+
auto* batched_gate = ctx.Output<LoDTensor>("BatchedGate");
178+
auto* batch_reset_hidden_prev =
179+
ctx.Output<LoDTensor>("BatchResetHiddenPrev");
180+
auto* batch_hidden = ctx.Output<LoDTensor>("BatchedHidden");
181+
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
182+
bool is_reverse = ctx.Attr<bool>("is_reverse");
183+
184+
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
185+
T* batched_gate_data = batched_gate->mutable_data<T>(ctx.GetPlace());
186+
batch_reset_hidden_prev->mutable_data<T>(ctx.GetPlace());
187+
batch_hidden->mutable_data<T>(ctx.GetPlace());
188+
hidden_out->mutable_data<T>(ctx.GetPlace());
189+
190+
const T* x_data = x->data<T>();
191+
const T* wx_data = wx->data<T>();
192+
const T* wh_data = wh->data<T>();
193+
auto x_dims = x->dims();
194+
auto wx_dims = wx->dims();
195+
auto& dev_ctx = ctx.template device_context<DeviceContext>();
196+
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
197+
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
198+
if (x_dims[1] > wx_dims[1]) {
199+
math::FCCompute<DeviceContext, T>(blas, x_dims[0], wx_dims[1], x_dims[1],
200+
x_data, wx_data, xx_data,
201+
bias ? bias->data<T>() : NULL);
202+
to_batch(dev_ctx, *xx, batched_gate, true, is_reverse);
203+
} else {
204+
to_batch(dev_ctx, *x, xx, true, is_reverse);
205+
batched_gate->set_lod(xx->lod());
206+
math::FCCompute<DeviceContext, T>(blas, x_dims[0], wx_dims[1], x_dims[1],
207+
xx_data, wx_data, batched_gate_data,
208+
bias ? bias->data<T>() : NULL);
209+
}
210+
211+
int frame_size = static_cast<int>(wx_dims[1] / 3);
212+
math::GRUMetaValue<T> gru_value;
213+
gru_value.gate_weight = const_cast<T*>(wh_data);
214+
gru_value.state_weight =
215+
const_cast<T*>(wh_data + 2 * frame_size * frame_size);
216+
Tensor ordered_h0;
217+
218+
framework::Vector<size_t> order(batched_gate->lod()[2]);
219+
220+
if (h0) {
221+
ReorderInitState<DeviceContext, T>(
222+
ctx.template device_context<DeviceContext>(), *h0, order, &ordered_h0,
223+
true);
224+
gru_value.prev_out_value = ordered_h0.data<T>();
225+
} else {
226+
gru_value.prev_out_value = nullptr;
227+
}
228+
auto batch_starts = batched_gate->lod()[0];
229+
size_t seq_len = batch_starts.size() - 1;
230+
auto active_node =
231+
math::detail::GetActivationType(ctx.Attr<std::string>("activation"));
232+
auto active_gate = math::detail::GetActivationType(
233+
ctx.Attr<std::string>("gate_activation"));
234+
235+
#ifdef PADDLE_WITH_MKLML
236+
// use MKL packed to speedup GEMM
237+
if (FLAGS_paddle_num_threads >= 4) {
238+
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
239+
T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
240+
frame_size * 2 /*width of weight*/,
241+
frame_size /*height of height*/);
242+
PADDLE_ENFORCE(packed_gate);
243+
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2,
244+
frame_size, T(1.0), gru_value.gate_weight, frame_size * 2,
245+
packed_gate);
246+
T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
247+
frame_size /*width of weight*/,
248+
frame_size /*height of height*/);
249+
PADDLE_ENFORCE(packed_state);
250+
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size,
251+
frame_size, T(1.0), gru_value.state_weight, frame_size,
252+
packed_state);
253+
for (size_t n = 0; n < seq_len; n++) {
254+
int bstart = static_cast<int>(batch_starts[n]);
255+
int bend = static_cast<int>(batch_starts[n + 1]);
256+
int cur_batch_size = bend - bstart;
257+
258+
Tensor gate_t = batched_gate->Slice(bstart, bend);
259+
Tensor reset_hidden_prev_t =
260+
batch_reset_hidden_prev->Slice(bstart, bend);
261+
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
262+
gru_value.output_value = hidden_t.data<T>();
263+
gru_value.gate_value = gate_t.data<T>();
264+
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
265+
266+
if (gru_value.prev_out_value) {
267+
blas.GEMM_COMPUTE(
268+
CblasNoTrans, CblasPacked, cur_batch_size, frame_size * 2,
269+
frame_size, gru_value.prev_out_value, frame_size, packed_gate,
270+
frame_size * 2, T(1), gru_value.gate_value, frame_size * 3);
271+
}
272+
273+
math::detail::forward_reset_output(
274+
math::detail::forward::gru_resetOutput<T>(), gru_value, frame_size,
275+
cur_batch_size, active_gate);
276+
277+
if (gru_value.prev_out_value) {
278+
blas.GEMM_COMPUTE(
279+
CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size,
280+
gru_value.reset_output_value, frame_size, packed_state,
281+
frame_size, T(1), gru_value.gate_value + frame_size * 2,
282+
frame_size * 3);
283+
}
284+
285+
math::detail::forward_final_output(
286+
math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size,
287+
cur_batch_size, active_node);
288+
289+
gru_value.prev_out_value = gru_value.output_value;
290+
}
291+
292+
blas.GEMM_FREE(packed_gate);
293+
blas.GEMM_FREE(packed_state);
294+
} else {
295+
#endif
296+
for (size_t n = 0; n < seq_len; n++) {
297+
int bstart = static_cast<int>(batch_starts[n]);
298+
int bend = static_cast<int>(batch_starts[n + 1]);
299+
int cur_batch_size = bend - bstart;
300+
301+
Tensor gate_t = batched_gate->Slice(bstart, bend);
302+
Tensor reset_hidden_prev_t =
303+
batch_reset_hidden_prev->Slice(bstart, bend);
304+
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
305+
gru_value.output_value = hidden_t.data<T>();
306+
gru_value.gate_value = gate_t.data<T>();
307+
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
308+
309+
math::GRUUnitFunctor<DeviceContext, T>::compute(
310+
dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
311+
active_gate);
312+
313+
gru_value.prev_out_value = gru_value.output_value;
314+
}
315+
#ifdef PADDLE_WITH_MKLML
316+
}
317+
#endif
318+
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
319+
batch_hidden->set_lod(batched_gate->lod());
320+
to_seq(dev_ctx, *batch_hidden, hidden_out);
321+
}
322+
};
323+
324+
} // namespace operators
325+
} // namespace paddle
326+
327+
namespace ops = paddle::operators;
328+
REGISTER_OPERATOR(fusion_gru, ops::FusionGRUOp, ops::FusionGRUOpMaker,
329+
paddle::framework::DefaultGradOpDescMaker<true>);
330+
REGISTER_OP_CPU_KERNEL(
331+
fusion_gru, ops::FusionGRUKernel<paddle::platform::CPUDeviceContext, float>,
332+
ops::FusionGRUKernel<paddle::platform::CPUDeviceContext, double>);
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/fluid/framework/op_registry.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
using LoDTensor = framework::LoDTensor;
22+
using Tensor = framework::Tensor;
23+
24+
class FusionGRUOp : public framework::OperatorWithKernel {
25+
public:
26+
using framework::OperatorWithKernel::OperatorWithKernel;
27+
28+
void InferShape(framework::InferShapeContext* ctx) const override;
29+
30+
protected:
31+
framework::OpKernelType GetExpectedKernelType(
32+
const framework::ExecutionContext& ctx) const override;
33+
};
34+
35+
class FusionGRUOpMaker : public framework::OpProtoAndCheckerMaker {
36+
public:
37+
void Make() override;
38+
};
39+
40+
} // namespace operators
41+
} // namespace paddle

0 commit comments

Comments
 (0)