Skip to content

Commit fcd31d6

Browse files
committed
Follow comments and polish code names
1 parent 0a13d3c commit fcd31d6

File tree

5 files changed

+323
-292
lines changed

5 files changed

+323
-292
lines changed

paddle/fluid/operators/math/blas.cc

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,34 +18,26 @@
1818
namespace paddle {
1919
namespace operators {
2020
namespace math {
21-
MatDescriptor GetMatDim(const framework::DDim& dim, int num_flatten_cols,
22-
bool trans) {
21+
MatDescriptor CreateMatrixDescriptor(const framework::DDim &tensor_dim,
22+
int num_flatten_cols, bool trans) {
23+
PADDLE_ENFORCE_GT(tensor_dim.size(), 1);
2324
MatDescriptor retv;
2425
if (num_flatten_cols > 1) {
25-
auto flatten_dim = framework::flatten_to_2d(dim, num_flatten_cols);
26+
auto flatten_dim = framework::flatten_to_2d(tensor_dim, num_flatten_cols);
2627
retv.height_ = flatten_dim[0];
2728
retv.width_ = flatten_dim[1];
2829
} else {
29-
if (dim.size() == 1) {
30-
retv.height_ = 1;
31-
retv.width_ = dim[0];
32-
} else if (dim.size() == 2) {
33-
retv.height_ = dim[0];
34-
retv.width_ = dim[1];
30+
if (tensor_dim.size() == 2) {
31+
retv.height_ = tensor_dim[0];
32+
retv.width_ = tensor_dim[1];
3533
} else {
36-
if (dim.size() == 3) {
37-
retv.batch_size_ = dim[0];
38-
retv.height_ = dim[1];
39-
retv.width_ = dim[2];
40-
} else {
41-
auto dim_vec = framework::vectorize(dim);
42-
retv.batch_size_ = 1;
43-
for (size_t i = 0; i < dim_vec.size() - 2; ++i) {
44-
retv.batch_size_ *= dim_vec[i];
45-
retv.height_ = dim_vec[dim_vec.size() - 2];
46-
retv.width_ = dim_vec[dim_vec.size() - 1];
47-
}
34+
auto dim_vec = framework::vectorize(tensor_dim);
35+
retv.batch_size_ = 1;
36+
for (size_t i = 0; i < dim_vec.size() - 2; ++i) {
37+
retv.batch_size_ *= dim_vec[i];
4838
}
39+
retv.height_ = dim_vec[dim_vec.size() - 2];
40+
retv.width_ = dim_vec[dim_vec.size() - 1];
4941
retv.stride_ = retv.height_ * retv.width_;
5042
}
5143
}

paddle/fluid/operators/math/blas.h

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,25 @@ namespace paddle {
4646
namespace operators {
4747
namespace math {
4848

49+
/**
50+
* Matrix Descriptor of a memory buffer.
51+
*
52+
* It is used for Blas::MatMul. MatMul operator can be batched.
53+
* if Mat A is [BatchSize, H, W], Mat B is [BatchSize, H, W]. It will be a
54+
* `batch_size` times of GEMM. The batched GEMM could be faster base on the
55+
* implementation of the blas library. The batch size could be zero. If any
56+
* matrix of `matmul` has a batch size, the will be a batched GEMM, too. e.g.,
57+
* Mat A is [BatchSize, H1, W2], and Mat B [H2, W2], The result matrix wil be
58+
* [BatchSize, H1, W2]
59+
*
60+
* The boolean flag, `trans`, describe the memory is the transpose of matrix or
61+
* not. If the trans is true, the last two dims of matrix are transposed. The
62+
* memory layout of the matrix is [Width, Height] or [BatchSize, Width, Height].
63+
*
64+
* The MatDescriptor is not only the dimension or shape of a matrix, it also
65+
* contains the layout, stride of matrix. It is clearer to have a structure than
66+
* reuse `DDim`.
67+
*/
4968
struct MatDescriptor {
5069
int64_t height_;
5170
int64_t width_;
@@ -54,8 +73,22 @@ struct MatDescriptor {
5473
bool trans_;
5574
};
5675

57-
extern MatDescriptor GetMatDim(const framework::DDim& tensor,
58-
int num_flatten_cols, bool trans);
76+
/**
77+
* Create Matrix Descriptor from a tensor dim, num_flatten_cols, and transpose
78+
* flag
79+
*
80+
* @param tensor_dim: The dimension of the tensor. The rank of this dimension
81+
* must larger than 1.
82+
*
83+
* @param num_flatten_cols: Reshape a tensor to a matrix. The matrix's first
84+
* dimension(column length) will be the product of tensor's first `num_col_dims`
85+
* dimensions. If num_flatten_cols is zero, the first N-2 dimension will be the
86+
* batch_size of descriptor.
87+
*
88+
* @param trans: True if the matrix is transposed.
89+
*/
90+
extern MatDescriptor CreateMatrixDescriptor(const framework::DDim& tensor_dim,
91+
int num_flatten_cols, bool trans);
5992

6093
template <typename DeviceContext>
6194
class Blas {

0 commit comments

Comments
 (0)