@@ -143,7 +143,7 @@ class LSTMPKernel : public framework::OpKernel<T> {
143
143
auto proj_act = math::detail::GetActivationType (
144
144
ctx.Attr <std::string>(" proj_activation" ));
145
145
auto & place = *ctx.template device_context <DeviceContext>().eigen_device ();
146
-
146
+ auto blas = math::GetBlas<DeviceContext, T>(device_ctx);
147
147
for (size_t n = 0 ; n < num_batch; n++) {
148
148
int bstart = static_cast <int >(batch_starts[n]);
149
149
int bend = static_cast <int >(batch_starts[n + 1 ]);
@@ -160,9 +160,8 @@ class LSTMPKernel : public framework::OpKernel<T> {
160
160
int pre_h_start = static_cast <int >(batch_starts[n - 1 ]);
161
161
int pre_h_end = pre_h_start + cur_batch_size;
162
162
auto pre_proj_t = batch_proj.Slice (pre_h_start, pre_h_end);
163
- math::matmul<DeviceContext, T>(device_ctx, pre_proj_t , false , *weight,
164
- false , static_cast <T>(1.0 ), &gate_t ,
165
- static_cast <T>(1.0 ));
163
+ blas.MatMul (pre_proj_t , false , *weight, false , static_cast <T>(1.0 ),
164
+ &gate_t , static_cast <T>(1.0 ));
166
165
} else if (hidden_t0) {
167
166
// If n == 0 and there is no initialized hidden state, that is to say
168
167
// the H0 is zeros, the calculation W_h * H0 will be skiped.
@@ -176,16 +175,14 @@ class LSTMPKernel : public framework::OpKernel<T> {
176
175
ordered_proj0->mutable_data <T>(ctx.GetPlace ());
177
176
ReorderInitState<DeviceContext, T>(device_ctx, *hidden_t0, order,
178
177
&ordered_h0, true );
179
- math::matmul<DeviceContext, T>(device_ctx, ordered_h0, false ,
180
- *proj_weight, false , static_cast <T>(1.0 ),
181
- ordered_proj0, static_cast <T>(0.0 ));
178
+ blas.MatMul (ordered_h0, false , *proj_weight, false , static_cast <T>(1.0 ),
179
+ ordered_proj0, static_cast <T>(0.0 ));
182
180
if (proj_act != math::detail::ActivationType::kIdentity ) {
183
181
auto proj0_dev = EigenMatrix<T>::From (*ordered_proj0);
184
182
ActCompute (cell_act, place, proj0_dev, proj0_dev);
185
183
}
186
- math::matmul<DeviceContext, T>(device_ctx, *ordered_proj0, false ,
187
- *weight, false , static_cast <T>(1.0 ),
188
- &gate_t , static_cast <T>(1.0 ));
184
+ blas.MatMul (*ordered_proj0, false , *weight, false , static_cast <T>(1.0 ),
185
+ &gate_t , static_cast <T>(1.0 ));
189
186
}
190
187
191
188
lstmp_value.gate_value = gate_t .data <T>();
@@ -196,9 +193,8 @@ class LSTMPKernel : public framework::OpKernel<T> {
196
193
device_ctx, lstmp_value, frame_size, cur_batch_size, gate_act,
197
194
cell_act, cand_act);
198
195
lstmp_value.prev_state_value = lstmp_value.state_value ;
199
- math::matmul<DeviceContext, T>(device_ctx, hidden_t , false , *proj_weight,
200
- false , static_cast <T>(1.0 ), &proj_t ,
201
- static_cast <T>(0.0 ));
196
+ blas.MatMul (hidden_t , false , *proj_weight, false , static_cast <T>(1.0 ),
197
+ &proj_t , static_cast <T>(0.0 ));
202
198
if (proj_act != math::detail::ActivationType::kIdentity ) {
203
199
auto proj_t_dev = EigenMatrix<T>::From (proj_t );
204
200
ActCompute (cell_act, place, proj_t_dev, proj_t_dev);
@@ -361,6 +357,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
361
357
362
358
auto batch_starts = batch_gate->lod ()[0 ];
363
359
size_t num_batch = batch_starts.size () - 1 ;
360
+ auto blas = math::GetBlas<DeviceContext, T>(device_ctx);
364
361
for (int n = static_cast <int >(num_batch) - 1 ; n >= 0 ; n--) {
365
362
int bstart = static_cast <int >(batch_starts[n]);
366
363
int bend = static_cast <int >(batch_starts[n + 1 ]);
@@ -375,15 +372,13 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
375
372
}
376
373
/* hidden state backwarad */
377
374
Tensor out_g = batch_hidden_g.Slice (bstart, bend);
378
- math::matmul<DeviceContext, T>(device_ctx, proj_g, false , *proj_weight,
379
- true , static_cast <T>(1.0 ), &out_g,
380
- static_cast <T>(0.0 ));
375
+ blas.MatMul (proj_g, false , *proj_weight, true , static_cast <T>(1.0 ),
376
+ &out_g, static_cast <T>(0.0 ));
381
377
/* projection weight backward*/
382
378
if (proj_weight_g) {
383
379
Tensor hidden_t = batch_hidden->Slice (bstart, bend);
384
- math::matmul<DeviceContext, T>(device_ctx, hidden_t , true , proj_g,
385
- false , static_cast <T>(1.0 ),
386
- proj_weight_g, static_cast <T>(1.0 ));
380
+ blas.MatMul (hidden_t , true , proj_g, false , static_cast <T>(1.0 ),
381
+ proj_weight_g, static_cast <T>(1.0 ));
387
382
}
388
383
389
384
Tensor gate = batch_gate->Slice (bstart, bend);
@@ -419,49 +414,43 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
419
414
int pre_h_start = static_cast <int >(batch_starts[n - 1 ]);
420
415
int pre_h_end = pre_h_start + cur_batch_size;
421
416
auto pre_proj_g = batch_proj_g.Slice (pre_h_start, pre_h_end);
422
- math::matmul<DeviceContext, T>(device_ctx, gate_g, false , *weight, true ,
423
- static_cast <T>(1.0 ), &pre_proj_g,
424
- static_cast <T>(1.0 ));
417
+ blas.MatMul (gate_g, false , *weight, true , static_cast <T>(1.0 ),
418
+ &pre_proj_g, static_cast <T>(1.0 ));
425
419
if (weight_g) {
426
420
/* weight backward*/
427
421
auto pre_proj = batch_proj.Slice (pre_h_start, pre_h_end);
428
- math::matmul<DeviceContext, T>(device_ctx, pre_proj, true , gate_g,
429
- false , static_cast <T>(1.0 ), weight_g,
430
- static_cast <T>(1.0 ));
422
+ blas.MatMul (pre_proj, true , gate_g, false , static_cast <T>(1.0 ),
423
+ weight_g, static_cast <T>(1.0 ));
431
424
}
432
425
} else {
433
426
if (h0 && weight_g) {
434
427
ReorderInitState<DeviceContext, T>(device_ctx, *h0, order,
435
428
&ordered_h0, true );
436
429
if (weight_g) {
437
- math::matmul<DeviceContext, T>(device_ctx, *ordered_proj0, true ,
438
- gate_g, false , static_cast <T>(1.0 ),
439
- weight_g, static_cast <T>(1.0 ));
430
+ blas.MatMul (*ordered_proj0, true , gate_g, false ,
431
+ static_cast <T>(1.0 ), weight_g, static_cast <T>(1.0 ));
440
432
}
441
433
}
442
434
if (h0 && (h0_g || proj_weight_g)) {
443
435
ordered_h0_g.mutable_data <T>(h0_g->dims (), ctx.GetPlace ());
444
436
Tensor proj0_g;
445
437
proj0_g.Resize ({in_dims[0 ], proj_weight->dims ()[1 ]});
446
438
proj0_g.mutable_data <T>(ctx.GetPlace ());
447
- math::matmul<DeviceContext, T>(device_ctx, gate_g, false , *weight,
448
- true , static_cast <T>(1.0 ), &proj0_g,
449
- static_cast <T>(0.0 ));
439
+ blas.MatMul (gate_g, false , *weight, true , static_cast <T>(1.0 ),
440
+ &proj0_g, static_cast <T>(0.0 ));
450
441
if (proj_act != math::detail::ActivationType::kIdentity ) {
451
442
auto proj0_dev = EigenMatrix<T>::From (*ordered_proj0);
452
443
auto proj0_g_dev = EigenMatrix<T>::From (proj0_g);
453
444
ActGradCompute (cell_act, place, proj0_dev, proj0_dev, proj0_g_dev,
454
445
proj0_g_dev);
455
446
}
456
447
if (h0_g) {
457
- math::matmul<DeviceContext, T>(
458
- device_ctx, proj0_g, false , *proj_weight, true ,
459
- static_cast <T>(1.0 ), &ordered_h0_g, static_cast <T>(0.0 ));
448
+ blas.MatMul (proj0_g, false , *proj_weight, true , static_cast <T>(1.0 ),
449
+ &ordered_h0_g, static_cast <T>(0.0 ));
460
450
}
461
451
if (proj_weight_g) {
462
- math::matmul<DeviceContext, T>(device_ctx, ordered_h0, true ,
463
- proj0_g, false , static_cast <T>(1.0 ),
464
- proj_weight_g, static_cast <T>(1.0 ));
452
+ blas.MatMul (ordered_h0, true , proj0_g, false , static_cast <T>(1.0 ),
453
+ proj_weight_g, static_cast <T>(1.0 ));
465
454
}
466
455
}
467
456
}
0 commit comments