Skip to content

Commit 18c322c

Browse files
committed
seperate cpu and gpu implementations for gru kernel compute
1 parent 54c95e4 commit 18c322c

File tree

3 files changed

+225
-126
lines changed

3 files changed

+225
-126
lines changed

paddle/fluid/operators/gru_op.cc

Lines changed: 135 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,16 +211,148 @@ class GRUGradOp : public framework::OperatorWithKernel {
211211
}
212212
};
213213

214+
template <typename T>
215+
class GRUCPUKernel : public framework::OpKernel<T> {
216+
public:
217+
void BatchCompute(const framework::ExecutionContext& context) const {
218+
using DeviceContext = paddle::platform::CPUDeviceContext;
219+
auto* input = context.Input<LoDTensor>("Input");
220+
auto* h0 = context.Input<Tensor>("H0");
221+
auto* weight = context.Input<Tensor>("Weight");
222+
const T* weight_data = weight->data<T>();
223+
auto* bias = context.Input<Tensor>("Bias");
224+
auto* batch_gate = context.Output<LoDTensor>("BatchGate");
225+
batch_gate->mutable_data<T>(context.GetPlace());
226+
auto* batch_reset_hidden_prev =
227+
context.Output<LoDTensor>("BatchResetHiddenPrev");
228+
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
229+
auto* batch_hidden = context.Output<LoDTensor>("BatchHidden");
230+
batch_hidden->mutable_data<T>(context.GetPlace());
231+
auto* hidden = context.Output<LoDTensor>("Hidden");
232+
hidden->mutable_data<T>(context.GetPlace());
233+
234+
auto hidden_dims = hidden->dims();
235+
236+
bool is_reverse = context.Attr<bool>("is_reverse");
237+
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
238+
auto& dev_ctx = context.template device_context<DeviceContext>();
239+
to_batch(dev_ctx, *input, batch_gate, true, is_reverse);
240+
241+
if (bias) {
242+
math::RowwiseAdd<DeviceContext, T> add_bias;
243+
add_bias(dev_ctx, *batch_gate, *bias, batch_gate);
244+
}
245+
246+
int frame_size = hidden_dims[1];
247+
math::GRUMetaValue<T> gru_value;
248+
gru_value.gate_weight = const_cast<T*>(weight_data);
249+
gru_value.state_weight =
250+
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
251+
Tensor ordered_h0;
252+
253+
framework::Vector<size_t> order(batch_gate->lod()[2]);
254+
255+
if (h0) {
256+
// Since the batch computing for GRU reorders the input sequences
257+
// according to their length. The initialized cell state also needs
258+
// to reorder.
259+
ReorderInitState<DeviceContext, T>(
260+
context.template device_context<DeviceContext>(), *h0, order,
261+
&ordered_h0, true);
262+
gru_value.prev_out_value = ordered_h0.data<T>();
263+
} else {
264+
gru_value.prev_out_value = nullptr;
265+
}
266+
auto batch_starts = batch_gate->lod()[0];
267+
size_t num_batch = batch_starts.size() - 1;
268+
auto active_node = math::detail::GetActivationType(
269+
context.Attr<std::string>("activation"));
270+
auto active_gate = math::detail::GetActivationType(
271+
context.Attr<std::string>("gate_activation"));
272+
273+
#ifdef PADDLE_WITH_MKLML
274+
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
275+
// TODO(TJ): make a class
276+
T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
277+
frame_size * 2 /*width of weight*/,
278+
frame_size /*height of height*/);
279+
PADDLE_ENFORCE(packed_gate);
280+
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2,
281+
frame_size, T(1.0), gru_value.gate_weight, frame_size * 2,
282+
packed_gate);
283+
T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
284+
frame_size /*width of weight*/,
285+
frame_size /*height of height*/);
286+
PADDLE_ENFORCE(packed_state);
287+
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size,
288+
frame_size, T(1.0), gru_value.state_weight, frame_size,
289+
packed_state);
290+
#endif
291+
for (size_t n = 0; n < num_batch; n++) {
292+
int bstart = static_cast<int>(batch_starts[n]);
293+
int bend = static_cast<int>(batch_starts[n + 1]);
294+
int cur_batch_size = bend - bstart;
295+
296+
Tensor gate_t = batch_gate->Slice(bstart, bend);
297+
Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
298+
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
299+
gru_value.output_value = hidden_t.data<T>();
300+
gru_value.gate_value = gate_t.data<T>();
301+
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
302+
303+
#ifdef PADDLE_WITH_MKLML
304+
if (gru_value.prev_out_value) {
305+
blas.GEMM_COMPUTE(CblasNoTrans, CblasPacked, cur_batch_size,
306+
frame_size * 2, frame_size, gru_value.prev_out_value,
307+
frame_size, packed_gate, frame_size * 2, T(1),
308+
gru_value.gate_value, frame_size * 3);
309+
}
310+
311+
math::detail::forward_reset_output(
312+
math::detail::forward::gru_resetOutput<T>(), gru_value, frame_size,
313+
cur_batch_size, active_gate);
314+
315+
if (gru_value.prev_out_value) {
316+
blas.GEMM_COMPUTE(
317+
CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size,
318+
gru_value.reset_output_value, frame_size, packed_state, frame_size,
319+
T(1), gru_value.gate_value + frame_size * 2, frame_size * 3);
320+
}
321+
322+
math::detail::forward_final_output(
323+
math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size,
324+
cur_batch_size, active_node);
325+
#else
326+
math::GRUUnitFunctor<DeviceContext, T>::compute(
327+
dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
328+
active_gate);
329+
#endif
330+
gru_value.prev_out_value = gru_value.output_value;
331+
}
332+
#ifdef PADDLE_WITH_MKLML
333+
blas.GEMM_FREE(packed_gate);
334+
blas.GEMM_FREE(packed_state);
335+
#endif
336+
337+
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
338+
batch_hidden->set_lod(batch_gate->lod());
339+
to_seq(dev_ctx, *batch_hidden, hidden);
340+
}
341+
342+
void Compute(const framework::ExecutionContext& context) const override {
343+
BatchCompute(context);
344+
}
345+
};
346+
214347
} // namespace operators
215348
} // namespace paddle
216349

217350
namespace ops = paddle::operators;
218351
REGISTER_OPERATOR(gru, ops::GRUOp, ops::GRUOpMaker,
219352
paddle::framework::DefaultGradOpDescMaker<true>);
220353
REGISTER_OPERATOR(gru_grad, ops::GRUGradOp);
221-
REGISTER_OP_CPU_KERNEL(
222-
gru, ops::GRUKernel<paddle::platform::CPUDeviceContext, float>,
223-
ops::GRUKernel<paddle::platform::CPUDeviceContext, double>);
354+
REGISTER_OP_CPU_KERNEL(gru, ops::GRUCPUKernel<float>,
355+
ops::GRUCPUKernel<double>);
224356
REGISTER_OP_CPU_KERNEL(
225357
gru_grad, ops::GRUGradKernel<paddle::platform::CPUDeviceContext, float>,
226358
ops::GRUGradKernel<paddle::platform::CPUDeviceContext, double>);

paddle/fluid/operators/gru_op.cu.cc

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,96 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/gru_op.h"
1616

17+
namespace paddle {
18+
namespace operators {
19+
20+
template <typename DeviceContext, typename T>
21+
class GRUKernel : public framework::OpKernel<T> {
22+
public:
23+
void BatchCompute(const framework::ExecutionContext& context) const {
24+
auto* input = context.Input<LoDTensor>("Input");
25+
auto* h0 = context.Input<Tensor>("H0");
26+
auto* weight = context.Input<Tensor>("Weight");
27+
const T* weight_data = weight->data<T>();
28+
auto* bias = context.Input<Tensor>("Bias");
29+
auto* batch_gate = context.Output<LoDTensor>("BatchGate");
30+
batch_gate->mutable_data<T>(context.GetPlace());
31+
auto* batch_reset_hidden_prev =
32+
context.Output<LoDTensor>("BatchResetHiddenPrev");
33+
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
34+
auto* batch_hidden = context.Output<LoDTensor>("BatchHidden");
35+
batch_hidden->mutable_data<T>(context.GetPlace());
36+
auto* hidden = context.Output<LoDTensor>("Hidden");
37+
hidden->mutable_data<T>(context.GetPlace());
38+
39+
auto hidden_dims = hidden->dims();
40+
41+
bool is_reverse = context.Attr<bool>("is_reverse");
42+
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
43+
auto& dev_ctx = context.template device_context<DeviceContext>();
44+
to_batch(dev_ctx, *input, batch_gate, true, is_reverse);
45+
46+
if (bias) {
47+
math::RowwiseAdd<DeviceContext, T> add_bias;
48+
add_bias(dev_ctx, *batch_gate, *bias, batch_gate);
49+
}
50+
51+
int frame_size = hidden_dims[1];
52+
math::GRUMetaValue<T> gru_value;
53+
gru_value.gate_weight = const_cast<T*>(weight_data);
54+
gru_value.state_weight =
55+
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
56+
Tensor ordered_h0;
57+
58+
framework::Vector<size_t> order(batch_gate->lod()[2]);
59+
60+
if (h0) {
61+
// Since the batch computing for GRU reorders the input sequences
62+
// according to their length. The initialized cell state also needs
63+
// to reorder.
64+
ReorderInitState<DeviceContext, T>(
65+
context.template device_context<DeviceContext>(), *h0, order,
66+
&ordered_h0, true);
67+
gru_value.prev_out_value = ordered_h0.data<T>();
68+
} else {
69+
gru_value.prev_out_value = nullptr;
70+
}
71+
auto batch_starts = batch_gate->lod()[0];
72+
size_t num_batch = batch_starts.size() - 1;
73+
auto active_node = math::detail::GetActivationType(
74+
context.Attr<std::string>("activation"));
75+
auto active_gate = math::detail::GetActivationType(
76+
context.Attr<std::string>("gate_activation"));
77+
for (size_t n = 0; n < num_batch; n++) {
78+
int bstart = static_cast<int>(batch_starts[n]);
79+
int bend = static_cast<int>(batch_starts[n + 1]);
80+
int cur_batch_size = bend - bstart;
81+
82+
Tensor gate_t = batch_gate->Slice(bstart, bend);
83+
Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
84+
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
85+
gru_value.output_value = hidden_t.data<T>();
86+
gru_value.gate_value = gate_t.data<T>();
87+
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
88+
math::GRUUnitFunctor<DeviceContext, T>::compute(
89+
dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
90+
active_gate);
91+
gru_value.prev_out_value = gru_value.output_value;
92+
}
93+
94+
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
95+
batch_hidden->set_lod(batch_gate->lod());
96+
to_seq(dev_ctx, *batch_hidden, hidden);
97+
}
98+
99+
void Compute(const framework::ExecutionContext& context) const override {
100+
BatchCompute(context);
101+
}
102+
};
103+
104+
} // namespace operators
105+
} // namespace paddle
106+
17107
namespace ops = paddle::operators;
18108
REGISTER_OP_CUDA_KERNEL(
19109
gru, ops::GRUKernel<paddle::platform::CUDADeviceContext, float>,

paddle/fluid/operators/gru_op.h

Lines changed: 0 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -40,129 +40,6 @@ inline void ReorderInitState(const DeviceContext& ctx,
4040
row_shuffle(ctx, src, index_lod, dst, indexed_src);
4141
}
4242

43-
template <typename DeviceContext, typename T>
44-
class GRUKernel : public framework::OpKernel<T> {
45-
public:
46-
void BatchCompute(const framework::ExecutionContext& context) const {
47-
auto* input = context.Input<LoDTensor>("Input");
48-
auto* h0 = context.Input<Tensor>("H0");
49-
auto* weight = context.Input<Tensor>("Weight");
50-
const T* weight_data = weight->data<T>();
51-
auto* bias = context.Input<Tensor>("Bias");
52-
auto* batch_gate = context.Output<LoDTensor>("BatchGate");
53-
batch_gate->mutable_data<T>(context.GetPlace());
54-
auto* batch_reset_hidden_prev =
55-
context.Output<LoDTensor>("BatchResetHiddenPrev");
56-
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
57-
auto* batch_hidden = context.Output<LoDTensor>("BatchHidden");
58-
batch_hidden->mutable_data<T>(context.GetPlace());
59-
auto* hidden = context.Output<LoDTensor>("Hidden");
60-
hidden->mutable_data<T>(context.GetPlace());
61-
62-
auto hidden_dims = hidden->dims();
63-
64-
bool is_reverse = context.Attr<bool>("is_reverse");
65-
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
66-
auto& dev_ctx = context.template device_context<DeviceContext>();
67-
to_batch(dev_ctx, *input, batch_gate, true, is_reverse);
68-
69-
if (bias) {
70-
math::RowwiseAdd<DeviceContext, T> add_bias;
71-
add_bias(dev_ctx, *batch_gate, *bias, batch_gate);
72-
}
73-
74-
int frame_size = hidden_dims[1];
75-
math::GRUMetaValue<T> gru_value;
76-
gru_value.gate_weight = const_cast<T*>(weight_data);
77-
gru_value.state_weight =
78-
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
79-
Tensor ordered_h0;
80-
81-
framework::Vector<size_t> order(batch_gate->lod()[2]);
82-
83-
if (h0) {
84-
// Since the batch computing for GRU reorders the input sequences
85-
// according to their length. The initialized cell state also needs
86-
// to reorder.
87-
ReorderInitState<DeviceContext, T>(
88-
context.template device_context<DeviceContext>(), *h0, order,
89-
&ordered_h0, true);
90-
gru_value.prev_out_value = ordered_h0.data<T>();
91-
} else {
92-
gru_value.prev_out_value = nullptr;
93-
}
94-
auto batch_starts = batch_gate->lod()[0];
95-
size_t num_batch = batch_starts.size() - 1;
96-
auto active_node = math::detail::GetActivationType(
97-
context.Attr<std::string>("activation"));
98-
auto active_gate = math::detail::GetActivationType(
99-
context.Attr<std::string>("gate_activation"));
100-
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
101-
102-
// TODO(TJ): make a class, make one pack
103-
T* packed_gate = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
104-
frame_size * 2 /*width of weight*/,
105-
frame_size /*height of height*/);
106-
PADDLE_ENFORCE(packed_gate);
107-
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size * 2,
108-
frame_size, T(1.0), gru_value.gate_weight, frame_size * 2,
109-
packed_gate);
110-
T* packed_state = blas.GEMM_ALLOC(CblasBMatrix, 1 /*height of C*/,
111-
frame_size /*width of weight*/,
112-
frame_size /*height of height*/);
113-
PADDLE_ENFORCE(packed_state);
114-
blas.GEMM_PACK(CblasBMatrix, CblasNoTrans, 1 /*cur bs?*/, frame_size,
115-
frame_size, T(1.0), gru_value.state_weight, frame_size,
116-
packed_state);
117-
118-
for (size_t n = 0; n < num_batch; n++) {
119-
int bstart = static_cast<int>(batch_starts[n]);
120-
int bend = static_cast<int>(batch_starts[n + 1]);
121-
int cur_batch_size = bend - bstart;
122-
123-
Tensor gate_t = batch_gate->Slice(bstart, bend);
124-
Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
125-
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
126-
gru_value.output_value = hidden_t.data<T>();
127-
gru_value.gate_value = gate_t.data<T>();
128-
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
129-
if (gru_value.prev_out_value) {
130-
blas.GEMM_COMPUTE(CblasNoTrans, CblasPacked, cur_batch_size,
131-
frame_size * 2, frame_size, gru_value.prev_out_value,
132-
frame_size, packed_gate, frame_size * 2, T(1),
133-
gru_value.gate_value, frame_size * 3);
134-
}
135-
136-
math::detail::forward_reset_output(
137-
math::detail::forward::gru_resetOutput<T>(), gru_value, frame_size,
138-
cur_batch_size, active_gate);
139-
140-
if (gru_value.prev_out_value) {
141-
blas.GEMM_COMPUTE(
142-
CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size,
143-
gru_value.reset_output_value, frame_size, packed_state, frame_size,
144-
T(1), gru_value.gate_value + frame_size * 2, frame_size * 3);
145-
}
146-
147-
math::detail::forward_final_output(
148-
math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size,
149-
cur_batch_size, active_node);
150-
151-
gru_value.prev_out_value = gru_value.output_value;
152-
}
153-
blas.GEMM_FREE(packed_gate);
154-
blas.GEMM_FREE(packed_state);
155-
156-
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
157-
batch_hidden->set_lod(batch_gate->lod());
158-
to_seq(dev_ctx, *batch_hidden, hidden);
159-
}
160-
161-
void Compute(const framework::ExecutionContext& context) const override {
162-
BatchCompute(context);
163-
}
164-
};
165-
16643
template <typename DeviceContext, typename T>
16744
class GRUGradKernel : public framework::OpKernel<T> {
16845
public:

0 commit comments

Comments
 (0)