Skip to content

Commit 0a13d3c

Browse files
committed
Move MatMul to blas_impl.h
Rename MatDim to MatDescriptor
1 parent 3dd0182 commit 0a13d3c

File tree

4 files changed

+35
-26
lines changed

4 files changed

+35
-26
lines changed

paddle/fluid/operators/math/blas.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
namespace paddle {
1919
namespace operators {
2020
namespace math {
21-
MatDim GetMatDim(const framework::DDim& dim, int num_flatten_cols, bool trans) {
22-
MatDim retv;
21+
MatDescriptor GetMatDim(const framework::DDim& dim, int num_flatten_cols,
22+
bool trans) {
23+
MatDescriptor retv;
2324
if (num_flatten_cols > 1) {
2425
auto flatten_dim = framework::flatten_to_2d(dim, num_flatten_cols);
2526
retv.height_ = flatten_dim[0];

paddle/fluid/operators/math/blas.h

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,16 @@ namespace paddle {
4646
namespace operators {
4747
namespace math {
4848

49-
struct MatDim {
49+
struct MatDescriptor {
5050
int64_t height_;
5151
int64_t width_;
5252
int64_t stride_{0};
5353
int64_t batch_size_{0};
5454
bool trans_;
5555
};
5656

57-
extern MatDim GetMatDim(const framework::DDim& tensor, int num_flatten_cols,
58-
bool trans);
57+
extern MatDescriptor GetMatDim(const framework::DDim& tensor,
58+
int num_flatten_cols, bool trans);
5959

6060
template <typename DeviceContext>
6161
class Blas {
@@ -102,26 +102,9 @@ class Blas {
102102
int batchCount, int64_t strideA, int64_t strideB) const;
103103

104104
template <typename T>
105-
void MatMul(const framework::Tensor& mat_a, const MatDim& dim_a,
106-
const framework::Tensor& mat_b, const MatDim& dim_b, T alpha,
107-
framework::Tensor* mat_out, T beta) const {
108-
PADDLE_ENFORCE_EQ(dim_a.width_, dim_b.height_);
109-
CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans;
110-
CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans;
111-
if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) {
112-
this->template GEMM<T>(transA, transB, dim_a.height_, dim_b.width_,
113-
dim_a.width_, alpha, mat_a.data<T>(),
114-
mat_b.data<T>(), beta, mat_out->data<T>());
115-
} else {
116-
PADDLE_ENFORCE(dim_a.batch_size_ == dim_b.batch_size_ ||
117-
dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0);
118-
this->template BatchedGEMM<T>(
119-
transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha,
120-
mat_a.data<T>(), mat_b.data<T>(), beta, mat_out->data<T>(),
121-
dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_,
122-
dim_a.stride_, dim_b.stride_);
123-
}
124-
}
105+
void MatMul(const framework::Tensor& mat_a, const MatDescriptor& dim_a,
106+
const framework::Tensor& mat_b, const MatDescriptor& dim_b,
107+
T alpha, framework::Tensor* mat_out, T beta) const;
125108

126109
private:
127110
const DeviceContext& context_;

paddle/fluid/operators/math/blas_impl.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,31 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
180180
#endif
181181
}
182182

183+
template <typename DeviceContext>
184+
template <typename T>
185+
void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
186+
const MatDescriptor &dim_a,
187+
const framework::Tensor &mat_b,
188+
const MatDescriptor &dim_b, T alpha,
189+
framework::Tensor *mat_out, T beta) const {
190+
PADDLE_ENFORCE_EQ(dim_a.width_, dim_b.height_);
191+
CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans;
192+
CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans;
193+
if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) {
194+
this->template GEMM<T>(transA, transB, dim_a.height_, dim_b.width_,
195+
dim_a.width_, alpha, mat_a.data<T>(),
196+
mat_b.data<T>(), beta, mat_out->data<T>());
197+
} else {
198+
PADDLE_ENFORCE(dim_a.batch_size_ == dim_b.batch_size_ ||
199+
dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0);
200+
this->template BatchedGEMM<T>(
201+
transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha,
202+
mat_a.data<T>(), mat_b.data<T>(), beta, mat_out->data<T>(),
203+
dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_,
204+
dim_a.stride_, dim_b.stride_);
205+
}
206+
}
207+
183208
} // namespace math
184209
} // namespace operators
185210
} // namespace paddle

paddle/fluid/operators/matmul_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ inline framework::Tensor CombineBatchAndN(const DeviceContext& context,
9191
}
9292

9393
inline void NormalizeTensorShape(framework::Tensor* x,
94-
const math::MatDim& mat_dim_x) {
94+
const math::MatDescriptor& mat_dim_x) {
9595
int64_t h, w;
9696
h = mat_dim_x.height_;
9797
w = mat_dim_x.width_;

0 commit comments

Comments
 (0)