Skip to content

Commit 9cb8738

Browse files
authored
Merge pull request #14018 from tensor-tang/refine/jit/gru
Refine/jit/gru
2 parents 8c1eea9 + 032c3a0 commit 9cb8738

File tree

5 files changed

+244
-154
lines changed

5 files changed

+244
-154
lines changed

paddle/fluid/operators/fusion_gru_op.cc

Lines changed: 50 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@ limitations under the License. */
1616
#include <cstring> // for memcpy
1717
#include <string>
1818
#include "paddle/fluid/operators/math/blas.h"
19-
#include "paddle/fluid/operators/math/cpu_vec.h"
2019
#include "paddle/fluid/operators/math/fc_compute.h"
20+
#include "paddle/fluid/operators/math/jit_kernel.h"
2121
#include "paddle/fluid/operators/math/sequence2batch.h"
22-
#include "paddle/fluid/platform/cpu_info.h"
2322

2423
namespace paddle {
2524
namespace operators {
@@ -174,58 +173,44 @@ class FusionGRUKernel : public framework::OpKernel<T> {
174173
}
175174
}
176175

177-
#define INIT_VEC_FUNC \
178-
std::function<void(const int, const T *, T *)> act_gate, act_state; \
179-
std::function<void(const int, const T*, const T*, const T*, T*)> cross; \
180-
auto& act_gate_str = ctx.Attr<std::string>("gate_activation"); \
181-
auto& act_state_str = ctx.Attr<std::string>("activation"); \
182-
if (platform::jit::MayIUse(platform::jit::avx)) { \
183-
math::VecActivations<T, platform::jit::avx> act_functor; \
184-
act_gate = act_functor(act_gate_str); \
185-
act_state = act_functor(act_state_str); \
186-
cross = math::vec_cross<T, platform::jit::avx>; \
187-
} else { \
188-
math::VecActivations<T, platform::jit::isa_any> act_functor; \
189-
act_gate = act_functor(act_gate_str); \
190-
act_state = act_functor(act_state_str); \
191-
cross = math::vec_cross<T, platform::jit::isa_any>; \
192-
}
193-
194-
#define INIT_BASE_INPUT_OUTPUT \
195-
auto* h0 = ctx.Input<Tensor>("H0"); \
196-
auto* wx = ctx.Input<Tensor>("WeightX"); \
197-
auto* wh = ctx.Input<Tensor>("WeightH"); \
198-
auto* bias = ctx.Input<Tensor>("Bias"); \
199-
auto* xx = ctx.Output<LoDTensor>("XX"); \
200-
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
201-
bool is_reverse = ctx.Attr<bool>("is_reverse");
202-
203-
#define INIT_BASE_SIZES \
204-
auto x_dims = x->dims(); /* T x M*/ \
205-
auto wh_dims = wh->dims(); /* D x 3D*/ \
206-
const int total_T = x_dims[0]; \
207-
const int M = x_dims[1]; \
208-
const int D = wh_dims[0]; \
209-
const int D3 = wh_dims[1]; \
210-
const int D2 = D * 2;
176+
#define INIT_BASE_DEFINES \
177+
auto* x = ctx.Input<LoDTensor>("X"); \
178+
auto* wh = ctx.Input<Tensor>("WeightH"); \
179+
auto* xx = ctx.Output<LoDTensor>("XX"); \
180+
auto x_lod = x->lod(); \
181+
auto x_dims = x->dims(); /* T x M*/ \
182+
auto wh_dims = wh->dims(); /* D x 3D*/ \
183+
const int total_T = x_dims[0]; \
184+
const int D3 = wh_dims[1]
185+
186+
#define INIT_OTHER_DEFINES \
187+
auto* h0 = ctx.Input<Tensor>("H0"); \
188+
auto* wx = ctx.Input<Tensor>("WeightX"); \
189+
auto* bias = ctx.Input<Tensor>("Bias"); \
190+
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
191+
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
192+
const int M = x_dims[1]; \
193+
const int D = wh_dims[0]; \
194+
const int D2 = D * 2; \
195+
const auto& ker = math::jitkernel::KernelPool::Instance() \
196+
.template Get<math::jitkernel::GRUKernel<T>, \
197+
const std::string&, const std::string&>( \
198+
ctx.Attr<std::string>("gate_activation"), \
199+
ctx.Attr<std::string>("activation"), D); \
200+
const T* x_data = x->data<T>(); \
201+
const T* wx_data = wx->data<T>(); \
202+
const T* wh_data = wh->data<T>(); \
203+
auto place = ctx.GetPlace(); \
204+
T* xx_data = xx->mutable_data<T>(place)
211205

212206
void SeqCompute(const framework::ExecutionContext& ctx) const {
213207
using DeviceContext = paddle::platform::CPUDeviceContext;
214-
auto* x = ctx.Input<LoDTensor>("X");
215-
INIT_BASE_INPUT_OUTPUT
216-
INIT_BASE_SIZES
217-
INIT_VEC_FUNC
218-
219-
auto x_lod = x->lod();
208+
INIT_BASE_DEFINES;
209+
INIT_OTHER_DEFINES;
220210
const int N = x_lod[0].size() - 1;
221-
const T* x_data = x->data<T>();
222211
const T* h0_data = h0 ? h0->data<T>() : nullptr;
223-
const T* wx_data = wx->data<T>();
224-
const T* wh_data = wh->data<T>();
225212
const T* wh_state_data = wh_data + D * D2;
226-
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
227-
T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace());
228-
213+
T* hidden_out_data = hidden_out->mutable_data<T>(place);
229214
auto blas = math::GetBlas<DeviceContext, T>(ctx);
230215
math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data,
231216
xx_data,
@@ -252,14 +237,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
252237
if (h0_data) {
253238
prev_hidden_data = h0_data + bid * D;
254239
} else {
255-
// W: {W_update, W_reset; W_state}
256-
// update gate
257-
act_gate(D, xx_data, xx_data);
258-
// state gate
259-
act_state(D, xx_data + D2, xx_data + D2);
260-
// out = a*b
261-
blas.VMUL(D, xx_data, xx_data + D2, hidden_out_data);
262-
// save prev
240+
ker->ComputeH1(xx_data, hidden_out_data);
263241
prev_hidden_data = hidden_out_data;
264242
tstart = 1;
265243
move_step();
@@ -269,17 +247,12 @@ class FusionGRUKernel : public framework::OpKernel<T> {
269247
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast<T>(1),
270248
prev_hidden_data, D, wh_data, D2, static_cast<T>(1), xx_data,
271249
D3);
272-
act_gate(D2, xx_data, xx_data);
273-
// rt = rt*ht_1 inplace result
274-
blas.VMUL(D, prev_hidden_data, xx_data + D, hidden_out_data);
275-
250+
ker->ComputeHtPart1(xx_data, prev_hidden_data, hidden_out_data);
276251
// gemm rt * Ws
277252
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1),
278253
hidden_out_data, D, wh_state_data, D, static_cast<T>(1),
279254
xx_data + D2, D3);
280-
act_state(D, xx_data + D2, xx_data + D2);
281-
// out = zt*ht~ + (1-zt)*ht_1
282-
cross(D, xx_data, xx_data + D2, prev_hidden_data, hidden_out_data);
255+
ker->ComputeHtPart2(xx_data, prev_hidden_data, hidden_out_data);
283256
// save prev
284257
prev_hidden_data = hidden_out_data;
285258
move_step();
@@ -289,28 +262,19 @@ class FusionGRUKernel : public framework::OpKernel<T> {
289262

290263
void BatchCompute(const framework::ExecutionContext& ctx) const {
291264
using DeviceContext = paddle::platform::CPUDeviceContext;
292-
auto* x = ctx.Input<LoDTensor>("X");
293-
INIT_BASE_INPUT_OUTPUT
294-
INIT_BASE_SIZES
295-
if (x->lod()[0].size() == 2) {
265+
INIT_BASE_DEFINES;
266+
if (x_lod[0].size() == 2) {
296267
xx->Resize({total_T, D3});
297268
SeqCompute(ctx);
298269
return;
299270
}
300-
INIT_VEC_FUNC
301-
271+
INIT_OTHER_DEFINES;
302272
auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
303273
auto* batched_input = ctx.Output<LoDTensor>("BatchedInput");
304274
auto* batched_out = ctx.Output<LoDTensor>("BatchedOut");
305-
306-
const T* x_data = x->data<T>();
307-
const T* wx_data = wx->data<T>();
308-
const T* wh_data = wh->data<T>();
309-
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
310-
T* batched_input_data = batched_input->mutable_data<T>(ctx.GetPlace());
311-
T* batched_out_data = batched_out->mutable_data<T>(ctx.GetPlace());
312-
hidden_out->mutable_data<T>(ctx.GetPlace());
313-
275+
T* batched_input_data = batched_input->mutable_data<T>(place);
276+
T* batched_out_data = batched_out->mutable_data<T>(place);
277+
hidden_out->mutable_data<T>(place);
314278
auto& dev_ctx = ctx.template device_context<DeviceContext>();
315279
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
316280
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
@@ -336,7 +300,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
336300
T* prev_hidden_data = nullptr;
337301
if (h0) {
338302
// reorder h0
339-
T* reordered_h0_data = reordered_h0->mutable_data<T>(ctx.GetPlace());
303+
T* reordered_h0_data = reordered_h0->mutable_data<T>(place);
340304
const T* h0_data = h0->data<T>();
341305
prev_hidden_data = reordered_h0_data;
342306
size_t sz = sizeof(T) * D;
@@ -350,12 +314,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
350314
T* cur_out_data = batched_out_data;
351315
// W: {W_update, W_reset; W_state}
352316
for (int i = 0; i < max_bs; ++i) {
353-
// update gate
354-
act_gate(D, cur_in_data, cur_in_data);
355-
// state gate
356-
act_state(D, cur_in_data + D2, cur_in_data + D2);
357-
// out = a*b
358-
blas.VMUL(D, cur_in_data, cur_in_data + D2, cur_out_data);
317+
ker->ComputeH1(cur_in_data, cur_out_data);
359318
// add offset
360319
cur_in_data += D3;
361320
cur_out_data += D;
@@ -380,10 +339,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
380339
T* cur_out_data = batched_out_data;
381340
T* cur_prev_hidden_data = prev_hidden_data;
382341
for (int i = 0; i < cur_bs; ++i) {
383-
act_gate(D2, cur_batched_data, cur_batched_data);
384-
// rt = rt*ht_1 inplace result
385-
blas.VMUL(D, cur_prev_hidden_data, cur_batched_data + D, cur_out_data);
386-
342+
ker->ComputeHtPart1(cur_batched_data, cur_prev_hidden_data,
343+
cur_out_data);
387344
cur_batched_data += D3;
388345
cur_prev_hidden_data += D;
389346
cur_out_data += D;
@@ -397,12 +354,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
397354

398355
cur_prev_hidden_data = prev_hidden_data;
399356
for (int i = 0; i < cur_bs; ++i) {
400-
// ht~ = act_state(...)
401-
act_state(D, cur_batched_data + D2, cur_batched_data + D2);
402-
// out = zt*ht~ + (1-zt)*ht_1
403-
cross(D, cur_batched_data, cur_batched_data + D2, cur_prev_hidden_data,
404-
cur_out_data);
405-
357+
ker->ComputeHtPart2(cur_batched_data, cur_prev_hidden_data,
358+
cur_out_data);
406359
cur_batched_data += D3;
407360
cur_prev_hidden_data += D;
408361
cur_out_data += D;
@@ -416,9 +369,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
416369
batched_out->set_lod(batched_lod);
417370
to_seq(dev_ctx, *batched_out, hidden_out);
418371
}
419-
#undef INIT_VEC_FUNC
420-
#undef INIT_BASE_SIZES
421-
#undef INIT_BASE_INPUT_OUTPUT
372+
#undef INIT_OTHER_DEFINES
373+
#undef INIT_BASE_DEFINES
422374
};
423375

424376
} // namespace operators

paddle/fluid/operators/math/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,6 @@ endif()
7575
cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split)
7676
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
7777
cc_library(jit_kernel
78-
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc
78+
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc
7979
DEPS cpu_info cblas)
8080
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)

paddle/fluid/operators/math/jit_kernel.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,15 @@ class LSTMKernel : public Kernel {
142142
const T *wp_data = nullptr) const = 0;
143143
};
144144

145+
template <typename T>
146+
class GRUKernel : public Kernel {
147+
public:
148+
// compute h1 without h0
149+
virtual void ComputeH1(T *gates, T *ht) const = 0;
150+
virtual void ComputeHtPart1(T *gates, const T *ht_1, T *ht) const = 0;
151+
virtual void ComputeHtPart2(T *gates, const T *ht_1, T *ht) const = 0;
152+
};
153+
145154
} // namespace jitkernel
146155
} // namespace math
147156
} // namespace operators

0 commit comments

Comments
 (0)