Skip to content

Commit 3b44b84

Browse files
committed
address comments
1 parent 95de761 commit 3b44b84

File tree

4 files changed

+37
-20
lines changed

4 files changed

+37
-20
lines changed

paddle/fluid/operators/math/math_function.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ void gemm<platform::CUDADeviceContext, float16>(
4545
const half* h_B = reinterpret_cast<const half*>(B);
4646
half* h_C = reinterpret_cast<half*>(C);
4747

48+
// TODO(kexinzhao): add processing code for compute capability < 53 case
49+
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
50+
"cublas Hgemm requires GPU compute capability >= 53");
4851
PADDLE_ENFORCE(platform::dynload::cublasHgemm(
4952
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
5053
h_A, lda, &h_beta, h_C, N));
@@ -106,6 +109,9 @@ void gemm<platform::CUDADeviceContext, float16>(
106109
const half* h_B = reinterpret_cast<const half*>(B);
107110
half* h_C = reinterpret_cast<half*>(C);
108111

112+
// TODO(kexinzhao): add processing code for compute capability < 53 case
113+
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
114+
"cublas Hgemm requires GPU compute capability >= 53");
109115
PADDLE_ENFORCE(platform::dynload::cublasHgemm(
110116
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
111117
h_A, lda, &h_beta, h_C, ldc));
@@ -251,6 +257,9 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
251257
const half* h_B = reinterpret_cast<const half*>(B);
252258
half* h_C = reinterpret_cast<half*>(C);
253259

260+
// TODO(kexinzhao): add processing code for compute capability < 53 case
261+
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
262+
"cublas Hgemm requires GPU compute capability >= 53");
254263
PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
255264
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
256265
strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount));

paddle/fluid/operators/math/math_function_test.cu

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,6 @@ TEST(math_function, notrans_mul_trans_fp16) {
6262
using namespace paddle::framework;
6363
using namespace paddle::platform;
6464

65-
// fp16 GEMM in cublas requires GPU compute capability >= 53
66-
if (GetCUDAComputeCapability(0) < 53) {
67-
return;
68-
}
69-
7065
Tensor input1;
7166
Tensor input1_gpu;
7267
Tensor input2_gpu;
@@ -77,6 +72,11 @@ TEST(math_function, notrans_mul_trans_fp16) {
7772
CUDAPlace gpu_place(0);
7873
CUDADeviceContext context(gpu_place);
7974

75+
// fp16 GEMM in cublas requires GPU compute capability >= 53
76+
if (context.GetComputeCapability() < 53) {
77+
return;
78+
}
79+
8080
float16* input1_ptr = input1.mutable_data<float16>({2, 3}, cpu_place);
8181
fill_fp16_data(input1_ptr, input1.numel(), {0, 1, 2, 3, 4, 5});
8282

@@ -144,11 +144,6 @@ TEST(math_function, trans_mul_notrans_fp16) {
144144
using namespace paddle::framework;
145145
using namespace paddle::platform;
146146

147-
// fp16 GEMM in cublas requires GPU compute capability >= 53
148-
if (GetCUDAComputeCapability(0) < 53) {
149-
return;
150-
}
151-
152147
Tensor input1;
153148
Tensor input1_gpu;
154149
Tensor input2_gpu;
@@ -159,6 +154,11 @@ TEST(math_function, trans_mul_notrans_fp16) {
159154
CUDAPlace gpu_place(0);
160155
CUDADeviceContext context(gpu_place);
161156

157+
// fp16 GEMM in cublas requires GPU compute capability >= 53
158+
if (context.GetComputeCapability() < 53) {
159+
return;
160+
}
161+
162162
float16* input1_ptr = input1.mutable_data<float16>({2, 3}, cpu_place);
163163
fill_fp16_data(input1_ptr, input1.numel(), {0, 1, 2, 3, 4, 5});
164164

@@ -247,11 +247,6 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
247247
using namespace paddle::framework;
248248
using namespace paddle::platform;
249249

250-
// fp16 GEMM in cublas requires GPU compute capability >= 53
251-
if (GetCUDAComputeCapability(0) < 53) {
252-
return;
253-
}
254-
255250
Tensor input1;
256251
Tensor input2;
257252
Tensor input3;
@@ -263,6 +258,11 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
263258
CUDAPlace gpu_place(0);
264259
CUDADeviceContext context(gpu_place);
265260

261+
// fp16 GEMM in cublas requires GPU compute capability >= 53
262+
if (context.GetComputeCapability() < 53) {
263+
return;
264+
}
265+
266266
int m = 2;
267267
int n = 3;
268268
int k = 3;
@@ -359,11 +359,6 @@ TEST(math_function, gemm_trans_cublas_fp16) {
359359
using namespace paddle::framework;
360360
using namespace paddle::platform;
361361

362-
// fp16 GEMM in cublas requires GPU compute capability >= 53
363-
if (GetCUDAComputeCapability(0) < 53) {
364-
return;
365-
}
366-
367362
Tensor input1;
368363
Tensor input2;
369364
Tensor input3;
@@ -375,6 +370,11 @@ TEST(math_function, gemm_trans_cublas_fp16) {
375370
CUDAPlace gpu_place(0);
376371
CUDADeviceContext context(gpu_place);
377372

373+
// fp16 GEMM in cublas requires GPU compute capability >= 53
374+
if (context.GetComputeCapability() < 53) {
375+
return;
376+
}
377+
378378
int m = 2;
379379
int n = 3;
380380
int k = 3;

paddle/fluid/platform/device_context.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
127127

128128
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
129129
SetDeviceId(place_.device);
130+
compute_capability = GetCUDAComputeCapability(place_.device);
130131
multi_process = GetCUDAMultiProcessors(place_.device);
131132
max_threads_per_mp = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
132133
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
@@ -162,6 +163,10 @@ void CUDADeviceContext::Wait() const {
162163
PADDLE_ENFORCE(cudaGetLastError());
163164
}
164165

166+
int CUDADeviceContext::GetComputeCapability() const {
167+
return compute_capability;
168+
}
169+
165170
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
166171
return multi_process * max_threads_per_mp;
167172
}

paddle/fluid/platform/device_context.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ class CUDADeviceContext : public DeviceContext {
7979
/*! \brief Return place in the device context. */
8080
Place GetPlace() const override;
8181

82+
int GetComputeCapability() const;
83+
8284
/*! \brief Return the max physical thread count in the device context */
8385
int GetMaxPhysicalThreadCount() const;
8486

@@ -104,6 +106,7 @@ class CUDADeviceContext : public DeviceContext {
104106
cudnnHandle_t cudnn_handle_;
105107
cublasHandle_t cublas_handle_;
106108

109+
int compute_capability;
107110
int multi_process;
108111
int max_threads_per_mp;
109112
};

0 commit comments

Comments
 (0)