@@ -14,6 +14,11 @@ limitations under the License. */
14
14
15
15
#include " paddle/fluid/operators/gru_op.h"
16
16
#include < string>
17
+ #include " paddle/fluid/operators/math/blas.h"
18
+ #include " paddle/fluid/operators/math/detail/gru_cpu_kernel.h"
19
+ #include " paddle/fluid/operators/math/detail/gru_kernel.h"
20
+
21
+ DECLARE_int32 (paddle_num_threads);
17
22
18
23
namespace paddle {
19
24
namespace operators {
@@ -264,76 +269,94 @@ class GRUCPUKernel : public framework::OpKernel<T> {
264
269
gru_value.prev_out_value = nullptr ;
265
270
}
266
271
auto batch_starts = batch_gate->lod ()[0 ];
267
- size_t num_batch = batch_starts.size () - 1 ;
272
+ size_t seq_len = batch_starts.size () - 1 ;
268
273
auto active_node = math::detail::GetActivationType (
269
274
context.Attr <std::string>(" activation" ));
270
275
auto active_gate = math::detail::GetActivationType (
271
276
context.Attr <std::string>(" gate_activation" ));
272
277
273
278
#ifdef PADDLE_WITH_MKLML
274
- auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
275
- // TODO(TJ): make a class
276
- T* packed_gate = blas.GEMM_ALLOC (CblasBMatrix, 1 /* height of C*/ ,
277
- frame_size * 2 /* width of weight*/ ,
278
- frame_size /* height of height*/ );
279
- PADDLE_ENFORCE (packed_gate);
280
- blas.GEMM_PACK (CblasBMatrix, CblasNoTrans, 1 /* cur bs?*/ , frame_size * 2 ,
281
- frame_size, T (1.0 ), gru_value.gate_weight , frame_size * 2 ,
282
- packed_gate);
283
- T* packed_state = blas.GEMM_ALLOC (CblasBMatrix, 1 /* height of C*/ ,
284
- frame_size /* width of weight*/ ,
285
- frame_size /* height of height*/ );
286
- PADDLE_ENFORCE (packed_state);
287
- blas.GEMM_PACK (CblasBMatrix, CblasNoTrans, 1 /* cur bs?*/ , frame_size,
288
- frame_size, T (1.0 ), gru_value.state_weight , frame_size,
289
- packed_state);
290
- #endif
291
- for (size_t n = 0 ; n < num_batch; n++) {
292
- int bstart = static_cast <int >(batch_starts[n]);
293
- int bend = static_cast <int >(batch_starts[n + 1 ]);
294
- int cur_batch_size = bend - bstart;
295
-
296
- Tensor gate_t = batch_gate->Slice (bstart, bend);
297
- Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice (bstart, bend);
298
- Tensor hidden_t = batch_hidden->Slice (bstart, bend);
299
- gru_value.output_value = hidden_t .data <T>();
300
- gru_value.gate_value = gate_t .data <T>();
301
- gru_value.reset_output_value = reset_hidden_prev_t .data <T>();
279
+ if (FLAGS_paddle_num_threads >= 4 ) {
280
+ auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
281
+ T* packed_gate = blas.GEMM_ALLOC (CblasBMatrix, 1 /* height of C*/ ,
282
+ frame_size * 2 /* width of weight*/ ,
283
+ frame_size /* height of height*/ );
284
+ PADDLE_ENFORCE (packed_gate);
285
+ blas.GEMM_PACK (CblasBMatrix, CblasNoTrans, 1 /* cur bs?*/ , frame_size * 2 ,
286
+ frame_size, T (1.0 ), gru_value.gate_weight , frame_size * 2 ,
287
+ packed_gate);
288
+ T* packed_state = blas.GEMM_ALLOC (CblasBMatrix, 1 /* height of C*/ ,
289
+ frame_size /* width of weight*/ ,
290
+ frame_size /* height of height*/ );
291
+ PADDLE_ENFORCE (packed_state);
292
+ blas.GEMM_PACK (CblasBMatrix, CblasNoTrans, 1 /* cur bs?*/ , frame_size,
293
+ frame_size, T (1.0 ), gru_value.state_weight , frame_size,
294
+ packed_state);
295
+ for (size_t n = 0 ; n < seq_len; n++) {
296
+ int bstart = static_cast <int >(batch_starts[n]);
297
+ int bend = static_cast <int >(batch_starts[n + 1 ]);
298
+ int cur_batch_size = bend - bstart;
302
299
303
- # ifdef PADDLE_WITH_MKLML
304
- if (gru_value. prev_out_value ) {
305
- blas. GEMM_COMPUTE (CblasNoTrans, CblasPacked, cur_batch_size,
306
- frame_size * 2 , frame_size, gru_value. prev_out_value ,
307
- frame_size, packed_gate, frame_size * 2 , T ( 1 ),
308
- gru_value.gate_value , frame_size * 3 );
309
- }
300
+ Tensor gate_t = batch_gate-> Slice (bstart, bend);
301
+ Tensor reset_hidden_prev_t =
302
+ batch_reset_hidden_prev-> Slice (bstart, bend);
303
+ Tensor hidden_t = batch_hidden-> Slice (bstart, bend);
304
+ gru_value. output_value = hidden_t . data <T>();
305
+ gru_value.gate_value = gate_t . data <T>( );
306
+ gru_value. reset_output_value = reset_hidden_prev_t . data <T>();
310
307
311
- math::detail::forward_reset_output (
312
- math::detail::forward::gru_resetOutput<T>(), gru_value, frame_size,
313
- cur_batch_size, active_gate);
308
+ if (gru_value.prev_out_value ) {
309
+ blas.GEMM_COMPUTE (
310
+ CblasNoTrans, CblasPacked, cur_batch_size, frame_size * 2 ,
311
+ frame_size, gru_value.prev_out_value , frame_size, packed_gate,
312
+ frame_size * 2 , T (1 ), gru_value.gate_value , frame_size * 3 );
313
+ }
314
314
315
- if (gru_value.prev_out_value ) {
316
- blas.GEMM_COMPUTE (
317
- CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size,
318
- gru_value.reset_output_value , frame_size, packed_state, frame_size,
319
- T (1 ), gru_value.gate_value + frame_size * 2 , frame_size * 3 );
315
+ math::detail::forward_reset_output (
316
+ math::detail::forward::gru_resetOutput<T>(), gru_value, frame_size,
317
+ cur_batch_size, active_gate);
318
+
319
+ if (gru_value.prev_out_value ) {
320
+ blas.GEMM_COMPUTE (
321
+ CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size,
322
+ gru_value.reset_output_value , frame_size, packed_state,
323
+ frame_size, T (1 ), gru_value.gate_value + frame_size * 2 ,
324
+ frame_size * 3 );
325
+ }
326
+
327
+ math::detail::forward_final_output (
328
+ math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size,
329
+ cur_batch_size, active_node);
330
+
331
+ gru_value.prev_out_value = gru_value.output_value ;
320
332
}
321
333
322
- math::detail::forward_final_output (
323
- math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size,
324
- cur_batch_size, active_node);
325
- #else
326
- math::GRUUnitFunctor<DeviceContext, T>::compute (
327
- dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
328
- active_gate);
334
+ blas.GEMM_FREE (packed_gate);
335
+ blas.GEMM_FREE (packed_state);
336
+ } else {
329
337
#endif
330
- gru_value.prev_out_value = gru_value.output_value ;
331
- }
338
+ for (size_t n = 0 ; n < seq_len; n++) {
339
+ int bstart = static_cast <int >(batch_starts[n]);
340
+ int bend = static_cast <int >(batch_starts[n + 1 ]);
341
+ int cur_batch_size = bend - bstart;
342
+
343
+ Tensor gate_t = batch_gate->Slice (bstart, bend);
344
+ Tensor reset_hidden_prev_t =
345
+ batch_reset_hidden_prev->Slice (bstart, bend);
346
+ Tensor hidden_t = batch_hidden->Slice (bstart, bend);
347
+ gru_value.output_value = hidden_t .data <T>();
348
+ gru_value.gate_value = gate_t .data <T>();
349
+ gru_value.reset_output_value = reset_hidden_prev_t .data <T>();
350
+
351
+ math::GRUUnitFunctor<DeviceContext, T>::compute (
352
+ dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
353
+ active_gate);
354
+
355
+ gru_value.prev_out_value = gru_value.output_value ;
356
+ }
332
357
#ifdef PADDLE_WITH_MKLML
333
- blas.GEMM_FREE (packed_gate);
334
- blas.GEMM_FREE (packed_state);
358
+ }
335
359
#endif
336
-
337
360
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
338
361
batch_hidden->set_lod (batch_gate->lod ());
339
362
to_seq (dev_ctx, *batch_hidden, hidden);
0 commit comments