Skip to content

Commit 0285a2b

Browse files
authored
Merge pull request #10371 from reyoung/refine_code
Polish MatMul, clean copy & paste code
2 parents 67f42cc + ef6ea79 commit 0285a2b

23 files changed

+526
-671
lines changed

paddle/fluid/operators/bilinear_tensor_product_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ limitations under the License. */
1616

1717
#include "paddle/fluid/framework/eigen.h"
1818
#include "paddle/fluid/framework/op_registry.h"
19-
#include "paddle/fluid/operators/math/math_function.h"
19+
#include "paddle/fluid/operators/math/blas.h"
2020

2121
namespace paddle {
2222
namespace operators {

paddle/fluid/operators/conv_op.h

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ limitations under the License. */
1717
#include <vector>
1818
#include "paddle/fluid/framework/eigen.h"
1919
#include "paddle/fluid/framework/op_registry.h"
20+
#include "paddle/fluid/operators/math/blas.h"
2021
#include "paddle/fluid/operators/math/depthwise_conv.h"
2122
#include "paddle/fluid/operators/math/im2col.h"
22-
#include "paddle/fluid/operators/math/math_function.h"
2323
#include "paddle/fluid/operators/math/vol2col.h"
2424

2525
namespace paddle {
@@ -161,6 +161,7 @@ class GemmConvKernel : public framework::OpKernel<T> {
161161
math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
162162

163163
auto& dev_ctx = context.template device_context<DeviceContext>();
164+
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
164165
for (int i = 0; i < batch_size; i++) {
165166
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
166167
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
@@ -186,8 +187,7 @@ class GemmConvKernel : public framework::OpKernel<T> {
186187
// gemm
187188
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
188189
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
189-
math::matmul<DeviceContext, T>(dev_ctx, filter_slice, false, col_matrix,
190-
false, T(1.0), &out_slice, T(0.0));
190+
blas.MatMul(filter_slice, col_matrix, &out_slice);
191191
}
192192
}
193193
}
@@ -274,6 +274,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
274274

275275
math::SetConstant<DeviceContext, T> set_zero;
276276
auto& dev_ctx = context.template device_context<DeviceContext>();
277+
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
277278

278279
if (input_grad) {
279280
input_grad->mutable_data<T>(context.GetPlace());
@@ -303,9 +304,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
303304
col_matrix.ShareDataWith(in_grad_slice);
304305
col_matrix.Resize(col_matrix_shape);
305306
}
306-
math::matmul<DeviceContext, T>(dev_ctx, filter_slice, true,
307-
out_grad_slice, false, T(1.0),
308-
&col_matrix, T(0.0));
307+
blas.MatMul(filter_slice, true, out_grad_slice, false, &col_matrix);
309308

310309
if (is_expand && data_dim == 2U) {
311310
col2im(dev_ctx, col, dilations, strides,
@@ -352,9 +351,8 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
352351
// gemm
353352
Tensor filter_grad_slice =
354353
filter_grad_.Slice(g * out_step, (g + 1) * out_step);
355-
math::matmul<DeviceContext, T>(dev_ctx, out_grad_slice, false,
356-
col_matrix, true, T(1.0),
357-
&filter_grad_slice, T(1.0));
354+
blas.MatMul(out_grad_slice, false, col_matrix, true,
355+
&filter_grad_slice);
358356
}
359357
}
360358
}

paddle/fluid/operators/conv_transpose_op.h

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ limitations under the License. */
1616
#include <vector>
1717
#include "paddle/fluid/framework/eigen.h"
1818
#include "paddle/fluid/framework/op_registry.h"
19+
#include "paddle/fluid/operators/math/blas.h"
1920
#include "paddle/fluid/operators/math/im2col.h"
20-
#include "paddle/fluid/operators/math/math_function.h"
2121
#include "paddle/fluid/operators/math/vol2col.h"
2222

2323
namespace paddle {
@@ -118,6 +118,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
118118
output->mutable_data<T>(context.GetPlace());
119119
math::SetConstant<DeviceContext, T> set_zero;
120120
auto& dev_ctx = context.template device_context<DeviceContext>();
121+
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
121122
set_zero(dev_ctx, output, static_cast<T>(0));
122123

123124
math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
@@ -134,9 +135,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
134135

135136
// col_matrix = filter * input_batch
136137
// of shape (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
137-
math::matmul<DeviceContext, T>(dev_ctx, filter, true, input_batch, false,
138-
static_cast<T>(1.0), &col_matrix,
139-
static_cast<T>(0.0));
138+
blas.MatMul(filter, true, input_batch, false, &col_matrix);
140139

141140
if (data_dim == 2U) {
142141
// col2im: col_matrix -> dy
@@ -213,6 +212,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
213212
// im2col + gemm (similar to conv-forward)
214213
// input need to compute gradient
215214
auto& dev_ctx = context.template device_context<DeviceContext>();
215+
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
216216
if (input_grad || filter_grad) {
217217
Tensor col;
218218
col.mutable_data<T>(col_shape, context.GetPlace());
@@ -267,9 +267,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
267267
// or
268268
// (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m,
269269
// d, h, w)
270-
math::matmul<DeviceContext, T>(
271-
dev_ctx, filter, false, col_matrix, false, static_cast<T>(1.0),
272-
&input_grad_batch, static_cast<T>(0.0));
270+
blas.MatMul(filter, false, col_matrix, false, &input_grad_batch);
273271
}
274272
if (filter_grad) {
275273
// input batch
@@ -279,9 +277,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
279277
// or
280278
// (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d *
281279
// k_h * k_w)
282-
math::matmul<DeviceContext, T>(dev_ctx, in_batch, false, col_matrix,
283-
true, static_cast<T>(1.0),
284-
&filter_grad_, static_cast<T>(1.0));
280+
blas.MatMul(in_batch, false, col_matrix, true, &filter_grad_);
285281
}
286282
}
287283
}

paddle/fluid/operators/gru_unit_op.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@ limitations under the License. */
1414

1515
#pragma once
1616

17-
#include "paddle/fluid/operators/activation_op.h"
18-
#include "paddle/fluid/operators/math/math_function.h"
19-
2017
#include "paddle/fluid/framework/eigen.h"
2118
#include "paddle/fluid/framework/op_registry.h"
19+
#include "paddle/fluid/operators/activation_op.h"
20+
#include "paddle/fluid/operators/math/blas.h"
2221

2322
namespace paddle {
2423
namespace operators {

paddle/fluid/operators/layer_norm_op.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ limitations under the License. */
1515
#pragma once
1616
#include "paddle/fluid/framework/eigen.h"
1717
#include "paddle/fluid/framework/op_registry.h"
18-
1918
#include "paddle/fluid/operators/elementwise_op_function.h"
19+
#include "paddle/fluid/operators/math/blas.h"
2020
#include "paddle/fluid/operators/math/math_function.h"
2121

2222
namespace paddle {
@@ -46,9 +46,9 @@ class RowwiseMean2D<platform::CUDADeviceContext, T> {
4646
}
4747
void operator()(const platform::CUDADeviceContext& context,
4848
const framework::Tensor& input, framework::Tensor* out) {
49-
math::gemv<platform::CUDADeviceContext, T>(
50-
context, false, left_, right_, 1., input.data<T>(), divisor_.data<T>(),
51-
0., out->data<T>());
49+
math::GetBlas<platform::CUDADeviceContext, T>(context).GEMV(
50+
false, left_, right_, 1., input.data<T>(), divisor_.data<T>(), 0.,
51+
out->data<T>());
5252
}
5353

5454
private:
@@ -93,9 +93,9 @@ class ColwiseSum2D<platform::CUDADeviceContext, T> {
9393

9494
void operator()(const platform::CUDADeviceContext& context,
9595
const framework::Tensor& input, framework::Tensor* out) {
96-
math::gemv<platform::CUDADeviceContext, T>(
97-
context, true, left_, right_, 1., input.data<T>(), divisor_.data<T>(),
98-
0., out->data<T>());
96+
math::GetBlas<platform::CUDADeviceContext, T>(context).GEMV(
97+
true, left_, right_, 1., input.data<T>(), divisor_.data<T>(), 0.,
98+
out->data<T>());
9999
}
100100

101101
private:

paddle/fluid/operators/lstm_op.h

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ limitations under the License. */
1515
#pragma once
1616
#include <string>
1717
#include "paddle/fluid/framework/op_registry.h"
18+
#include "paddle/fluid/operators/math/blas.h"
1819
#include "paddle/fluid/operators/math/detail/activation_functions.h"
1920
#include "paddle/fluid/operators/math/lstm_compute.h"
20-
#include "paddle/fluid/operators/math/math_function.h"
2121
#include "paddle/fluid/operators/math/sequence2batch.h"
2222

2323
namespace paddle {
@@ -114,6 +114,7 @@ class LSTMKernel : public framework::OpKernel<T> {
114114
auto cand_act = math::detail::GetActivationType(
115115
ctx.Attr<std::string>("candidate_activation"));
116116

117+
auto blas = math::GetBlas<DeviceContext, T>(device_ctx);
117118
for (size_t n = 0; n < num_batch; n++) {
118119
int bstart = static_cast<int>(batch_starts[n]);
119120
int bend = static_cast<int>(batch_starts[n + 1]);
@@ -129,9 +130,8 @@ class LSTMKernel : public framework::OpKernel<T> {
129130
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
130131
int pre_h_end = pre_h_start + cur_batch_size;
131132
auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end);
132-
math::matmul<DeviceContext, T>(device_ctx, pre_hidden_t, false, *weight,
133-
false, static_cast<T>(1.0), &gate_t,
134-
static_cast<T>(1.0));
133+
blas.MatMul(pre_hidden_t, false, *weight, false, static_cast<T>(1.0),
134+
&gate_t, static_cast<T>(1.0));
135135
} else if (hidden_t0) {
136136
// If n == 0 and there is no initialized hidden state, that is to say
137137
// the H0 is zeros, the calculation W_h * H0 will be skiped.
@@ -143,9 +143,8 @@ class LSTMKernel : public framework::OpKernel<T> {
143143
Tensor ordered_h0;
144144
ReorderInitState<DeviceContext, T>(device_ctx, *hidden_t0, order,
145145
&ordered_h0, true);
146-
math::matmul<DeviceContext, T>(device_ctx, ordered_h0, false, *weight,
147-
false, static_cast<T>(1.0), &gate_t,
148-
static_cast<T>(1.0));
146+
blas.MatMul(ordered_h0, false, *weight, false, static_cast<T>(1.0),
147+
&gate_t, static_cast<T>(1.0));
149148
}
150149

151150
lstm_value.gate_value = gate_t.data<T>();
@@ -282,6 +281,7 @@ class LSTMGradKernel : public framework::OpKernel<T> {
282281

283282
auto batch_starts = batch_gate->lod()[0];
284283
size_t num_batch = batch_starts.size() - 1;
284+
auto blas = math::GetBlas<DeviceContext, T>(device_ctx);
285285
for (int n = static_cast<int>(num_batch) - 1; n >= 0; n--) {
286286
int bstart = static_cast<int>(batch_starts[n]);
287287
int bend = static_cast<int>(batch_starts[n + 1]);
@@ -320,29 +320,25 @@ class LSTMGradKernel : public framework::OpKernel<T> {
320320
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
321321
int pre_h_end = pre_h_start + cur_batch_size;
322322
auto pre_hidden_g = batch_hidden_g.Slice(pre_h_start, pre_h_end);
323-
math::matmul<DeviceContext, T>(device_ctx, gate_g, false, *weight, true,
324-
static_cast<T>(1.0), &pre_hidden_g,
325-
static_cast<T>(1.0));
323+
blas.MatMul(gate_g, false, *weight, true, static_cast<T>(1.0),
324+
&pre_hidden_g, static_cast<T>(1.0));
326325
if (weight_g) {
327326
/* backward weight */
328327
auto pre_hidden = batch_hidden.Slice(pre_h_start, pre_h_end);
329-
math::matmul<DeviceContext, T>(device_ctx, pre_hidden, true, gate_g,
330-
false, static_cast<T>(1.0), weight_g,
331-
static_cast<T>(1.0));
328+
blas.MatMul(pre_hidden, true, gate_g, false, static_cast<T>(1.0),
329+
weight_g, static_cast<T>(1.0));
332330
}
333331
} else {
334332
if (h0 && weight_g) {
335333
ReorderInitState<DeviceContext, T>(device_ctx, *h0, order,
336334
&ordered_h0, true);
337-
math::matmul<DeviceContext, T>(device_ctx, ordered_h0, true, gate_g,
338-
false, static_cast<T>(1.0), weight_g,
339-
static_cast<T>(1.0));
335+
blas.MatMul(ordered_h0, true, gate_g, false, static_cast<T>(1.0),
336+
weight_g, static_cast<T>(1.0));
340337
}
341338
if (h0 && h0_g) {
342339
ordered_h0_g.mutable_data<T>(h0_g->dims(), ctx.GetPlace());
343-
math::matmul<DeviceContext, T>(device_ctx, gate_g, false, *weight,
344-
true, static_cast<T>(1.0),
345-
&ordered_h0_g, static_cast<T>(0.0));
340+
blas.MatMul(gate_g, false, *weight, true, static_cast<T>(1.0),
341+
&ordered_h0_g, static_cast<T>(0.0));
346342
}
347343
}
348344
}

0 commit comments

Comments
 (0)