Skip to content

Commit 177c135

Browse files
authored
Use cublasHgemm "back" for fp16 computation with Volta GPU (microsoft#3765)
* Use cublasHgemm for fp16 computation with Volta GPU
1 parent 3421ec1 commit 177c135

File tree

4 files changed

+31
-21
lines changed

4 files changed

+31
-21
lines changed

onnxruntime/contrib_ops/cuda/bert/attention.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,19 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
7171

7272
// Bias shape is (N), broadcast using B(N, M) = 1 * bias(N, 1) x ones(1, M) + 0 * B.
7373
// TODO: use custom kernel of expand to improve the performance.
74+
auto& device_prop = GetDeviceProp();
7475
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
7576
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one,
7677
reinterpret_cast<const CudaT*>(bias->template Data<T>()), n,
7778
GetConstOnes<CudaT>(m), 1,
78-
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n));
79+
&zero, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop));
7980

8081
// Gemm, note that CUDA assumes col-major, so result(N, M) = 1 * weights x input + 1 x B.
8182
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
8283
cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one,
8384
reinterpret_cast<const CudaT*>(weights->template Data<T>()), n,
8485
reinterpret_cast<const CudaT*>(input->template Data<T>()), k,
85-
&one, reinterpret_cast<CudaT*>(gemm_buffer.get()), n));
86+
&one, reinterpret_cast<CudaT*>(gemm_buffer.get()), n, device_prop));
8687

8788
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, num_heads_, head_size, sequence_length);
8889
auto temp_buffer = GetScratchBuffer<void>(workSpaceSize);

onnxruntime/core/providers/cuda/math/gemm.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,11 @@ Status Gemm<T>::ComputeInternal(OpKernelContext* ctx) const {
6565

6666
CudaT one = ToCudaType<T>::FromFloat(1.0f);
6767
CudaT zero = ToCudaType<T>::FromFloat(0.0f);
68-
68+
auto& device_prop = GetDeviceProp();
6969
// broadcast bias if needed and is present
7070
if (beta_ != 0 && B != nullptr) {
7171
auto& b_shape = B->Shape();
7272
const CudaT* b_data = reinterpret_cast<const CudaT*>(B->template Data<T>());
73-
7473
if (b_shape.Size() == 1) {
7574
// if B is (), (1,) or (1, 1), broadcast the scalar
7675
CUBLAS_RETURN_IF_ERROR(cublasCopyHelper(
@@ -91,7 +90,7 @@ Status Gemm<T>::ComputeInternal(OpKernelContext* ctx) const {
9190
b_data, N,
9291
GetConstOnes<CudaT>(M), 1,
9392
/*beta*/ &zero,
94-
out_data, N));
93+
out_data, N, device_prop));
9594
} else if (b_shape.NumDimensions() == 2 && b_shape[1] == 1) {
9695
// B is (M, 1), broadcast using Y(N,M) = 1 * ones(N,1) x B(1,M) + 0 * Y
9796
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
@@ -103,7 +102,7 @@ Status Gemm<T>::ComputeInternal(OpKernelContext* ctx) const {
103102
GetConstOnes<CudaT>(N), N,
104103
b_data, 1,
105104
/*beta*/ &zero,
106-
out_data, N));
105+
out_data, N, device_prop));
107106
} else {
108107
// B is (M, N), no broadcast needed.
109108
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(out_data, b_data, M * N * sizeof(float), cudaMemcpyDeviceToDevice));
@@ -126,7 +125,7 @@ Status Gemm<T>::ComputeInternal(OpKernelContext* ctx) const {
126125
// ideally we need to set the output buffer contents to 0 if bias is missing,
127126
// but passing 0 for beta is cheaper and it will ignore any junk in the output buffer
128127
B != nullptr ? &beta : &zero,
129-
out_data, N));
128+
out_data, N, device_prop));
130129

131130
return Status::OK();
132131
}

onnxruntime/core/providers/cuda/math/matmul.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ static bool CanUseStridedBatchedGemm(const TensorShape& left_shape, const Tensor
4949
int64_t left_k = transa ? left_shape[left_num_dims - 2] : left_shape[left_num_dims - 1];
5050

5151
if (right_num_dims >= 3) {
52-
int64_t right_p = right_shape.SizeToDimension(right_num_dims-2);
52+
int64_t right_p = right_shape.SizeToDimension(right_num_dims - 2);
5353
if (left_p != right_p) {
5454
return false;
5555
}
@@ -102,7 +102,7 @@ Status MatMul<T>::ComputeInternal(OpKernelContext* ctx) const {
102102
const int ldb = transb ? static_cast<int>(helper.K()) : static_cast<int>(helper.N());
103103
const int ldc = static_cast<int>(helper.N());
104104
int64_t stride_A, stride_B, stride_C, batch_count;
105-
105+
auto& device_prop = GetDeviceProp();
106106
if (helper.OutputOffsets().size() == 1) {
107107
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
108108
Base::CublasHandle(),
@@ -118,10 +118,11 @@ Status MatMul<T>::ComputeInternal(OpKernelContext* ctx) const {
118118
lda,
119119
&zero,
120120
reinterpret_cast<CudaT*>(Y->template MutableData<T>()),
121-
ldc));
121+
ldc,
122+
device_prop));
122123
return Status::OK();
123124
} else if (CanUseStridedBatchedGemm(left_X->Shape(), right_X->Shape(),
124-
transa, transb, stride_A, stride_B, stride_C, batch_count)) {
125+
transa, transb, stride_A, stride_B, stride_C, batch_count)) {
125126
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(Base::CublasHandle(),
126127
transB,
127128
transA,

onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,29 @@
1818
// Generalize library calls to be use in template functions
1919

2020
// gemm
21-
inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float* alpha, const float* A, int lda, const float* B, int ldb, const float* beta, float* C, int ldc) {
21+
inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
22+
int m, int n, int k, const float* alpha, const float* A, int lda,
23+
const float* B, int ldb, const float* beta, float* C, int ldc,
24+
const cudaDeviceProp& /*prop*/) {
2225
return cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
2326
}
24-
inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const double* alpha, const double* A, int lda, const double* B, int ldb, const double* beta, double* C, int ldc) {
27+
inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
28+
int m, int n, int k, const double* alpha, const double* A, int lda,
29+
const double* B, int ldb, const double* beta, double* C, int ldc,
30+
const cudaDeviceProp& /*prop*/) {
2531
return cublasDgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
2632
}
27-
inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const half* alpha, const half* A, int lda, const half* B, int ldb, const half* beta, half* C, int ldc) {
28-
// Disable below to make sure merged result is on par with before-merge.
33+
inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
34+
int m, int n, int k, const half* alpha, const half* A, int lda,
35+
const half* B, int ldb, const half* beta, half* C, int ldc,
36+
const cudaDeviceProp& prop) {
2937
// This does true FP16 computation which is slow for non-Volta GPUs
30-
//if (onnxruntime::cuda::DeviceProp().GetDeviceProps().major >= 7) {
31-
// onnxruntime::cuda::CublasMathModeSetter math_mode_setter( handle, CUBLAS_TENSOR_OP_MATH );
32-
// return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
33-
//}
34-
// This does pseudo FP16 computation (input/output in fp16, computation in fp32)
38+
if (prop.major >= 7) {
39+
onnxruntime::cuda::CublasMathModeSetter math_mode_setter(handle, CUBLAS_TENSOR_OP_MATH);
40+
return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
41+
}
42+
43+
//This does pseudo FP16 computation (input/output in fp16, computation in fp32)
3544
float h_a = onnxruntime::math::halfToFloat(*reinterpret_cast<const uint16_t*>(alpha));
3645
float h_b = onnxruntime::math::halfToFloat(*reinterpret_cast<const uint16_t*>(beta));
3746
cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH);
@@ -79,7 +88,7 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle,
7988
const double* beta,
8089
double* C, int ldc,
8190
long long int strideC,
82-
int batch_count){
91+
int batch_count) {
8392
return cublasDgemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batch_count);
8493
}
8594

0 commit comments

Comments
 (0)