@@ -16,10 +16,9 @@ limitations under the License. */
16
16
#include < cstring> // for memcpy
17
17
#include < string>
18
18
#include " paddle/fluid/operators/math/blas.h"
19
- #include " paddle/fluid/operators/math/cpu_vec.h"
20
19
#include " paddle/fluid/operators/math/fc_compute.h"
20
+ #include " paddle/fluid/operators/math/jit_kernel.h"
21
21
#include " paddle/fluid/operators/math/sequence2batch.h"
22
- #include " paddle/fluid/platform/cpu_info.h"
23
22
24
23
namespace paddle {
25
24
namespace operators {
@@ -174,58 +173,44 @@ class FusionGRUKernel : public framework::OpKernel<T> {
174
173
}
175
174
}
176
175
177
- #define INIT_VEC_FUNC \
178
- std::function<void (const int , const T *, T *)> act_gate, act_state; \
179
- std::function<void (const int , const T*, const T*, const T*, T*)> cross; \
180
- auto & act_gate_str = ctx.Attr<std::string>(" gate_activation" ); \
181
- auto & act_state_str = ctx.Attr<std::string>(" activation" ); \
182
- if (platform::jit::MayIUse(platform::jit::avx)) { \
183
- math::VecActivations<T, platform::jit::avx> act_functor; \
184
- act_gate = act_functor (act_gate_str); \
185
- act_state = act_functor (act_state_str); \
186
- cross = math::vec_cross<T, platform::jit::avx>; \
187
- } else { \
188
- math::VecActivations<T, platform::jit::isa_any> act_functor; \
189
- act_gate = act_functor (act_gate_str); \
190
- act_state = act_functor (act_state_str); \
191
- cross = math::vec_cross<T, platform::jit::isa_any>; \
192
- }
193
-
194
- #define INIT_BASE_INPUT_OUTPUT \
195
- auto * h0 = ctx.Input<Tensor>(" H0" ); \
196
- auto * wx = ctx.Input<Tensor>(" WeightX" ); \
197
- auto * wh = ctx.Input<Tensor>(" WeightH" ); \
198
- auto * bias = ctx.Input<Tensor>(" Bias" ); \
199
- auto * xx = ctx.Output<LoDTensor>(" XX" ); \
200
- auto * hidden_out = ctx.Output<LoDTensor>(" Hidden" ); \
201
- bool is_reverse = ctx.Attr<bool >(" is_reverse" );
202
-
203
- #define INIT_BASE_SIZES \
204
- auto x_dims = x->dims (); /* T x M*/ \
205
- auto wh_dims = wh->dims (); /* D x 3D*/ \
206
- const int total_T = x_dims[0 ]; \
207
- const int M = x_dims[1 ]; \
208
- const int D = wh_dims[0 ]; \
209
- const int D3 = wh_dims[1 ]; \
210
- const int D2 = D * 2 ;
176
+ #define INIT_BASE_DEFINES \
177
+ auto * x = ctx.Input<LoDTensor>(" X" ); \
178
+ auto * wh = ctx.Input<Tensor>(" WeightH" ); \
179
+ auto * xx = ctx.Output<LoDTensor>(" XX" ); \
180
+ auto x_lod = x->lod (); \
181
+ auto x_dims = x->dims (); /* T x M*/ \
182
+ auto wh_dims = wh->dims (); /* D x 3D*/ \
183
+ const int total_T = x_dims[0 ]; \
184
+ const int D3 = wh_dims[1 ]
185
+
186
+ #define INIT_OTHER_DEFINES \
187
+ auto * h0 = ctx.Input<Tensor>(" H0" ); \
188
+ auto * wx = ctx.Input<Tensor>(" WeightX" ); \
189
+ auto * bias = ctx.Input<Tensor>(" Bias" ); \
190
+ auto * hidden_out = ctx.Output<LoDTensor>(" Hidden" ); \
191
+ bool is_reverse = ctx.Attr<bool >(" is_reverse" ); \
192
+ const int M = x_dims[1 ]; \
193
+ const int D = wh_dims[0 ]; \
194
+ const int D2 = D * 2 ; \
195
+ const auto & ker = math::jitkernel::KernelPool::Instance() \
196
+ .template Get<math::jitkernel::GRUKernel<T>, \
197
+ const std::string&, const std::string&>( \
198
+ ctx.Attr<std::string>(" gate_activation" ), \
199
+ ctx.Attr<std::string>(" activation" ), D); \
200
+ const T* x_data = x->data<T>(); \
201
+ const T* wx_data = wx->data<T>(); \
202
+ const T* wh_data = wh->data<T>(); \
203
+ auto place = ctx.GetPlace(); \
204
+ T* xx_data = xx->mutable_data<T>(place)
211
205
212
206
void SeqCompute (const framework::ExecutionContext& ctx) const {
213
207
using DeviceContext = paddle::platform::CPUDeviceContext;
214
- auto * x = ctx.Input <LoDTensor>(" X" );
215
- INIT_BASE_INPUT_OUTPUT
216
- INIT_BASE_SIZES
217
- INIT_VEC_FUNC
218
-
219
- auto x_lod = x->lod ();
208
+ INIT_BASE_DEFINES;
209
+ INIT_OTHER_DEFINES;
220
210
const int N = x_lod[0 ].size () - 1 ;
221
- const T* x_data = x->data <T>();
222
211
const T* h0_data = h0 ? h0->data <T>() : nullptr ;
223
- const T* wx_data = wx->data <T>();
224
- const T* wh_data = wh->data <T>();
225
212
const T* wh_state_data = wh_data + D * D2;
226
- T* xx_data = xx->mutable_data <T>(ctx.GetPlace ());
227
- T* hidden_out_data = hidden_out->mutable_data <T>(ctx.GetPlace ());
228
-
213
+ T* hidden_out_data = hidden_out->mutable_data <T>(place);
229
214
auto blas = math::GetBlas<DeviceContext, T>(ctx);
230
215
math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data,
231
216
xx_data,
@@ -252,14 +237,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
252
237
if (h0_data) {
253
238
prev_hidden_data = h0_data + bid * D;
254
239
} else {
255
- // W: {W_update, W_reset; W_state}
256
- // update gate
257
- act_gate (D, xx_data, xx_data);
258
- // state gate
259
- act_state (D, xx_data + D2, xx_data + D2);
260
- // out = a*b
261
- blas.VMUL (D, xx_data, xx_data + D2, hidden_out_data);
262
- // save prev
240
+ ker->ComputeH1 (xx_data, hidden_out_data);
263
241
prev_hidden_data = hidden_out_data;
264
242
tstart = 1 ;
265
243
move_step ();
@@ -269,17 +247,12 @@ class FusionGRUKernel : public framework::OpKernel<T> {
269
247
blas.GEMM (CblasNoTrans, CblasNoTrans, 1 , D2, D, static_cast <T>(1 ),
270
248
prev_hidden_data, D, wh_data, D2, static_cast <T>(1 ), xx_data,
271
249
D3);
272
- act_gate (D2, xx_data, xx_data);
273
- // rt = rt*ht_1 inplace result
274
- blas.VMUL (D, prev_hidden_data, xx_data + D, hidden_out_data);
275
-
250
+ ker->ComputeHtPart1 (xx_data, prev_hidden_data, hidden_out_data);
276
251
// gemm rt * Ws
277
252
blas.GEMM (CblasNoTrans, CblasNoTrans, 1 , D, D, static_cast <T>(1 ),
278
253
hidden_out_data, D, wh_state_data, D, static_cast <T>(1 ),
279
254
xx_data + D2, D3);
280
- act_state (D, xx_data + D2, xx_data + D2);
281
- // out = zt*ht~ + (1-zt)*ht_1
282
- cross (D, xx_data, xx_data + D2, prev_hidden_data, hidden_out_data);
255
+ ker->ComputeHtPart2 (xx_data, prev_hidden_data, hidden_out_data);
283
256
// save prev
284
257
prev_hidden_data = hidden_out_data;
285
258
move_step ();
@@ -289,28 +262,19 @@ class FusionGRUKernel : public framework::OpKernel<T> {
289
262
290
263
void BatchCompute (const framework::ExecutionContext& ctx) const {
291
264
using DeviceContext = paddle::platform::CPUDeviceContext;
292
- auto * x = ctx.Input <LoDTensor>(" X" );
293
- INIT_BASE_INPUT_OUTPUT
294
- INIT_BASE_SIZES
295
- if (x->lod ()[0 ].size () == 2 ) {
265
+ INIT_BASE_DEFINES;
266
+ if (x_lod[0 ].size () == 2 ) {
296
267
xx->Resize ({total_T, D3});
297
268
SeqCompute (ctx);
298
269
return ;
299
270
}
300
- INIT_VEC_FUNC
301
-
271
+ INIT_OTHER_DEFINES;
302
272
auto * reordered_h0 = ctx.Output <Tensor>(" ReorderedH0" );
303
273
auto * batched_input = ctx.Output <LoDTensor>(" BatchedInput" );
304
274
auto * batched_out = ctx.Output <LoDTensor>(" BatchedOut" );
305
-
306
- const T* x_data = x->data <T>();
307
- const T* wx_data = wx->data <T>();
308
- const T* wh_data = wh->data <T>();
309
- T* xx_data = xx->mutable_data <T>(ctx.GetPlace ());
310
- T* batched_input_data = batched_input->mutable_data <T>(ctx.GetPlace ());
311
- T* batched_out_data = batched_out->mutable_data <T>(ctx.GetPlace ());
312
- hidden_out->mutable_data <T>(ctx.GetPlace ());
313
-
275
+ T* batched_input_data = batched_input->mutable_data <T>(place);
276
+ T* batched_out_data = batched_out->mutable_data <T>(place);
277
+ hidden_out->mutable_data <T>(place);
314
278
auto & dev_ctx = ctx.template device_context <DeviceContext>();
315
279
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
316
280
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
@@ -336,7 +300,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
336
300
T* prev_hidden_data = nullptr ;
337
301
if (h0) {
338
302
// reorder h0
339
- T* reordered_h0_data = reordered_h0->mutable_data <T>(ctx. GetPlace () );
303
+ T* reordered_h0_data = reordered_h0->mutable_data <T>(place );
340
304
const T* h0_data = h0->data <T>();
341
305
prev_hidden_data = reordered_h0_data;
342
306
size_t sz = sizeof (T) * D;
@@ -350,12 +314,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
350
314
T* cur_out_data = batched_out_data;
351
315
// W: {W_update, W_reset; W_state}
352
316
for (int i = 0 ; i < max_bs; ++i) {
353
- // update gate
354
- act_gate (D, cur_in_data, cur_in_data);
355
- // state gate
356
- act_state (D, cur_in_data + D2, cur_in_data + D2);
357
- // out = a*b
358
- blas.VMUL (D, cur_in_data, cur_in_data + D2, cur_out_data);
317
+ ker->ComputeH1 (cur_in_data, cur_out_data);
359
318
// add offset
360
319
cur_in_data += D3;
361
320
cur_out_data += D;
@@ -380,10 +339,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
380
339
T* cur_out_data = batched_out_data;
381
340
T* cur_prev_hidden_data = prev_hidden_data;
382
341
for (int i = 0 ; i < cur_bs; ++i) {
383
- act_gate (D2, cur_batched_data, cur_batched_data);
384
- // rt = rt*ht_1 inplace result
385
- blas.VMUL (D, cur_prev_hidden_data, cur_batched_data + D, cur_out_data);
386
-
342
+ ker->ComputeHtPart1 (cur_batched_data, cur_prev_hidden_data,
343
+ cur_out_data);
387
344
cur_batched_data += D3;
388
345
cur_prev_hidden_data += D;
389
346
cur_out_data += D;
@@ -397,12 +354,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
397
354
398
355
cur_prev_hidden_data = prev_hidden_data;
399
356
for (int i = 0 ; i < cur_bs; ++i) {
400
- // ht~ = act_state(...)
401
- act_state (D, cur_batched_data + D2, cur_batched_data + D2);
402
- // out = zt*ht~ + (1-zt)*ht_1
403
- cross (D, cur_batched_data, cur_batched_data + D2, cur_prev_hidden_data,
404
- cur_out_data);
405
-
357
+ ker->ComputeHtPart2 (cur_batched_data, cur_prev_hidden_data,
358
+ cur_out_data);
406
359
cur_batched_data += D3;
407
360
cur_prev_hidden_data += D;
408
361
cur_out_data += D;
@@ -416,9 +369,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
416
369
batched_out->set_lod (batched_lod);
417
370
to_seq (dev_ctx, *batched_out, hidden_out);
418
371
}
419
- #undef INIT_VEC_FUNC
420
- #undef INIT_BASE_SIZES
421
- #undef INIT_BASE_INPUT_OUTPUT
372
+ #undef INIT_OTHER_DEFINES
373
+ #undef INIT_BASE_DEFINES
422
374
};
423
375
424
376
} // namespace operators
0 commit comments