|
18 | 18 | // Generalize library calls to be use in template functions |
19 | 19 |
|
20 | 20 | // gemm |
21 | | -inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float* alpha, const float* A, int lda, const float* B, int ldb, const float* beta, float* C, int ldc) { |
| 21 | +inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, |
| 22 | + int m, int n, int k, const float* alpha, const float* A, int lda, |
| 23 | + const float* B, int ldb, const float* beta, float* C, int ldc, |
| 24 | + const cudaDeviceProp& /*prop*/) { |
22 | 25 | return cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); |
23 | 26 | } |
24 | | -inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const double* alpha, const double* A, int lda, const double* B, int ldb, const double* beta, double* C, int ldc) { |
| 27 | +inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, |
| 28 | + int m, int n, int k, const double* alpha, const double* A, int lda, |
| 29 | + const double* B, int ldb, const double* beta, double* C, int ldc, |
| 30 | + const cudaDeviceProp& /*prop*/) { |
25 | 31 | return cublasDgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); |
26 | 32 | } |
27 | | -inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const half* alpha, const half* A, int lda, const half* B, int ldb, const half* beta, half* C, int ldc) { |
28 | | - // Disable below to make sure merged result is on par with before-merge. |
| 33 | +inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, |
| 34 | + int m, int n, int k, const half* alpha, const half* A, int lda, |
| 35 | + const half* B, int ldb, const half* beta, half* C, int ldc, |
| 36 | + const cudaDeviceProp& prop) { |
29 | 37 | // This does true FP16 computation which is slow for non-Volta GPUs |
30 | | - //if (onnxruntime::cuda::DeviceProp().GetDeviceProps().major >= 7) { |
31 | | - // onnxruntime::cuda::CublasMathModeSetter math_mode_setter( handle, CUBLAS_TENSOR_OP_MATH ); |
32 | | - // return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); |
33 | | - //} |
34 | | - // This does pseudo FP16 computation (input/output in fp16, computation in fp32) |
| 38 | + if (prop.major >= 7) { |
| 39 | + onnxruntime::cuda::CublasMathModeSetter math_mode_setter(handle, CUBLAS_TENSOR_OP_MATH); |
| 40 | + return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); |
| 41 | + } |
| 42 | + |
| 43 | + //This does pseudo FP16 computation (input/output in fp16, computation in fp32) |
35 | 44 | float h_a = onnxruntime::math::halfToFloat(*reinterpret_cast<const uint16_t*>(alpha)); |
36 | 45 | float h_b = onnxruntime::math::halfToFloat(*reinterpret_cast<const uint16_t*>(beta)); |
37 | 46 | cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH); |
@@ -79,7 +88,7 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, |
79 | 88 | const double* beta, |
80 | 89 | double* C, int ldc, |
81 | 90 | long long int strideC, |
82 | | - int batch_count){ |
| 91 | + int batch_count) { |
83 | 92 | return cublasDgemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batch_count); |
84 | 93 | } |
85 | 94 |
|
|
0 commit comments