@@ -16,11 +16,40 @@ limitations under the License. */
16
16
#include " paddle/fluid/framework/data_type.h"
17
17
#include " paddle/fluid/operators/math/math_function.h"
18
18
#include " paddle/fluid/operators/math/math_function_impl.h"
19
+ #include " paddle/fluid/platform/float16.h"
19
20
20
21
namespace paddle {
21
22
namespace operators {
22
23
namespace math {
23
24
25
+ using float16 = paddle::platform::float16;
26
+
27
+ template <>
28
+ void gemm<platform::CUDADeviceContext, float16>(
29
+ const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
30
+ const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
31
+ const float16 alpha, const float16* A, const float16* B, const float16 beta,
32
+ float16* C) {
33
+ // Note that cublas follows fortran order, so the order is different from
34
+ // the cblas convention.
35
+ int lda = (transA == CblasNoTrans) ? K : M;
36
+ int ldb = (transB == CblasNoTrans) ? N : K;
37
+ cublasOperation_t cuTransA =
38
+ (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
39
+ cublasOperation_t cuTransB =
40
+ (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
41
+
42
+ const half h_alpha = static_cast <const half>(alpha);
43
+ const half h_beta = static_cast <const half>(beta);
44
+ const half* h_A = reinterpret_cast <const half*>(A);
45
+ const half* h_B = reinterpret_cast <const half*>(B);
46
+ half* h_C = reinterpret_cast <half*>(C);
47
+
48
+ PADDLE_ENFORCE (platform::dynload::cublasHgemm (
49
+ context.cublas_handle (), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
50
+ h_A, lda, &h_beta, h_C, N));
51
+ }
52
+
24
53
template <>
25
54
void gemm<platform::CUDADeviceContext, float >(
26
55
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
@@ -60,6 +89,28 @@ void gemm<platform::CUDADeviceContext, double>(
60
89
lda, &beta, C, N));
61
90
}
62
91
92
+ template <>
93
+ void gemm<platform::CUDADeviceContext, float16>(
94
+ const platform::CUDADeviceContext& context, const bool transA,
95
+ const bool transB, const int M, const int N, const int K,
96
+ const float16 alpha, const float16* A, const int lda, const float16* B,
97
+ const int ldb, const float16 beta, float16* C, const int ldc) {
98
+ // Note that cublas follows fortran order, so the order is different from
99
+ // the cblas convention.
100
+ cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
101
+ cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;
102
+
103
+ const half h_alpha = static_cast <const half>(alpha);
104
+ const half h_beta = static_cast <const half>(beta);
105
+ const half* h_A = reinterpret_cast <const half*>(A);
106
+ const half* h_B = reinterpret_cast <const half*>(B);
107
+ half* h_C = reinterpret_cast <half*>(C);
108
+
109
+ PADDLE_ENFORCE (platform::dynload::cublasHgemm (
110
+ context.cublas_handle (), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
111
+ h_A, lda, &h_beta, h_C, ldc));
112
+ }
113
+
63
114
template <>
64
115
void gemm<platform::CUDADeviceContext, float >(
65
116
const platform::CUDADeviceContext& context, const bool transA,
@@ -90,6 +141,35 @@ void gemm<platform::CUDADeviceContext, double>(
90
141
lda, &beta, C, ldc));
91
142
}
92
143
144
+ template <>
145
+ void matmul<platform::CUDADeviceContext, float16>(
146
+ const platform::CUDADeviceContext& context,
147
+ const framework::Tensor& matrix_a, bool trans_a,
148
+ const framework::Tensor& matrix_b, bool trans_b, float16 alpha,
149
+ framework::Tensor* matrix_out, float16 beta) {
150
+ auto dim_a = matrix_a.dims ();
151
+ auto dim_b = matrix_b.dims ();
152
+ auto dim_out = matrix_out->dims ();
153
+ PADDLE_ENFORCE (dim_a.size () == 2 && dim_b.size () == 2 && dim_out.size () == 2 ,
154
+ " The input and output of matmul be matrix" );
155
+
156
+ PADDLE_ENFORCE (platform::is_gpu_place (matrix_a.place ()) &&
157
+ platform::is_gpu_place (matrix_b.place ()) &&
158
+ platform::is_gpu_place (matrix_out->place ()),
159
+ " Matrix must all be in CUDAPlace" );
160
+
161
+ int M = dim_out[0 ];
162
+ int N = dim_out[1 ];
163
+ int K = (trans_a == false ) ? dim_a[1 ] : dim_a[0 ];
164
+
165
+ CBLAS_TRANSPOSE transA = (trans_a == false ) ? CblasNoTrans : CblasTrans;
166
+ CBLAS_TRANSPOSE transB = (trans_b == false ) ? CblasNoTrans : CblasTrans;
167
+
168
+ gemm<platform::CUDADeviceContext, float16>(
169
+ context, transA, transB, M, N, K, alpha, matrix_a.data <float16>(),
170
+ matrix_b.data <float16>(), beta, matrix_out->data <float16>());
171
+ }
172
+
93
173
template <>
94
174
void matmul<platform::CUDADeviceContext, float >(
95
175
const platform::CUDADeviceContext& context,
@@ -148,6 +228,34 @@ void matmul<platform::CUDADeviceContext, double>(
148
228
matrix_b.data <double >(), beta, matrix_out->data <double >());
149
229
}
150
230
231
+ template <>
232
+ void batched_gemm<platform::CUDADeviceContext, float16>(
233
+ const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
234
+ const CBLAS_TRANSPOSE transB, const int M, const int N, const int K,
235
+ const float16 alpha, const float16* A, const float16* B, const float16 beta,
236
+ float16* C, const int batchCount, const int strideA, const int strideB) {
237
+ // Note that cublas follows fortran order, so the order is different from
238
+ // the cblas convention.
239
+ int lda = (transA == CblasNoTrans) ? K : M;
240
+ int ldb = (transB == CblasNoTrans) ? N : K;
241
+ int ldc = N;
242
+ cublasOperation_t cuTransA =
243
+ (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
244
+ cublasOperation_t cuTransB =
245
+ (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
246
+ const int strideC = M * N;
247
+
248
+ const half h_alpha = static_cast <const half>(alpha);
249
+ const half h_beta = static_cast <const half>(beta);
250
+ const half* h_A = reinterpret_cast <const half*>(A);
251
+ const half* h_B = reinterpret_cast <const half*>(B);
252
+ half* h_C = reinterpret_cast <half*>(C);
253
+
254
+ PADDLE_ENFORCE (platform::dynload::cublasHgemmStridedBatched (
255
+ context.cublas_handle (), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
256
+ strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount));
257
+ }
258
+
151
259
template <>
152
260
void batched_gemm<platform::CUDADeviceContext, float >(
153
261
const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA,
0 commit comments