Skip to content

Commit 2924c92

Browse files
authored
Merge pull request #10569 from reyoung/feature/matmul_support_float16_double
matmul support float16/double
2 parents 5ce2df9 + 05a96db commit 2924c92

File tree

4 files changed

+66
-41
lines changed

4 files changed

+66
-41
lines changed

paddle/fluid/operators/math/blas_impl.cu.h

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,22 @@ struct CUBlas<platform::float16> {
9696
reinterpret_cast<__half *>(C), ldc));
9797
}
9898

99-
template <typename... ARGS>
100-
static void GEMM_BATCH(ARGS... args) {
99+
static void GEMM_BATCH(cublasHandle_t handle, cublasOperation_t transa,
100+
cublasOperation_t transb, int m, int n, int k,
101+
const float16 *alpha, const float16 *A, int lda,
102+
long long int strideA, const float16 *B, // NOLINT
103+
int ldb, long long int strideB, // NOLINT
104+
const float16 *beta, float16 *C, int ldc,
105+
long long int strideC, // NOLINT
106+
int batchCount) {
101107
#if CUDA_VERSION >= 8000
102-
PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(args...));
108+
PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
109+
handle, transa, transb, m, n, k,
110+
reinterpret_cast<const __half *>(alpha),
111+
reinterpret_cast<const __half *>(A), lda, strideA,
112+
reinterpret_cast<const __half *>(B), ldb, strideB,
113+
reinterpret_cast<const __half *>(beta), reinterpret_cast<__half *>(C),
114+
ldc, strideC, batchCount));
103115
#else
104116
PADDLE_THROW("HgemmStridedBatched is not supported on cuda <= 7.5");
105117
#endif

paddle/fluid/operators/math/blas_impl.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,9 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
172172
c_array.data(), &ldc, 1 /* group_count */, &batchCount);
173173
#else
174174
for (int k = 0; k < batchCount; ++k) {
175-
const float *Ak = &A[k * strideA];
176-
const float *Bk = &B[k * strideB];
177-
float *Ck = &C[k * M * N];
175+
auto *Ak = &A[k * strideA];
176+
auto *Bk = &B[k * strideB];
177+
auto *Ck = &C[k * M * N];
178178
this->template GEMM<T>(transA, transB, M, N, K, alpha, Ak, Bk, beta, Ck);
179179
}
180180
#endif

paddle/fluid/operators/math/math_function.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@ template struct SetConstant<platform::CUDADeviceContext, int>;
3333
template struct SetConstant<platform::CUDADeviceContext, int64_t>;
3434
template struct SetConstant<platform::CUDADeviceContext, bool>;
3535

36-
#define DEFINE_GPU_TRANS(RANK) \
37-
template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
38-
template struct Transpose<platform::CUDADeviceContext, double, RANK>;
36+
#define DEFINE_GPU_TRANS(RANK) \
37+
template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
38+
template struct Transpose<platform::CUDADeviceContext, double, RANK>; \
39+
template struct Transpose<platform::CUDADeviceContext, float16, RANK>;
3940

4041
DEFINE_GPU_TRANS(1);
4142
DEFINE_GPU_TRANS(2);

paddle/fluid/operators/matmul_op.cc

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ namespace operators {
2525
* Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
2626
* original x_dim is returned.
2727
*/
28-
static framework::DDim RowMatrixFromVector(const framework::DDim& x_dim) {
28+
static framework::DDim RowMatrixFromVector(const framework::DDim &x_dim) {
2929
if (x_dim.size() > 1) {
3030
return x_dim;
3131
}
@@ -36,7 +36,7 @@ static framework::DDim RowMatrixFromVector(const framework::DDim& x_dim) {
3636
* Get column matrix shape from a vector shape. If the ran of y_dim > 1, the
3737
* original y_dim is returned.
3838
*/
39-
static framework::DDim ColumnMatrixFromVector(const framework::DDim& y_dim) {
39+
static framework::DDim ColumnMatrixFromVector(const framework::DDim &y_dim) {
4040
if (y_dim.size() > 1) {
4141
return y_dim;
4242
}
@@ -46,12 +46,12 @@ static framework::DDim ColumnMatrixFromVector(const framework::DDim& y_dim) {
4646
template <typename DeviceContext, typename T>
4747
class MatMulKernel : public framework::OpKernel<T> {
4848
public:
49-
void Compute(const framework::ExecutionContext& context) const override {
50-
auto& x =
49+
void Compute(const framework::ExecutionContext &context) const override {
50+
auto &x =
5151
detail::Ref(context.Input<framework::Tensor>("X"), "Cannot find X");
52-
auto& y =
52+
auto &y =
5353
detail::Ref(context.Input<framework::Tensor>("Y"), "Cannot find Y");
54-
auto* out = context.Output<framework::Tensor>("Out");
54+
auto *out = context.Output<framework::Tensor>("Out");
5555
out->mutable_data<T>(context.GetPlace());
5656

5757
auto blas = math::GetBlas<DeviceContext, T>(context);
@@ -65,7 +65,7 @@ class MatMulKernel : public framework::OpKernel<T> {
6565

6666
// Reshape a rank-3 tensor from P x M x N to (P * M) x N.
6767
// Identity op if the tensor is not of rank 3.
68-
static framework::Tensor FoldInitDims(const framework::Tensor& input) {
68+
static framework::Tensor FoldInitDims(const framework::Tensor &input) {
6969
auto output = input;
7070
auto in_dims = input.dims();
7171
if (in_dims.size() == 3) {
@@ -78,8 +78,8 @@ static framework::Tensor FoldInitDims(const framework::Tensor& input) {
7878
// (Warning: This requires transposing data and writes into new memory.)
7979
// Identity op if the tensor is not of rank 3.
8080
template <typename DeviceContext, typename T>
81-
static framework::Tensor FoldHeadAndLastDims(const DeviceContext& context,
82-
const framework::Tensor& input) {
81+
static framework::Tensor FoldHeadAndLastDims(const DeviceContext &context,
82+
const framework::Tensor &input) {
8383
auto in_dims = input.dims();
8484
if (in_dims.size() != 3) {
8585
return input;
@@ -102,7 +102,7 @@ static framework::Tensor FoldHeadAndLastDims(const DeviceContext& context,
102102
* If transposed, `H,W` will be swapped.
103103
*/
104104
static void ReshapeTensorIntoMatrixSequence(
105-
framework::Tensor* x, const math::MatDescriptor& descriptor) {
105+
framework::Tensor *x, const math::MatDescriptor &descriptor) {
106106
int64_t h, w;
107107
h = descriptor.height_;
108108
w = descriptor.width_;
@@ -130,9 +130,9 @@ static void ReshapeTensorIntoMatrixSequence(
130130
* If any of `X` and `Y` has batch size BatchSize, the out will have the
131131
* BatchSize.
132132
*/
133-
static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x,
134-
framework::Tensor* y,
135-
framework::Tensor* out, bool trans_x,
133+
static void ReshapeXYOutIntoMatrixSequence(framework::Tensor *x,
134+
framework::Tensor *y,
135+
framework::Tensor *out, bool trans_x,
136136
bool trans_y) {
137137
auto x_dim = RowMatrixFromVector(x->dims());
138138
auto y_dim = ColumnMatrixFromVector(y->dims());
@@ -177,29 +177,29 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x,
177177
template <typename DeviceContext, typename T>
178178
class MatMulGradKernel : public framework::OpKernel<T> {
179179
public:
180-
void MatMul(const framework::ExecutionContext& context,
181-
const framework::Tensor& a, bool trans_a,
182-
const framework::Tensor& b, bool trans_b,
183-
framework::Tensor* out) const {
180+
void MatMul(const framework::ExecutionContext &context,
181+
const framework::Tensor &a, bool trans_a,
182+
const framework::Tensor &b, bool trans_b,
183+
framework::Tensor *out) const {
184184
out->mutable_data<T>(context.GetPlace());
185185
auto blas = math::GetBlas<DeviceContext, T>(context);
186186
auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a);
187187
auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b);
188188
blas.MatMul(a, mat_dim_a, b, mat_dim_b, T(1), out, T(0));
189189
}
190190

191-
void CalcInputGrad(const framework::ExecutionContext& context,
192-
const framework::Tensor& a, bool trans_a,
193-
bool is_fold_init_dims_a, const framework::Tensor& b,
191+
void CalcInputGrad(const framework::ExecutionContext &context,
192+
const framework::Tensor &a, bool trans_a,
193+
bool is_fold_init_dims_a, const framework::Tensor &b,
194194
bool trans_b, bool is_fold_init_dims_b,
195-
framework::Tensor* out) const {
195+
framework::Tensor *out) const {
196196
if (out == nullptr) return;
197197
bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) &&
198198
out->dims().size() == 2;
199199
if (!need_combine) {
200200
MatMul(context, a, trans_a, b, trans_b, out);
201201
} else {
202-
auto& ctx = context.template device_context<DeviceContext>();
202+
auto &ctx = context.template device_context<DeviceContext>();
203203
MatMul(context, is_fold_init_dims_a
204204
? FoldInitDims(a)
205205
: FoldHeadAndLastDims<DeviceContext, T>(ctx, a),
@@ -210,13 +210,13 @@ class MatMulGradKernel : public framework::OpKernel<T> {
210210
}
211211
}
212212

213-
void Compute(const framework::ExecutionContext& context) const override {
213+
void Compute(const framework::ExecutionContext &context) const override {
214214
auto x = *context.Input<framework::Tensor>("X");
215215
auto y = *context.Input<framework::Tensor>("Y");
216216
auto dout =
217217
*context.Input<framework::Tensor>(framework::GradVarName("Out"));
218-
auto* dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
219-
auto* dy = context.Output<framework::Tensor>(framework::GradVarName("Y"));
218+
auto *dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
219+
auto *dy = context.Output<framework::Tensor>(framework::GradVarName("Y"));
220220
bool transpose_x = context.Attr<bool>("transpose_X");
221221
bool transpose_y = context.Attr<bool>("transpose_Y");
222222

@@ -269,7 +269,7 @@ class MatMulOp : public framework::OperatorWithKernel {
269269
using framework::OperatorWithKernel::OperatorWithKernel;
270270

271271
protected:
272-
void InferShape(framework::InferShapeContext* context) const override {
272+
void InferShape(framework::InferShapeContext *context) const override {
273273
PADDLE_ENFORCE(context->HasInput("X"),
274274
"Input(X) of MatMulOp should not be null.");
275275
PADDLE_ENFORCE(context->HasInput("Y"),
@@ -375,7 +375,7 @@ class MatMulOpGrad : public framework::OperatorWithKernel {
375375
using framework::OperatorWithKernel::OperatorWithKernel;
376376

377377
protected:
378-
void InferShape(framework::InferShapeContext* context) const override {
378+
void InferShape(framework::InferShapeContext *context) const override {
379379
PADDLE_ENFORCE(context->HasInput("X"), "Input(X) should not be null");
380380
PADDLE_ENFORCE(context->HasInput("Y"), "Input(Y) should not be null");
381381
PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")),
@@ -401,7 +401,7 @@ class MatMulOpGradMaker : public framework::SingleGradOpDescMaker {
401401

402402
protected:
403403
std::unique_ptr<framework::OpDesc> Apply() const override {
404-
auto* retv = new framework::OpDesc();
404+
auto *retv = new framework::OpDesc();
405405
retv->SetType("matmul_grad");
406406
retv->SetInput("X", Input("X"));
407407
retv->SetInput("Y", Input("Y"));
@@ -420,15 +420,27 @@ REGISTER_OPERATOR(matmul, ops::MatMulOp, ops::MatMulOpMaker,
420420
ops::MatMulOpGradMaker);
421421
REGISTER_OPERATOR(matmul_grad, ops::MatMulOpGrad);
422422
REGISTER_OP_CPU_KERNEL(
423-
matmul, ops::MatMulKernel<paddle::platform::CPUDeviceContext, float>);
423+
matmul, ops::MatMulKernel<paddle::platform::CPUDeviceContext, float>,
424+
ops::MatMulKernel<paddle::platform::CPUDeviceContext, double>,
425+
ops::MatMulKernel<paddle::platform::CPUDeviceContext,
426+
paddle::platform::float16>);
424427
REGISTER_OP_CPU_KERNEL(
425428
matmul_grad,
426-
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, float>);
429+
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, float>,
430+
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, double>,
431+
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext,
432+
paddle::platform::float16>);
427433

428434
#ifdef PADDLE_WITH_CUDA
429435
REGISTER_OP_CUDA_KERNEL(
430-
matmul, ops::MatMulKernel<paddle::platform::CUDADeviceContext, float>);
436+
matmul, ops::MatMulKernel<paddle::platform::CUDADeviceContext, float>,
437+
ops::MatMulKernel<paddle::platform::CUDADeviceContext, double>,
438+
ops::MatMulKernel<paddle::platform::CUDADeviceContext,
439+
paddle::platform::float16>);
431440
REGISTER_OP_CUDA_KERNEL(
432441
matmul_grad,
433-
ops::MatMulGradKernel<paddle::platform::CUDADeviceContext, float>);
442+
ops::MatMulGradKernel<paddle::platform::CUDADeviceContext, float>,
443+
ops::MatMulGradKernel<paddle::platform::CUDADeviceContext, double>,
444+
ops::MatMulGradKernel<paddle::platform::CUDADeviceContext,
445+
paddle::platform::float16>);
434446
#endif

0 commit comments

Comments
 (0)