Skip to content

Commit d400b41

Browse files
committed
fix math function arch mismatch for older GPU
1 parent ccc5418 commit d400b41

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

paddle/fluid/operators/math/math_function_test.cu

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include "gtest/gtest.h"
1515
#include "paddle/fluid/operators/math/math_function.h"
1616

17+
#include <iostream>
18+
1719
void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size,
1820
const std::vector<float>& data) {
1921
PADDLE_ENFORCE_EQ(size, data.size());
@@ -22,6 +24,15 @@ void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size,
2224
}
2325
}
2426

27+
bool is_fp16_supported(int device_id) {
28+
cudaDeviceProp device_prop;
29+
cudaDeviceProperties(&device_prop, device_id);
30+
PADDLE_ENFORCE_EQ(cudaGetLastError(), cudaSuccess);
31+
int compute_capability = device_prop.major * 10 + device_prop.minor;
32+
std::cout << "compute_capability is " << compute_capability << std::endl;
33+
return compute_capability >= 53;
34+
}
35+
2536
TEST(math_function, notrans_mul_trans_fp32) {
2637
using namespace paddle::framework;
2738
using namespace paddle::platform;
@@ -62,6 +73,10 @@ TEST(math_function, notrans_mul_trans_fp16) {
6273
using namespace paddle::framework;
6374
using namespace paddle::platform;
6475

76+
if (!is_fp16_supported(0)) {
77+
return;
78+
}
79+
6580
Tensor input1;
6681
Tensor input1_gpu;
6782
Tensor input2_gpu;
@@ -139,6 +154,10 @@ TEST(math_function, trans_mul_notrans_fp16) {
139154
using namespace paddle::framework;
140155
using namespace paddle::platform;
141156

157+
if (!is_fp16_supported(0)) {
158+
return;
159+
}
160+
142161
Tensor input1;
143162
Tensor input1_gpu;
144163
Tensor input2_gpu;
@@ -237,6 +256,10 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
237256
using namespace paddle::framework;
238257
using namespace paddle::platform;
239258

259+
if (!is_fp16_supported(0)) {
260+
return;
261+
}
262+
240263
Tensor input1;
241264
Tensor input2;
242265
Tensor input3;
@@ -344,6 +367,10 @@ TEST(math_function, gemm_trans_cublas_fp16) {
344367
using namespace paddle::framework;
345368
using namespace paddle::platform;
346369

370+
if (!is_fp16_supported(0)) {
371+
return;
372+
}
373+
347374
Tensor input1;
348375
Tensor input2;
349376
Tensor input3;

0 commit comments

Comments
 (0)