Skip to content

Commit 815d888

Browse files
committed
Clean MatMul
1 parent 9d7279b commit 815d888

File tree

12 files changed

+156
-299
lines changed

12 files changed

+156
-299
lines changed

paddle/fluid/operators/conv_op.h

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ class GemmConvKernel : public framework::OpKernel<T> {
161161
math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
162162

163163
auto& dev_ctx = context.template device_context<DeviceContext>();
164+
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
164165
for (int i = 0; i < batch_size; i++) {
165166
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
166167
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
@@ -186,8 +187,7 @@ class GemmConvKernel : public framework::OpKernel<T> {
186187
// gemm
187188
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
188189
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
189-
math::matmul<DeviceContext, T>(dev_ctx, filter_slice, false, col_matrix,
190-
false, T(1.0), &out_slice, T(0.0));
190+
blas.MatMul(filter_slice, col_matrix, &out_slice);
191191
}
192192
}
193193
}
@@ -274,6 +274,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
274274

275275
math::SetConstant<DeviceContext, T> set_zero;
276276
auto& dev_ctx = context.template device_context<DeviceContext>();
277+
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
277278

278279
if (input_grad) {
279280
input_grad->mutable_data<T>(context.GetPlace());
@@ -303,9 +304,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
303304
col_matrix.ShareDataWith(in_grad_slice);
304305
col_matrix.Resize(col_matrix_shape);
305306
}
306-
math::matmul<DeviceContext, T>(dev_ctx, filter_slice, true,
307-
out_grad_slice, false, T(1.0),
308-
&col_matrix, T(0.0));
307+
blas.MatMul(filter_slice, true, out_grad_slice, false, &col_matrix);
309308

310309
if (is_expand && data_dim == 2U) {
311310
col2im(dev_ctx, col, dilations, strides,
@@ -352,9 +351,8 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
352351
// gemm
353352
Tensor filter_grad_slice =
354353
filter_grad_.Slice(g * out_step, (g + 1) * out_step);
355-
math::matmul<DeviceContext, T>(dev_ctx, out_grad_slice, false,
356-
col_matrix, true, T(1.0),
357-
&filter_grad_slice, T(1.0));
354+
blas.MatMul(out_grad_slice, false, col_matrix, true,
355+
&filter_grad_slice);
358356
}
359357
}
360358
}

paddle/fluid/operators/conv_transpose_op.h

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
118118
output->mutable_data<T>(context.GetPlace());
119119
math::SetConstant<DeviceContext, T> set_zero;
120120
auto& dev_ctx = context.template device_context<DeviceContext>();
121+
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
121122
set_zero(dev_ctx, output, static_cast<T>(0));
122123

123124
math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
@@ -134,9 +135,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
134135

135136
// col_matrix = filter * input_batch
136137
// of shape (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
137-
math::matmul<DeviceContext, T>(dev_ctx, filter, true, input_batch, false,
138-
static_cast<T>(1.0), &col_matrix,
139-
static_cast<T>(0.0));
138+
blas.MatMul(filter, true, input_batch, false, &col_matrix);
140139

141140
if (data_dim == 2U) {
142141
// col2im: col_matrix -> dy
@@ -213,6 +212,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
213212
// im2col + gemm (similar to conv-forward)
214213
// input need to compute gradient
215214
auto& dev_ctx = context.template device_context<DeviceContext>();
215+
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
216216
if (input_grad || filter_grad) {
217217
Tensor col;
218218
col.mutable_data<T>(col_shape, context.GetPlace());
@@ -267,9 +267,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
267267
// or
268268
// (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m,
269269
// d, h, w)
270-
math::matmul<DeviceContext, T>(
271-
dev_ctx, filter, false, col_matrix, false, static_cast<T>(1.0),
272-
&input_grad_batch, static_cast<T>(0.0));
270+
blas.MatMul(filter, false, col_matrix, false, &input_grad_batch);
273271
}
274272
if (filter_grad) {
275273
// input batch
@@ -279,9 +277,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
279277
// or
280278
// (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d *
281279
// k_h * k_w)
282-
math::matmul<DeviceContext, T>(dev_ctx, in_batch, false, col_matrix,
283-
true, static_cast<T>(1.0),
284-
&filter_grad_, static_cast<T>(1.0));
280+
blas.MatMul(in_batch, false, col_matrix, true, &filter_grad_);
285281
}
286282
}
287283
}

paddle/fluid/operators/lstm_op.h

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ class LSTMKernel : public framework::OpKernel<T> {
114114
auto cand_act = math::detail::GetActivationType(
115115
ctx.Attr<std::string>("candidate_activation"));
116116

117+
auto blas = math::GetBlas<DeviceContext, T>(device_ctx);
117118
for (size_t n = 0; n < num_batch; n++) {
118119
int bstart = static_cast<int>(batch_starts[n]);
119120
int bend = static_cast<int>(batch_starts[n + 1]);
@@ -129,9 +130,8 @@ class LSTMKernel : public framework::OpKernel<T> {
129130
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
130131
int pre_h_end = pre_h_start + cur_batch_size;
131132
auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end);
132-
math::matmul<DeviceContext, T>(device_ctx, pre_hidden_t, false, *weight,
133-
false, static_cast<T>(1.0), &gate_t,
134-
static_cast<T>(1.0));
133+
blas.MatMul(pre_hidden_t, false, *weight, false, static_cast<T>(1.0),
134+
&gate_t, static_cast<T>(1.0));
135135
} else if (hidden_t0) {
136136
// If n == 0 and there is no initialized hidden state, that is to say
137137
// the H0 is zeros, the calculation W_h * H0 will be skiped.
@@ -143,9 +143,8 @@ class LSTMKernel : public framework::OpKernel<T> {
143143
Tensor ordered_h0;
144144
ReorderInitState<DeviceContext, T>(device_ctx, *hidden_t0, order,
145145
&ordered_h0, true);
146-
math::matmul<DeviceContext, T>(device_ctx, ordered_h0, false, *weight,
147-
false, static_cast<T>(1.0), &gate_t,
148-
static_cast<T>(1.0));
146+
blas.MatMul(ordered_h0, false, *weight, false, static_cast<T>(1.0),
147+
&gate_t, static_cast<T>(1.0));
149148
}
150149

151150
lstm_value.gate_value = gate_t.data<T>();
@@ -282,6 +281,7 @@ class LSTMGradKernel : public framework::OpKernel<T> {
282281

283282
auto batch_starts = batch_gate->lod()[0];
284283
size_t num_batch = batch_starts.size() - 1;
284+
auto blas = math::GetBlas<DeviceContext, T>(device_ctx);
285285
for (int n = static_cast<int>(num_batch) - 1; n >= 0; n--) {
286286
int bstart = static_cast<int>(batch_starts[n]);
287287
int bend = static_cast<int>(batch_starts[n + 1]);
@@ -320,29 +320,25 @@ class LSTMGradKernel : public framework::OpKernel<T> {
320320
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
321321
int pre_h_end = pre_h_start + cur_batch_size;
322322
auto pre_hidden_g = batch_hidden_g.Slice(pre_h_start, pre_h_end);
323-
math::matmul<DeviceContext, T>(device_ctx, gate_g, false, *weight, true,
324-
static_cast<T>(1.0), &pre_hidden_g,
325-
static_cast<T>(1.0));
323+
blas.MatMul(gate_g, false, *weight, true, static_cast<T>(1.0),
324+
&pre_hidden_g, static_cast<T>(1.0));
326325
if (weight_g) {
327326
/* backward weight */
328327
auto pre_hidden = batch_hidden.Slice(pre_h_start, pre_h_end);
329-
math::matmul<DeviceContext, T>(device_ctx, pre_hidden, true, gate_g,
330-
false, static_cast<T>(1.0), weight_g,
331-
static_cast<T>(1.0));
328+
blas.MatMul(pre_hidden, true, gate_g, false, static_cast<T>(1.0),
329+
weight_g, static_cast<T>(1.0));
332330
}
333331
} else {
334332
if (h0 && weight_g) {
335333
ReorderInitState<DeviceContext, T>(device_ctx, *h0, order,
336334
&ordered_h0, true);
337-
math::matmul<DeviceContext, T>(device_ctx, ordered_h0, true, gate_g,
338-
false, static_cast<T>(1.0), weight_g,
339-
static_cast<T>(1.0));
335+
blas.MatMul(ordered_h0, true, gate_g, false, static_cast<T>(1.0),
336+
weight_g, static_cast<T>(1.0));
340337
}
341338
if (h0 && h0_g) {
342339
ordered_h0_g.mutable_data<T>(h0_g->dims(), ctx.GetPlace());
343-
math::matmul<DeviceContext, T>(device_ctx, gate_g, false, *weight,
344-
true, static_cast<T>(1.0),
345-
&ordered_h0_g, static_cast<T>(0.0));
340+
blas.MatMul(gate_g, false, *weight, true, static_cast<T>(1.0),
341+
&ordered_h0_g, static_cast<T>(0.0));
346342
}
347343
}
348344
}

paddle/fluid/operators/lstmp_op.h

Lines changed: 26 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ class LSTMPKernel : public framework::OpKernel<T> {
143143
auto proj_act = math::detail::GetActivationType(
144144
ctx.Attr<std::string>("proj_activation"));
145145
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
146-
146+
auto blas = math::GetBlas<DeviceContext, T>(device_ctx);
147147
for (size_t n = 0; n < num_batch; n++) {
148148
int bstart = static_cast<int>(batch_starts[n]);
149149
int bend = static_cast<int>(batch_starts[n + 1]);
@@ -160,9 +160,8 @@ class LSTMPKernel : public framework::OpKernel<T> {
160160
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
161161
int pre_h_end = pre_h_start + cur_batch_size;
162162
auto pre_proj_t = batch_proj.Slice(pre_h_start, pre_h_end);
163-
math::matmul<DeviceContext, T>(device_ctx, pre_proj_t, false, *weight,
164-
false, static_cast<T>(1.0), &gate_t,
165-
static_cast<T>(1.0));
163+
blas.MatMul(pre_proj_t, false, *weight, false, static_cast<T>(1.0),
164+
&gate_t, static_cast<T>(1.0));
166165
} else if (hidden_t0) {
167166
// If n == 0 and there is no initialized hidden state, that is to say
168167
// the H0 is zeros, the calculation W_h * H0 will be skiped.
@@ -176,16 +175,14 @@ class LSTMPKernel : public framework::OpKernel<T> {
176175
ordered_proj0->mutable_data<T>(ctx.GetPlace());
177176
ReorderInitState<DeviceContext, T>(device_ctx, *hidden_t0, order,
178177
&ordered_h0, true);
179-
math::matmul<DeviceContext, T>(device_ctx, ordered_h0, false,
180-
*proj_weight, false, static_cast<T>(1.0),
181-
ordered_proj0, static_cast<T>(0.0));
178+
blas.MatMul(ordered_h0, false, *proj_weight, false, static_cast<T>(1.0),
179+
ordered_proj0, static_cast<T>(0.0));
182180
if (proj_act != math::detail::ActivationType::kIdentity) {
183181
auto proj0_dev = EigenMatrix<T>::From(*ordered_proj0);
184182
ActCompute(cell_act, place, proj0_dev, proj0_dev);
185183
}
186-
math::matmul<DeviceContext, T>(device_ctx, *ordered_proj0, false,
187-
*weight, false, static_cast<T>(1.0),
188-
&gate_t, static_cast<T>(1.0));
184+
blas.MatMul(*ordered_proj0, false, *weight, false, static_cast<T>(1.0),
185+
&gate_t, static_cast<T>(1.0));
189186
}
190187

191188
lstmp_value.gate_value = gate_t.data<T>();
@@ -196,9 +193,8 @@ class LSTMPKernel : public framework::OpKernel<T> {
196193
device_ctx, lstmp_value, frame_size, cur_batch_size, gate_act,
197194
cell_act, cand_act);
198195
lstmp_value.prev_state_value = lstmp_value.state_value;
199-
math::matmul<DeviceContext, T>(device_ctx, hidden_t, false, *proj_weight,
200-
false, static_cast<T>(1.0), &proj_t,
201-
static_cast<T>(0.0));
196+
blas.MatMul(hidden_t, false, *proj_weight, false, static_cast<T>(1.0),
197+
&proj_t, static_cast<T>(0.0));
202198
if (proj_act != math::detail::ActivationType::kIdentity) {
203199
auto proj_t_dev = EigenMatrix<T>::From(proj_t);
204200
ActCompute(cell_act, place, proj_t_dev, proj_t_dev);
@@ -361,6 +357,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
361357

362358
auto batch_starts = batch_gate->lod()[0];
363359
size_t num_batch = batch_starts.size() - 1;
360+
auto blas = math::GetBlas<DeviceContext, T>(device_ctx);
364361
for (int n = static_cast<int>(num_batch) - 1; n >= 0; n--) {
365362
int bstart = static_cast<int>(batch_starts[n]);
366363
int bend = static_cast<int>(batch_starts[n + 1]);
@@ -375,15 +372,13 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
375372
}
376373
/* hidden state backwarad */
377374
Tensor out_g = batch_hidden_g.Slice(bstart, bend);
378-
math::matmul<DeviceContext, T>(device_ctx, proj_g, false, *proj_weight,
379-
true, static_cast<T>(1.0), &out_g,
380-
static_cast<T>(0.0));
375+
blas.MatMul(proj_g, false, *proj_weight, true, static_cast<T>(1.0),
376+
&out_g, static_cast<T>(0.0));
381377
/* projection weight backward*/
382378
if (proj_weight_g) {
383379
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
384-
math::matmul<DeviceContext, T>(device_ctx, hidden_t, true, proj_g,
385-
false, static_cast<T>(1.0),
386-
proj_weight_g, static_cast<T>(1.0));
380+
blas.MatMul(hidden_t, true, proj_g, false, static_cast<T>(1.0),
381+
proj_weight_g, static_cast<T>(1.0));
387382
}
388383

389384
Tensor gate = batch_gate->Slice(bstart, bend);
@@ -419,49 +414,43 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
419414
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
420415
int pre_h_end = pre_h_start + cur_batch_size;
421416
auto pre_proj_g = batch_proj_g.Slice(pre_h_start, pre_h_end);
422-
math::matmul<DeviceContext, T>(device_ctx, gate_g, false, *weight, true,
423-
static_cast<T>(1.0), &pre_proj_g,
424-
static_cast<T>(1.0));
417+
blas.MatMul(gate_g, false, *weight, true, static_cast<T>(1.0),
418+
&pre_proj_g, static_cast<T>(1.0));
425419
if (weight_g) {
426420
/* weight backward*/
427421
auto pre_proj = batch_proj.Slice(pre_h_start, pre_h_end);
428-
math::matmul<DeviceContext, T>(device_ctx, pre_proj, true, gate_g,
429-
false, static_cast<T>(1.0), weight_g,
430-
static_cast<T>(1.0));
422+
blas.MatMul(pre_proj, true, gate_g, false, static_cast<T>(1.0),
423+
weight_g, static_cast<T>(1.0));
431424
}
432425
} else {
433426
if (h0 && weight_g) {
434427
ReorderInitState<DeviceContext, T>(device_ctx, *h0, order,
435428
&ordered_h0, true);
436429
if (weight_g) {
437-
math::matmul<DeviceContext, T>(device_ctx, *ordered_proj0, true,
438-
gate_g, false, static_cast<T>(1.0),
439-
weight_g, static_cast<T>(1.0));
430+
blas.MatMul(*ordered_proj0, true, gate_g, false,
431+
static_cast<T>(1.0), weight_g, static_cast<T>(1.0));
440432
}
441433
}
442434
if (h0 && (h0_g || proj_weight_g)) {
443435
ordered_h0_g.mutable_data<T>(h0_g->dims(), ctx.GetPlace());
444436
Tensor proj0_g;
445437
proj0_g.Resize({in_dims[0], proj_weight->dims()[1]});
446438
proj0_g.mutable_data<T>(ctx.GetPlace());
447-
math::matmul<DeviceContext, T>(device_ctx, gate_g, false, *weight,
448-
true, static_cast<T>(1.0), &proj0_g,
449-
static_cast<T>(0.0));
439+
blas.MatMul(gate_g, false, *weight, true, static_cast<T>(1.0),
440+
&proj0_g, static_cast<T>(0.0));
450441
if (proj_act != math::detail::ActivationType::kIdentity) {
451442
auto proj0_dev = EigenMatrix<T>::From(*ordered_proj0);
452443
auto proj0_g_dev = EigenMatrix<T>::From(proj0_g);
453444
ActGradCompute(cell_act, place, proj0_dev, proj0_dev, proj0_g_dev,
454445
proj0_g_dev);
455446
}
456447
if (h0_g) {
457-
math::matmul<DeviceContext, T>(
458-
device_ctx, proj0_g, false, *proj_weight, true,
459-
static_cast<T>(1.0), &ordered_h0_g, static_cast<T>(0.0));
448+
blas.MatMul(proj0_g, false, *proj_weight, true, static_cast<T>(1.0),
449+
&ordered_h0_g, static_cast<T>(0.0));
460450
}
461451
if (proj_weight_g) {
462-
math::matmul<DeviceContext, T>(device_ctx, ordered_h0, true,
463-
proj0_g, false, static_cast<T>(1.0),
464-
proj_weight_g, static_cast<T>(1.0));
452+
blas.MatMul(ordered_h0, true, proj0_g, false, static_cast<T>(1.0),
453+
proj_weight_g, static_cast<T>(1.0));
465454
}
466455
}
467456
}

0 commit comments

Comments
 (0)