Skip to content

Commit c6a6d87

Browse files
committed
Rewrite Matmul, make code cleaner
1 parent 0285a2b commit c6a6d87

File tree

6 files changed

+258
-418
lines changed

6 files changed

+258
-418
lines changed

paddle/fluid/operators/math/blas.cc

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,47 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/operators/math/blas.h"
16+
17+
#include <utility>
1618
namespace paddle {
1719
namespace operators {
1820
namespace math {
19-
// Do nothing. Blas is a header only library.
21+
MatDim GetMatDim(const framework::DDim& dim, int num_flatten_cols, bool trans) {
22+
MatDim retv;
23+
if (num_flatten_cols > 1) {
24+
auto flatten_dim = framework::flatten_to_2d(dim, num_flatten_cols);
25+
retv.height_ = flatten_dim[0];
26+
retv.width_ = flatten_dim[1];
27+
} else {
28+
if (dim.size() == 1) {
29+
retv.height_ = 1;
30+
retv.width_ = dim[0];
31+
} else if (dim.size() == 2) {
32+
retv.height_ = dim[0];
33+
retv.width_ = dim[1];
34+
} else {
35+
if (dim.size() == 3) {
36+
retv.batch_size_ = dim[0];
37+
retv.height_ = dim[1];
38+
retv.width_ = dim[2];
39+
} else {
40+
auto dim_vec = framework::vectorize(dim);
41+
retv.batch_size_ = 1;
42+
for (size_t i = 0; i < dim_vec.size() - 2; ++i) {
43+
retv.batch_size_ *= dim_vec[i];
44+
retv.height_ = dim_vec[dim_vec.size() - 2];
45+
retv.width_ = dim_vec[dim_vec.size() - 1];
46+
}
47+
}
48+
retv.stride_ = retv.height_ * retv.width_;
49+
}
50+
}
51+
if (trans) {
52+
std::swap(retv.width_, retv.height_);
53+
}
54+
retv.trans_ = trans;
55+
return retv;
56+
}
2057
} // namespace math
2158
} // namespace operators
2259
} // namespace paddle

paddle/fluid/operators/math/blas.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,17 @@ namespace paddle {
4646
namespace operators {
4747
namespace math {
4848

49+
struct MatDim {
50+
int64_t height_;
51+
int64_t width_;
52+
int64_t stride_{0};
53+
int64_t batch_size_{0};
54+
bool trans_;
55+
};
56+
57+
extern MatDim GetMatDim(const framework::DDim& tensor, int num_flatten_cols,
58+
bool trans);
59+
4960
template <typename DeviceContext>
5061
class Blas {
5162
public:
@@ -90,6 +101,28 @@ class Blas {
90101
int K, T alpha, const T* A, const T* B, T beta, T* C,
91102
int batchCount, int64_t strideA, int64_t strideB) const;
92103

104+
template <typename T>
105+
void MatMul(const framework::Tensor& mat_a, const MatDim& dim_a,
106+
const framework::Tensor& mat_b, const MatDim& dim_b, T alpha,
107+
framework::Tensor* mat_out, T beta) const {
108+
PADDLE_ENFORCE_EQ(dim_a.width_, dim_b.height_);
109+
CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans;
110+
CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans;
111+
if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) {
112+
this->template GEMM<T>(transA, transB, dim_a.height_, dim_b.width_,
113+
dim_a.width_, alpha, mat_a.data<T>(),
114+
mat_b.data<T>(), beta, mat_out->data<T>());
115+
} else {
116+
PADDLE_ENFORCE(dim_a.batch_size_ == dim_b.batch_size_ ||
117+
dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0);
118+
this->template BatchedGEMM<T>(
119+
transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha,
120+
mat_a.data<T>(), mat_b.data<T>(), beta, mat_out->data<T>(),
121+
dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_,
122+
dim_a.stride_, dim_b.stride_);
123+
}
124+
}
125+
93126
private:
94127
const DeviceContext& context_;
95128
};

paddle/fluid/operators/math/matmul.h

Lines changed: 0 additions & 149 deletions
This file was deleted.

paddle/fluid/operators/matmul_op.cc

Lines changed: 26 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -36,121 +36,39 @@ class MatMulOp : public framework::OperatorWithKernel {
3636

3737
auto dim_x = context->GetInputDim("X");
3838
auto dim_y = context->GetInputDim("Y");
39-
bool transpose_x = context->Attrs().Get<bool>("transpose_X");
40-
bool transpose_y = context->Attrs().Get<bool>("transpose_Y");
41-
42-
PADDLE_ENFORCE_GE(dim_x.size(), 1,
43-
"Input tensor X must be at least 1-dimensional.");
44-
PADDLE_ENFORCE_GE(dim_y.size(), 1,
45-
"Input tensor Y must be at least 1-dimensional.");
46-
47-
std::vector<int64_t> out_dim;
48-
int64_t batch_count = 1;
49-
if (dim_x.size() > 3) {
50-
PADDLE_ENFORCE_EQ(
51-
dim_y.size(), dim_x.size(),
52-
"The dimensions of X and Y must be the same, and both of "
53-
"them should be %d-dimensional.",
54-
dim_x.size());
55-
56-
// The first rank-2 dimensions are accumulated on the batch_count, and the
57-
// last two dimensions are used for matrix multiplication.
58-
for (int j = 0; j < dim_x.size() - 2; ++j) {
59-
PADDLE_ENFORCE_EQ(dim_y[j], dim_x[j],
60-
"The %d-th dimension of X and Y must be the same.",
61-
j);
62-
out_dim.push_back(dim_x[j]);
63-
batch_count *= dim_x[j];
64-
}
65-
}
6639

67-
int M = 0, N = 0, KX = 0, KY = 0, batchCountX = 0, batchCountY = 0;
68-
bool remove_initial_dim = false, remove_final_dim = false;
69-
70-
switch (dim_x.size()) {
71-
case 1:
72-
if (transpose_x) {
73-
M = dim_x[0];
74-
KX = 1;
75-
} else {
76-
M = 1;
77-
KX = dim_x[0];
78-
remove_initial_dim = true;
79-
}
80-
break;
81-
case 2:
82-
M = transpose_x ? dim_x[1] : dim_x[0];
83-
KX = transpose_x ? dim_x[0] : dim_x[1];
84-
break;
85-
case 3:
86-
batchCountX = dim_x[0];
87-
M = transpose_x ? dim_x[2] : dim_x[1];
88-
KX = transpose_x ? dim_x[1] : dim_x[2];
89-
break;
90-
default:
91-
batchCountX = batch_count;
92-
size_t mat_s = dim_x.size() - 2;
93-
M = transpose_x ? dim_x[mat_s + 1] : dim_x[mat_s];
94-
KX = transpose_x ? dim_x[mat_s] : dim_x[mat_s + 1];
95-
break;
96-
}
40+
auto mat_dim_x = math::GetMatDim(GetXDim(dim_x), 0,
41+
context->Attrs().Get<bool>("transpose_X"));
42+
auto mat_dim_y = math::GetMatDim(GetYDim(dim_y), 0,
43+
context->Attrs().Get<bool>("transpose_Y"));
9744

98-
switch (dim_y.size()) {
99-
case 1:
100-
if (transpose_y) {
101-
N = dim_y[0];
102-
KY = 1;
103-
} else {
104-
N = 1;
105-
KY = dim_y[0];
106-
remove_final_dim = true;
107-
}
108-
break;
109-
case 2:
110-
KY = transpose_y ? dim_y[1] : dim_y[0];
111-
N = transpose_y ? dim_y[0] : dim_y[1];
112-
break;
113-
case 3:
114-
batchCountY = dim_y[0];
115-
KY = transpose_y ? dim_y[2] : dim_y[1];
116-
N = transpose_y ? dim_y[1] : dim_y[2];
117-
break;
118-
default:
119-
batchCountY = batch_count;
120-
size_t mat_s = dim_y.size() - 2;
121-
KY = transpose_y ? dim_y[mat_s + 1] : dim_y[mat_s];
122-
N = transpose_y ? dim_y[mat_s] : dim_y[mat_s + 1];
45+
PADDLE_ENFORCE_EQ(mat_dim_x.width_, mat_dim_y.height_);
46+
PADDLE_ENFORCE(mat_dim_x.batch_size_ == mat_dim_y.batch_size_ ||
47+
mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0);
48+
std::vector<int64_t> dim_out;
49+
if (mat_dim_x.batch_size_ != 0) {
50+
dim_out = framework::vectorize(dim_x);
51+
dim_out[dim_out.size() - 2] = mat_dim_x.height_;
52+
dim_out[dim_out.size() - 1] = mat_dim_y.width_;
53+
} else if (mat_dim_y.batch_size_ != 0) {
54+
dim_out = framework::vectorize(dim_y);
55+
dim_out[dim_out.size() - 2] = mat_dim_x.height_;
56+
dim_out[dim_out.size() - 1] = mat_dim_y.width_;
57+
} else {
58+
dim_out = {mat_dim_x.height_, mat_dim_y.width_};
12359
}
12460

125-
PADDLE_ENFORCE_EQ(
126-
KX, KY,
127-
"First matrix's width must be equal with second matrix's height.");
128-
if (batchCountX && batchCountY) {
129-
PADDLE_ENFORCE_EQ(
130-
batchCountX, batchCountY,
131-
"When Input(X) and Input(Y) are both three dimensional, they "
132-
"must have the same batch dimension.");
61+
if (dim_x.size() == 1 && dim_out[dim_out.size() - 2] == 1) {
62+
std::swap(dim_out[dim_out.size() - 2], dim_out[dim_out.size() - 1]);
63+
dim_out.resize(dim_out.size() - 1);
13364
}
134-
int batchCount = std::max(batchCountX, batchCountY);
13565

136-
std::vector<int64_t> dim_out;
137-
if (batchCount) {
138-
if (dim_x.size() > 3) {
139-
dim_out.insert(dim_out.begin(), out_dim.begin(), out_dim.end());
140-
} else {
141-
dim_out.push_back(batchCount);
142-
}
66+
if (dim_y.size() == 1 && dim_out[dim_out.size() - 1] == 1) {
67+
dim_out.resize(dim_out.size() - 1);
14368
}
144-
if (!remove_initial_dim) {
145-
dim_out.push_back(M);
146-
}
147-
if (!remove_final_dim) {
148-
dim_out.push_back(N);
149-
}
150-
if (dim_out.size() == 0) {
151-
// We don't support 0-dimensional Tensors (scalars), so instead
152-
// treat the output as a Tensor of shape (1, ) in this case.
153-
dim_out.push_back(1);
69+
70+
if (dim_out.empty()) {
71+
dim_out = {1};
15472
}
15573
context->SetOutputDim("Out", framework::make_ddim(dim_out));
15674
context->ShareLoD("X", /*->*/ "Out");

0 commit comments

Comments
 (0)