Skip to content

Commit 705e734

Browse files
authored
Merge pull request #10449 from reyoung/feature/clean_matmul
Rewrite Matmul, make code cleaner
2 parents 3665358 + ad594b9 commit 705e734

File tree

8 files changed

+423
-543
lines changed

8 files changed

+423
-543
lines changed

paddle/fluid/operators/math/blas.cc

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,40 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/operators/math/blas.h"
16+
17+
#include <utility>
1618
namespace paddle {
1719
namespace operators {
1820
namespace math {
19-
// Do nothing. Blas is a header only library.
21+
MatDescriptor CreateMatrixDescriptor(const framework::DDim &tensor_dim,
22+
int num_flatten_cols, bool trans) {
23+
PADDLE_ENFORCE_GT(tensor_dim.size(), 1);
24+
MatDescriptor retv;
25+
if (num_flatten_cols > 1) {
26+
auto flatten_dim = framework::flatten_to_2d(tensor_dim, num_flatten_cols);
27+
retv.height_ = flatten_dim[0];
28+
retv.width_ = flatten_dim[1];
29+
} else {
30+
if (tensor_dim.size() == 2) {
31+
retv.height_ = tensor_dim[0];
32+
retv.width_ = tensor_dim[1];
33+
} else {
34+
auto dim_vec = framework::vectorize(tensor_dim);
35+
retv.batch_size_ = 1;
36+
for (size_t i = 0; i < dim_vec.size() - 2; ++i) {
37+
retv.batch_size_ *= dim_vec[i];
38+
}
39+
retv.height_ = dim_vec[dim_vec.size() - 2];
40+
retv.width_ = dim_vec[dim_vec.size() - 1];
41+
retv.stride_ = retv.height_ * retv.width_;
42+
}
43+
}
44+
if (trans) {
45+
std::swap(retv.width_, retv.height_);
46+
}
47+
retv.trans_ = trans;
48+
return retv;
49+
}
2050
} // namespace math
2151
} // namespace operators
2252
} // namespace paddle

paddle/fluid/operators/math/blas.h

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,50 @@ namespace paddle {
4646
namespace operators {
4747
namespace math {
4848

49+
/**
50+
* Matrix Descriptor of a memory buffer.
51+
*
52+
* It is used for Blas::MatMul. MatMul operator can be batched.
53+
* if Mat A is [BatchSize, H, W], Mat B is [BatchSize, H, W]. It will be a
54+
* `batch_size` times of GEMM. The batched GEMM could be faster base on the
55+
* implementation of the blas library. The batch size could be zero. If any
56+
* matrix of `matmul` has a batch size, the will be a batched GEMM, too. e.g.,
57+
* Mat A is [BatchSize, H1, W2], and Mat B [H2, W2], The result matrix wil be
58+
* [BatchSize, H1, W2]
59+
*
60+
* The boolean flag, `trans`, describe the memory is the transpose of matrix or
61+
* not. If the trans is true, the last two dims of matrix are transposed. The
62+
* memory layout of the matrix is [Width, Height] or [BatchSize, Width, Height].
63+
*
64+
* The MatDescriptor is not only the dimension or shape of a matrix, it also
65+
* contains the layout, stride of matrix. It is clearer to have a structure than
66+
* reuse `DDim`.
67+
*/
68+
struct MatDescriptor {
69+
int64_t height_;
70+
int64_t width_;
71+
int64_t stride_{0};
72+
int64_t batch_size_{0};
73+
bool trans_;
74+
};
75+
76+
/**
77+
* Create Matrix Descriptor from a tensor dim, num_flatten_cols, and transpose
78+
* flag
79+
*
80+
* @param tensor_dim: The dimension of the tensor. The rank of this dimension
81+
* must larger than 1.
82+
*
83+
* @param num_flatten_cols: Reshape a tensor to a matrix. The matrix's first
84+
* dimension(column length) will be the product of tensor's first `num_col_dims`
85+
* dimensions. If num_flatten_cols is zero, the first N-2 dimension will be the
86+
* batch_size of descriptor.
87+
*
88+
* @param trans: True if the matrix is transposed.
89+
*/
90+
extern MatDescriptor CreateMatrixDescriptor(const framework::DDim& tensor_dim,
91+
int num_flatten_cols, bool trans);
92+
4993
template <typename DeviceContext>
5094
class Blas {
5195
public:
@@ -90,6 +134,11 @@ class Blas {
90134
int K, T alpha, const T* A, const T* B, T beta, T* C,
91135
int batchCount, int64_t strideA, int64_t strideB) const;
92136

137+
template <typename T>
138+
void MatMul(const framework::Tensor& mat_a, const MatDescriptor& dim_a,
139+
const framework::Tensor& mat_b, const MatDescriptor& dim_b,
140+
T alpha, framework::Tensor* mat_out, T beta) const;
141+
93142
private:
94143
const DeviceContext& context_;
95144
};

paddle/fluid/operators/math/blas_impl.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,31 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
180180
#endif
181181
}
182182

183+
template <typename DeviceContext>
184+
template <typename T>
185+
void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
186+
const MatDescriptor &dim_a,
187+
const framework::Tensor &mat_b,
188+
const MatDescriptor &dim_b, T alpha,
189+
framework::Tensor *mat_out, T beta) const {
190+
PADDLE_ENFORCE_EQ(dim_a.width_, dim_b.height_);
191+
CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans;
192+
CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans;
193+
if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) {
194+
this->template GEMM<T>(transA, transB, dim_a.height_, dim_b.width_,
195+
dim_a.width_, alpha, mat_a.data<T>(),
196+
mat_b.data<T>(), beta, mat_out->data<T>());
197+
} else {
198+
PADDLE_ENFORCE(dim_a.batch_size_ == dim_b.batch_size_ ||
199+
dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0);
200+
this->template BatchedGEMM<T>(
201+
transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha,
202+
mat_a.data<T>(), mat_b.data<T>(), beta, mat_out->data<T>(),
203+
dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_,
204+
dim_a.stride_, dim_b.stride_);
205+
}
206+
}
207+
183208
} // namespace math
184209
} // namespace operators
185210
} // namespace paddle

paddle/fluid/operators/math/matmul.h

Lines changed: 0 additions & 149 deletions
This file was deleted.

0 commit comments

Comments
 (0)