@@ -211,16 +211,148 @@ class GRUGradOp : public framework::OperatorWithKernel {
211
211
}
212
212
};
213
213
214
+ template <typename T>
215
+ class GRUCPUKernel : public framework ::OpKernel<T> {
216
+ public:
217
+ void BatchCompute (const framework::ExecutionContext& context) const {
218
+ using DeviceContext = paddle::platform::CPUDeviceContext;
219
+ auto * input = context.Input <LoDTensor>(" Input" );
220
+ auto * h0 = context.Input <Tensor>(" H0" );
221
+ auto * weight = context.Input <Tensor>(" Weight" );
222
+ const T* weight_data = weight->data <T>();
223
+ auto * bias = context.Input <Tensor>(" Bias" );
224
+ auto * batch_gate = context.Output <LoDTensor>(" BatchGate" );
225
+ batch_gate->mutable_data <T>(context.GetPlace ());
226
+ auto * batch_reset_hidden_prev =
227
+ context.Output <LoDTensor>(" BatchResetHiddenPrev" );
228
+ batch_reset_hidden_prev->mutable_data <T>(context.GetPlace ());
229
+ auto * batch_hidden = context.Output <LoDTensor>(" BatchHidden" );
230
+ batch_hidden->mutable_data <T>(context.GetPlace ());
231
+ auto * hidden = context.Output <LoDTensor>(" Hidden" );
232
+ hidden->mutable_data <T>(context.GetPlace ());
233
+
234
+ auto hidden_dims = hidden->dims ();
235
+
236
+ bool is_reverse = context.Attr <bool >(" is_reverse" );
237
+ math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
238
+ auto & dev_ctx = context.template device_context <DeviceContext>();
239
+ to_batch (dev_ctx, *input, batch_gate, true , is_reverse);
240
+
241
+ if (bias) {
242
+ math::RowwiseAdd<DeviceContext, T> add_bias;
243
+ add_bias (dev_ctx, *batch_gate, *bias, batch_gate);
244
+ }
245
+
246
+ int frame_size = hidden_dims[1 ];
247
+ math::GRUMetaValue<T> gru_value;
248
+ gru_value.gate_weight = const_cast <T*>(weight_data);
249
+ gru_value.state_weight =
250
+ const_cast <T*>(weight_data + 2 * frame_size * frame_size);
251
+ Tensor ordered_h0;
252
+
253
+ framework::Vector<size_t > order (batch_gate->lod ()[2 ]);
254
+
255
+ if (h0) {
256
+ // Since the batch computing for GRU reorders the input sequences
257
+ // according to their length. The initialized cell state also needs
258
+ // to reorder.
259
+ ReorderInitState<DeviceContext, T>(
260
+ context.template device_context <DeviceContext>(), *h0, order,
261
+ &ordered_h0, true );
262
+ gru_value.prev_out_value = ordered_h0.data <T>();
263
+ } else {
264
+ gru_value.prev_out_value = nullptr ;
265
+ }
266
+ auto batch_starts = batch_gate->lod ()[0 ];
267
+ size_t num_batch = batch_starts.size () - 1 ;
268
+ auto active_node = math::detail::GetActivationType (
269
+ context.Attr <std::string>(" activation" ));
270
+ auto active_gate = math::detail::GetActivationType (
271
+ context.Attr <std::string>(" gate_activation" ));
272
+
273
+ #ifdef PADDLE_WITH_MKLML
274
+ auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
275
+ // TODO(TJ): make a class
276
+ T* packed_gate = blas.GEMM_ALLOC (CblasBMatrix, 1 /* height of C*/ ,
277
+ frame_size * 2 /* width of weight*/ ,
278
+ frame_size /* height of height*/ );
279
+ PADDLE_ENFORCE (packed_gate);
280
+ blas.GEMM_PACK (CblasBMatrix, CblasNoTrans, 1 /* cur bs?*/ , frame_size * 2 ,
281
+ frame_size, T (1.0 ), gru_value.gate_weight , frame_size * 2 ,
282
+ packed_gate);
283
+ T* packed_state = blas.GEMM_ALLOC (CblasBMatrix, 1 /* height of C*/ ,
284
+ frame_size /* width of weight*/ ,
285
+ frame_size /* height of height*/ );
286
+ PADDLE_ENFORCE (packed_state);
287
+ blas.GEMM_PACK (CblasBMatrix, CblasNoTrans, 1 /* cur bs?*/ , frame_size,
288
+ frame_size, T (1.0 ), gru_value.state_weight , frame_size,
289
+ packed_state);
290
+ #endif
291
+ for (size_t n = 0 ; n < num_batch; n++) {
292
+ int bstart = static_cast <int >(batch_starts[n]);
293
+ int bend = static_cast <int >(batch_starts[n + 1 ]);
294
+ int cur_batch_size = bend - bstart;
295
+
296
+ Tensor gate_t = batch_gate->Slice (bstart, bend);
297
+ Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice (bstart, bend);
298
+ Tensor hidden_t = batch_hidden->Slice (bstart, bend);
299
+ gru_value.output_value = hidden_t .data <T>();
300
+ gru_value.gate_value = gate_t .data <T>();
301
+ gru_value.reset_output_value = reset_hidden_prev_t .data <T>();
302
+
303
+ #ifdef PADDLE_WITH_MKLML
304
+ if (gru_value.prev_out_value ) {
305
+ blas.GEMM_COMPUTE (CblasNoTrans, CblasPacked, cur_batch_size,
306
+ frame_size * 2 , frame_size, gru_value.prev_out_value ,
307
+ frame_size, packed_gate, frame_size * 2 , T (1 ),
308
+ gru_value.gate_value , frame_size * 3 );
309
+ }
310
+
311
+ math::detail::forward_reset_output (
312
+ math::detail::forward::gru_resetOutput<T>(), gru_value, frame_size,
313
+ cur_batch_size, active_gate);
314
+
315
+ if (gru_value.prev_out_value ) {
316
+ blas.GEMM_COMPUTE (
317
+ CblasNoTrans, CblasPacked, cur_batch_size, frame_size, frame_size,
318
+ gru_value.reset_output_value , frame_size, packed_state, frame_size,
319
+ T (1 ), gru_value.gate_value + frame_size * 2 , frame_size * 3 );
320
+ }
321
+
322
+ math::detail::forward_final_output (
323
+ math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size,
324
+ cur_batch_size, active_node);
325
+ #else
326
+ math::GRUUnitFunctor<DeviceContext, T>::compute (
327
+ dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
328
+ active_gate);
329
+ #endif
330
+ gru_value.prev_out_value = gru_value.output_value ;
331
+ }
332
+ #ifdef PADDLE_WITH_MKLML
333
+ blas.GEMM_FREE (packed_gate);
334
+ blas.GEMM_FREE (packed_state);
335
+ #endif
336
+
337
+ math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
338
+ batch_hidden->set_lod (batch_gate->lod ());
339
+ to_seq (dev_ctx, *batch_hidden, hidden);
340
+ }
341
+
342
+ void Compute (const framework::ExecutionContext& context) const override {
343
+ BatchCompute (context);
344
+ }
345
+ };
346
+
214
347
} // namespace operators
215
348
} // namespace paddle
216
349
217
350
namespace ops = paddle::operators;
218
351
REGISTER_OPERATOR (gru, ops::GRUOp, ops::GRUOpMaker,
219
352
paddle::framework::DefaultGradOpDescMaker<true >);
220
353
REGISTER_OPERATOR (gru_grad, ops::GRUGradOp);
221
- REGISTER_OP_CPU_KERNEL (
222
- gru, ops::GRUKernel<paddle::platform::CPUDeviceContext, float >,
223
- ops::GRUKernel<paddle::platform::CPUDeviceContext, double >);
354
+ REGISTER_OP_CPU_KERNEL (gru, ops::GRUCPUKernel<float >,
355
+ ops::GRUCPUKernel<double >);
224
356
REGISTER_OP_CPU_KERNEL (
225
357
gru_grad, ops::GRUGradKernel<paddle::platform::CPUDeviceContext, float >,
226
358
ops::GRUGradKernel<paddle::platform::CPUDeviceContext, double >);
0 commit comments