@@ -46,6 +46,25 @@ namespace paddle {
46
46
namespace operators {
47
47
namespace math {
48
48
49
+ /* *
50
+ * Matrix Descriptor of a memory buffer.
51
+ *
52
+ * It is used for Blas::MatMul. MatMul operator can be batched.
53
+ * if Mat A is [BatchSize, H, W], Mat B is [BatchSize, H, W]. It will be a
54
+ * `batch_size` times of GEMM. The batched GEMM could be faster base on the
55
+ * implementation of the blas library. The batch size could be zero. If any
56
+ * matrix of `matmul` has a batch size, the will be a batched GEMM, too. e.g.,
57
+ * Mat A is [BatchSize, H1, W2], and Mat B [H2, W2], The result matrix wil be
58
+ * [BatchSize, H1, W2]
59
+ *
60
+ * The boolean flag, `trans`, describe the memory is the transpose of matrix or
61
+ * not. If the trans is true, the last two dims of matrix are transposed. The
62
+ * memory layout of the matrix is [Width, Height] or [BatchSize, Width, Height].
63
+ *
64
+ * The MatDescriptor is not only the dimension or shape of a matrix, it also
65
+ * contains the layout, stride of matrix. It is clearer to have a structure than
66
+ * reuse `DDim`.
67
+ */
49
68
struct MatDescriptor {
50
69
int64_t height_;
51
70
int64_t width_;
@@ -54,8 +73,22 @@ struct MatDescriptor {
54
73
bool trans_;
55
74
};
56
75
57
- extern MatDescriptor GetMatDim (const framework::DDim& tensor,
58
- int num_flatten_cols, bool trans);
76
+ /* *
77
+ * Create Matrix Descriptor from a tensor dim, num_flatten_cols, and transpose
78
+ * flag
79
+ *
80
+ * @param tensor_dim: The dimension of the tensor. The rank of this dimension
81
+ * must larger than 1.
82
+ *
83
+ * @param num_flatten_cols: Reshape a tensor to a matrix. The matrix's first
84
+ * dimension(column length) will be the product of tensor's first `num_col_dims`
85
+ * dimensions. If num_flatten_cols is zero, the first N-2 dimension will be the
86
+ * batch_size of descriptor.
87
+ *
88
+ * @param trans: True if the matrix is transposed.
89
+ */
90
+ extern MatDescriptor CreateMatrixDescriptor (const framework::DDim& tensor_dim,
91
+ int num_flatten_cols, bool trans);
59
92
60
93
template <typename DeviceContext>
61
94
class Blas {
0 commit comments