Skip to content

Commit 1998d5a

Browse files
committed
add gpu info func to get compute cap
1 parent d400b41 commit 1998d5a

File tree

4 files changed

+31
-13
lines changed

4 files changed

+31
-13
lines changed

paddle/fluid/operators/math/math_function_test.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,20 @@
1414
#include "paddle/fluid/operators/math/math_function.h"
1515
#include "gtest/gtest.h"
1616

17+
#include <iostream>
18+
1719
TEST(math_function, gemm_notrans_cblas) {
1820
paddle::framework::Tensor input1;
1921
paddle::framework::Tensor input2;
2022
paddle::framework::Tensor input3;
2123

24+
// fp16 GEMM in cublas requires GPU compute capability >= 53
25+
if (GetCUDAComputeCapability(0) >= 53) {
26+
std::cout << "Compute capability is " << GetCUDAComputeCapability(0)
27+
<< std::endl;
28+
return;
29+
}
30+
2231
int m = 2;
2332
int n = 3;
2433
int k = 3;

paddle/fluid/operators/math/math_function_test.cu

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,6 @@ void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size,
2424
}
2525
}
2626

27-
bool is_fp16_supported(int device_id) {
28-
cudaDeviceProp device_prop;
29-
cudaDeviceProperties(&device_prop, device_id);
30-
PADDLE_ENFORCE_EQ(cudaGetLastError(), cudaSuccess);
31-
int compute_capability = device_prop.major * 10 + device_prop.minor;
32-
std::cout << "compute_capability is " << compute_capability << std::endl;
33-
return compute_capability >= 53;
34-
}
35-
3627
TEST(math_function, notrans_mul_trans_fp32) {
3728
using namespace paddle::framework;
3829
using namespace paddle::platform;
@@ -73,7 +64,10 @@ TEST(math_function, notrans_mul_trans_fp16) {
7364
using namespace paddle::framework;
7465
using namespace paddle::platform;
7566

76-
if (!is_fp16_supported(0)) {
67+
// fp16 GEMM in cublas requires GPU compute capability >= 53
68+
if (GetCUDAComputeCapability(0) >= 53) {
69+
std::cout << "Compute capability is " << GetCUDAComputeCapability(0)
70+
<< std::endl;
7771
return;
7872
}
7973

@@ -154,7 +148,8 @@ TEST(math_function, trans_mul_notrans_fp16) {
154148
using namespace paddle::framework;
155149
using namespace paddle::platform;
156150

157-
if (!is_fp16_supported(0)) {
151+
// fp16 GEMM in cublas requires GPU compute capability >= 53
152+
if (GetCUDAComputeCapability(0) >= 53) {
158153
return;
159154
}
160155

@@ -256,7 +251,8 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
256251
using namespace paddle::framework;
257252
using namespace paddle::platform;
258253

259-
if (!is_fp16_supported(0)) {
254+
// fp16 GEMM in cublas requires GPU compute capability >= 53
255+
if (GetCUDAComputeCapability(0) >= 53) {
260256
return;
261257
}
262258

@@ -367,7 +363,8 @@ TEST(math_function, gemm_trans_cublas_fp16) {
367363
using namespace paddle::framework;
368364
using namespace paddle::platform;
369365

370-
if (!is_fp16_supported(0)) {
366+
// fp16 GEMM in cublas requires GPU compute capability >= 53
367+
if (GetCUDAComputeCapability(0) >= 53) {
371368
return;
372369
}
373370

paddle/fluid/platform/gpu_info.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,15 @@ int GetCUDADeviceCount() {
3333
return count;
3434
}
3535

36+
int GetCUDAComputeCapability(int id) {
37+
PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
38+
cudaDeviceProp device_prop;
39+
PADDLE_ENFORCE(cudaGetDeviceProperties(&device_prop, id),
40+
"cudaGetDeviceProperties failed in "
41+
"paddle::platform::GetCUDAComputeCapability");
42+
return device_prop.major * 10 + device_prop.minor;
43+
}
44+
3645
int GetCUDAMultiProcessors(int id) {
3746
PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
3847
int count;

paddle/fluid/platform/gpu_info.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ const std::string kEnvFractionGpuMemoryToUse =
3030
//! Get the total number of GPU devices in system.
3131
int GetCUDADeviceCount();
3232

33+
//! Get the compute capability of the ith GPU (format: major * 10 + minor)
34+
int GetCUDAComputeCapability(int i);
35+
3336
//! Get the MultiProcessors of the ith GPU.
3437
int GetCUDAMultiProcessors(int i);
3538

0 commit comments

Comments
 (0)