@@ -183,24 +183,27 @@ class FusionGRUKernel : public framework::OpKernel<T> {
183
183
const int total_T = x_dims[0 ]; \
184
184
const int D3 = wh_dims[1 ]
185
185
186
- #define INIT_OTHER_DEFINES \
187
- auto * h0 = ctx.Input<Tensor>(" H0" ); \
188
- auto * wx = ctx.Input<Tensor>(" WeightX" ); \
189
- auto * bias = ctx.Input<Tensor>(" Bias" ); \
190
- auto * hidden_out = ctx.Output<LoDTensor>(" Hidden" ); \
191
- bool is_reverse = ctx.Attr<bool >(" is_reverse" ); \
192
- const int M = x_dims[1 ]; \
193
- const int D = wh_dims[0 ]; \
194
- const int D2 = D * 2 ; \
195
- const auto & ker = math::jitkernel::KernelPool::Instance() \
196
- .template Get<math::jitkernel::GRUKernel<T>, \
197
- const std::string&, const std::string&>( \
198
- ctx.Attr<std::string>(" gate_activation" ), \
199
- ctx.Attr<std::string>(" activation" ), D); \
200
- const T* x_data = x->data<T>(); \
201
- const T* wx_data = wx->data<T>(); \
202
- const T* wh_data = wh->data<T>(); \
203
- auto place = ctx.GetPlace(); \
186
+ #define INIT_OTHER_DEFINES \
187
+ auto * h0 = ctx.Input<Tensor>(" H0" ); \
188
+ auto * wx = ctx.Input<Tensor>(" WeightX" ); \
189
+ auto * bias = ctx.Input<Tensor>(" Bias" ); \
190
+ auto * hidden_out = ctx.Output<LoDTensor>(" Hidden" ); \
191
+ bool is_reverse = ctx.Attr<bool >(" is_reverse" ); \
192
+ const int M = x_dims[1 ]; \
193
+ const int D = wh_dims[0 ]; \
194
+ const int D2 = D * 2 ; \
195
+ const math::jitkernel::gru_attr_t attr ( \
196
+ D, ctx.Attr<std::string>(" gate_activation" ), \
197
+ ctx.Attr<std::string>(" activation" )); \
198
+ math::jitkernel::gru_t one_step; \
199
+ const auto & ker = \
200
+ math::jitkernel::KernelPool::Instance () \
201
+ .template Get<math::jitkernel::GRUKernel<T>, \
202
+ const math::jitkernel::gru_attr_t &>(attr); \
203
+ const T* x_data = x->data<T>(); \
204
+ const T* wx_data = wx->data<T>(); \
205
+ const T* wh_data = wh->data<T>(); \
206
+ auto place = ctx.GetPlace(); \
204
207
T* xx_data = xx->mutable_data<T>(place)
205
208
206
209
void SeqCompute (const framework::ExecutionContext& ctx) const {
@@ -237,7 +240,9 @@ class FusionGRUKernel : public framework::OpKernel<T> {
237
240
if (h0_data) {
238
241
prev_hidden_data = h0_data + bid * D;
239
242
} else {
240
- ker->ComputeH1 (xx_data, hidden_out_data);
243
+ one_step.gates = xx_data;
244
+ one_step.ht = hidden_out_data;
245
+ ker->ComputeH1 (&one_step, &attr);
241
246
prev_hidden_data = hidden_out_data;
242
247
tstart = 1 ;
243
248
move_step ();
@@ -247,12 +252,15 @@ class FusionGRUKernel : public framework::OpKernel<T> {
247
252
blas.GEMM (CblasNoTrans, CblasNoTrans, 1 , D2, D, static_cast <T>(1 ),
248
253
prev_hidden_data, D, wh_data, D2, static_cast <T>(1 ), xx_data,
249
254
D3);
250
- ker->ComputeHtPart1 (xx_data, prev_hidden_data, hidden_out_data);
255
+ one_step.gates = xx_data;
256
+ one_step.ht_1 = prev_hidden_data;
257
+ one_step.ht = hidden_out_data;
258
+ ker->ComputeHtPart1 (&one_step, &attr);
251
259
// gemm rt * Ws
252
260
blas.GEMM (CblasNoTrans, CblasNoTrans, 1 , D, D, static_cast <T>(1 ),
253
261
hidden_out_data, D, wh_state_data, D, static_cast <T>(1 ),
254
262
xx_data + D2, D3);
255
- ker->ComputeHtPart2 (xx_data, prev_hidden_data, hidden_out_data );
263
+ ker->ComputeHtPart2 (&one_step, &attr );
256
264
// save prev
257
265
prev_hidden_data = hidden_out_data;
258
266
move_step ();
@@ -314,7 +322,9 @@ class FusionGRUKernel : public framework::OpKernel<T> {
314
322
T* cur_out_data = batched_out_data;
315
323
// W: {W_update, W_reset; W_state}
316
324
for (int i = 0 ; i < max_bs; ++i) {
317
- ker->ComputeH1 (cur_in_data, cur_out_data);
325
+ one_step.gates = cur_in_data;
326
+ one_step.ht = cur_out_data;
327
+ ker->ComputeH1 (&one_step, &attr);
318
328
// add offset
319
329
cur_in_data += D3;
320
330
cur_out_data += D;
@@ -339,8 +349,11 @@ class FusionGRUKernel : public framework::OpKernel<T> {
339
349
T* cur_out_data = batched_out_data;
340
350
T* cur_prev_hidden_data = prev_hidden_data;
341
351
for (int i = 0 ; i < cur_bs; ++i) {
342
- ker->ComputeHtPart1 (cur_batched_data, cur_prev_hidden_data,
343
- cur_out_data);
352
+ one_step.gates = cur_batched_data;
353
+ one_step.ht_1 = cur_prev_hidden_data;
354
+ one_step.ht = cur_out_data;
355
+ ker->ComputeHtPart1 (&one_step, &attr);
356
+
344
357
cur_batched_data += D3;
345
358
cur_prev_hidden_data += D;
346
359
cur_out_data += D;
@@ -354,8 +367,10 @@ class FusionGRUKernel : public framework::OpKernel<T> {
354
367
355
368
cur_prev_hidden_data = prev_hidden_data;
356
369
for (int i = 0 ; i < cur_bs; ++i) {
357
- ker->ComputeHtPart2 (cur_batched_data, cur_prev_hidden_data,
358
- cur_out_data);
370
+ one_step.gates = cur_batched_data;
371
+ one_step.ht_1 = cur_prev_hidden_data;
372
+ one_step.ht = cur_out_data;
373
+ ker->ComputeHtPart2 (&one_step, &attr);
359
374
cur_batched_data += D3;
360
375
cur_prev_hidden_data += D;
361
376
cur_out_data += D;
0 commit comments