Skip to content

Commit 90215b7

Browse files
authored
Add float16 GEMM math function on GPU (#8695)
* test cpu float16 data transform * add isnan etc * small fix * fix containsNAN test error * add data_type transform GPU test * add float16 GPU example * fix error * fix GPU test error * initial commit * fix error * small fix * add more gemm fp16 tests * fix error * add utility function
1 parent 8e024d3 commit 90215b7

File tree

4 files changed

+449
-97
lines changed

4 files changed

+449
-97
lines changed

paddle/fluid/operators/math/math_function.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,23 @@ limitations under the License. */
1515
#include "paddle/fluid/operators/math/math_function.h"
1616
#include "paddle/fluid/framework/data_type.h"
1717
#include "paddle/fluid/operators/math/math_function_impl.h"
18+
#include "paddle/fluid/platform/float16.h"
1819

1920
namespace paddle {
2021
namespace operators {
2122
namespace math {
2223

24+
using float16 = paddle::platform::float16;
25+
26+
template <>
27+
void gemm<platform::CPUDeviceContext, float16>(
28+
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
29+
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
30+
const float16 alpha, const float16* A, const float16* B, const float16 beta,
31+
float16* C) {
32+
PADDLE_THROW("float16 GEMM not supported on CPU");
33+
}
34+
2335
template <>
2436
void gemm<platform::CPUDeviceContext, float>(
2537
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
@@ -46,6 +58,15 @@ void gemm<platform::CPUDeviceContext, double>(
4658
beta, C, ldc);
4759
}
4860

61+
template <>
62+
void gemm<platform::CPUDeviceContext, float16>(
63+
const platform::CPUDeviceContext& context, const bool transA,
64+
const bool transB, const int M, const int N, const int K,
65+
const float16 alpha, const float16* A, const int lda, const float16* B,
66+
const int ldb, const float16 beta, float16* C, const int ldc) {
67+
PADDLE_THROW("float16 GEMM not supported on CPU");
68+
}
69+
4970
template <>
5071
void gemm<platform::CPUDeviceContext, float>(
5172
const platform::CPUDeviceContext& context, const bool transA,
@@ -68,6 +89,15 @@ void gemm<platform::CPUDeviceContext, double>(
6889
lda, B, ldb, beta, C, ldc);
6990
}
7091

92+
template <>
93+
void matmul<platform::CPUDeviceContext, float16>(
94+
const platform::CPUDeviceContext& context,
95+
const framework::Tensor& matrix_a, bool trans_a,
96+
const framework::Tensor& matrix_b, bool trans_b, float16 alpha,
97+
framework::Tensor* matrix_out, float16 beta) {
98+
PADDLE_THROW("float16 matmul not supported on CPU");
99+
}
100+
71101
template <>
72102
void matmul<platform::CPUDeviceContext, float>(
73103
const platform::CPUDeviceContext& context,
@@ -126,6 +156,15 @@ void matmul<platform::CPUDeviceContext, double>(
126156
matrix_b.data<double>(), beta, matrix_out->data<double>());
127157
}
128158

159+
template <>
160+
void batched_gemm<platform::CPUDeviceContext, float16>(
161+
const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA,
162+
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
163+
const float16 alpha, const float16* A, const float16* B, const float16 beta,
164+
float16* C, const int batchCount, const int strideA, const int strideB) {
165+
PADDLE_THROW("float16 batched_gemm not supported on CPU");
166+
}
167+
129168
#ifdef PADDLE_WITH_MKLML
130169
// Use cblas_{s,d}gemm_batched if available: Run with 1 group of size batchSize.
131170
template <>

paddle/fluid/operators/math/math_function.cu

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,40 @@ limitations under the License. */
1616
#include "paddle/fluid/framework/data_type.h"
1717
#include "paddle/fluid/operators/math/math_function.h"
1818
#include "paddle/fluid/operators/math/math_function_impl.h"
19+
#include "paddle/fluid/platform/float16.h"
1920

2021
namespace paddle {
2122
namespace operators {
2223
namespace math {
2324

25+
using float16 = paddle::platform::float16;
26+
27+
template <>
28+
void gemm<platform::CUDADeviceContext, float16>(
29+
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
30+
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
31+
const float16 alpha, const float16* A, const float16* B, const float16 beta,
32+
float16* C) {
33+
// Note that cublas follows fortran order, so the order is different from
34+
// the cblas convention.
35+
int lda = (transA == CblasNoTrans) ? K : M;
36+
int ldb = (transB == CblasNoTrans) ? N : K;
37+
cublasOperation_t cuTransA =
38+
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
39+
cublasOperation_t cuTransB =
40+
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
41+
42+
const half h_alpha = static_cast<const half>(alpha);
43+
const half h_beta = static_cast<const half>(beta);
44+
const half* h_A = reinterpret_cast<const half*>(A);
45+
const half* h_B = reinterpret_cast<const half*>(B);
46+
half* h_C = reinterpret_cast<half*>(C);
47+
48+
PADDLE_ENFORCE(platform::dynload::cublasHgemm(
49+
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
50+
h_A, lda, &h_beta, h_C, N));
51+
}
52+
2453
template <>
2554
void gemm<platform::CUDADeviceContext, float>(
2655
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
@@ -60,6 +89,28 @@ void gemm<platform::CUDADeviceContext, double>(
6089
lda, &beta, C, N));
6190
}
6291

92+
template <>
93+
void gemm<platform::CUDADeviceContext, float16>(
94+
const platform::CUDADeviceContext& context, const bool transA,
95+
const bool transB, const int M, const int N, const int K,
96+
const float16 alpha, const float16* A, const int lda, const float16* B,
97+
const int ldb, const float16 beta, float16* C, const int ldc) {
98+
// Note that cublas follows fortran order, so the order is different from
99+
// the cblas convention.
100+
cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
101+
cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;
102+
103+
const half h_alpha = static_cast<const half>(alpha);
104+
const half h_beta = static_cast<const half>(beta);
105+
const half* h_A = reinterpret_cast<const half*>(A);
106+
const half* h_B = reinterpret_cast<const half*>(B);
107+
half* h_C = reinterpret_cast<half*>(C);
108+
109+
PADDLE_ENFORCE(platform::dynload::cublasHgemm(
110+
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
111+
h_A, lda, &h_beta, h_C, ldc));
112+
}
113+
63114
template <>
64115
void gemm<platform::CUDADeviceContext, float>(
65116
const platform::CUDADeviceContext& context, const bool transA,
@@ -90,6 +141,35 @@ void gemm<platform::CUDADeviceContext, double>(
90141
lda, &beta, C, ldc));
91142
}
92143

144+
template <>
145+
void matmul<platform::CUDADeviceContext, float16>(
146+
const platform::CUDADeviceContext& context,
147+
const framework::Tensor& matrix_a, bool trans_a,
148+
const framework::Tensor& matrix_b, bool trans_b, float16 alpha,
149+
framework::Tensor* matrix_out, float16 beta) {
150+
auto dim_a = matrix_a.dims();
151+
auto dim_b = matrix_b.dims();
152+
auto dim_out = matrix_out->dims();
153+
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
154+
"The input and output of matmul be matrix");
155+
156+
PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
157+
platform::is_gpu_place(matrix_b.place()) &&
158+
platform::is_gpu_place(matrix_out->place()),
159+
"Matrix must all be in CUDAPlace");
160+
161+
int M = dim_out[0];
162+
int N = dim_out[1];
163+
int K = (trans_a == false) ? dim_a[1] : dim_a[0];
164+
165+
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
166+
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
167+
168+
gemm<platform::CUDADeviceContext, float16>(
169+
context, transA, transB, M, N, K, alpha, matrix_a.data<float16>(),
170+
matrix_b.data<float16>(), beta, matrix_out->data<float16>());
171+
}
172+
93173
template <>
94174
void matmul<platform::CUDADeviceContext, float>(
95175
const platform::CUDADeviceContext& context,
@@ -148,6 +228,34 @@ void matmul<platform::CUDADeviceContext, double>(
148228
matrix_b.data<double>(), beta, matrix_out->data<double>());
149229
}
150230

231+
template <>
232+
void batched_gemm<platform::CUDADeviceContext, float16>(
233+
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
234+
const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
235+
const float16 alpha, const float16* A, const float16* B, const float16 beta,
236+
float16* C, const int batchCount, const int strideA, const int strideB) {
237+
// Note that cublas follows fortran order, so the order is different from
238+
// the cblas convention.
239+
int lda = (transA == CblasNoTrans) ? K : M;
240+
int ldb = (transB == CblasNoTrans) ? N : K;
241+
int ldc = N;
242+
cublasOperation_t cuTransA =
243+
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
244+
cublasOperation_t cuTransB =
245+
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
246+
const int strideC = M * N;
247+
248+
const half h_alpha = static_cast<const half>(alpha);
249+
const half h_beta = static_cast<const half>(beta);
250+
const half* h_A = reinterpret_cast<const half*>(A);
251+
const half* h_B = reinterpret_cast<const half*>(B);
252+
half* h_C = reinterpret_cast<half*>(C);
253+
254+
PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
255+
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
256+
strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount));
257+
}
258+
151259
template <>
152260
void batched_gemm<platform::CUDADeviceContext, float>(
153261
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,

0 commit comments

Comments
 (0)