14
14
#include " gtest/gtest.h"
15
15
#include " paddle/fluid/operators/math/math_function.h"
16
16
17
+ #include < iostream>
18
+
17
19
void fill_fp16_data (paddle::platform::float16* in_ptr, size_t size,
18
20
const std::vector<float >& data) {
19
21
PADDLE_ENFORCE_EQ (size, data.size ());
@@ -22,6 +24,15 @@ void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size,
22
24
}
23
25
}
24
26
27
+ bool is_fp16_supported (int device_id) {
28
+ cudaDeviceProp device_prop;
29
+ cudaDeviceProperties (&device_prop, device_id);
30
+ PADDLE_ENFORCE_EQ (cudaGetLastError (), cudaSuccess);
31
+ int compute_capability = device_prop.major * 10 + device_prop.minor ;
32
+ std::cout << " compute_capability is " << compute_capability << std::endl;
33
+ return compute_capability >= 53 ;
34
+ }
35
+
25
36
TEST (math_function, notrans_mul_trans_fp32) {
26
37
using namespace paddle ::framework;
27
38
using namespace paddle ::platform;
@@ -62,6 +73,10 @@ TEST(math_function, notrans_mul_trans_fp16) {
62
73
using namespace paddle ::framework;
63
74
using namespace paddle ::platform;
64
75
76
+ if (!is_fp16_supported (0 )) {
77
+ return ;
78
+ }
79
+
65
80
Tensor input1;
66
81
Tensor input1_gpu;
67
82
Tensor input2_gpu;
@@ -139,6 +154,10 @@ TEST(math_function, trans_mul_notrans_fp16) {
139
154
using namespace paddle ::framework;
140
155
using namespace paddle ::platform;
141
156
157
+ if (!is_fp16_supported (0 )) {
158
+ return ;
159
+ }
160
+
142
161
Tensor input1;
143
162
Tensor input1_gpu;
144
163
Tensor input2_gpu;
@@ -237,6 +256,10 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
237
256
using namespace paddle ::framework;
238
257
using namespace paddle ::platform;
239
258
259
+ if (!is_fp16_supported (0 )) {
260
+ return ;
261
+ }
262
+
240
263
Tensor input1;
241
264
Tensor input2;
242
265
Tensor input3;
@@ -344,6 +367,10 @@ TEST(math_function, gemm_trans_cublas_fp16) {
344
367
using namespace paddle ::framework;
345
368
using namespace paddle ::platform;
346
369
370
+ if (!is_fp16_supported (0 )) {
371
+ return ;
372
+ }
373
+
347
374
Tensor input1;
348
375
Tensor input2;
349
376
Tensor input3;
0 commit comments