Skip to content

Commit 734cac1

Browse files
committed
fix CUDA_VERSION issue
1 parent 080ff0c commit 734cac1

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

paddle/math/float16.h

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ limitations under the License. */
2020
#include <istream>
2121
#include <ostream>
2222

23+
#include <cuda.h>
24+
25+
#include "paddle/utils/Logging.h"
26+
2327
#define USE_EIGEN
2428

2529
#ifdef USE_EIGEN // delete this #if macro
@@ -48,6 +52,27 @@ limitations under the License. */
4852
#define PADDLE_HOSTDEVICE
4953
#endif // __CUDACC__
5054

55+
#define STR(x) #x
56+
#define XSTR(x) STR(x)
57+
58+
#ifndef __CUDACC__
59+
#pragma message "__CUDACC__ not defined"
60+
#else
61+
#pragma message "__CUDACC__ defined"
62+
#endif
63+
64+
#ifndef CUDA_VERSION
65+
#pragma message "CUDA_VERSION not defined"
66+
#else
67+
#pragma message "CUDA_VERSION defined: " XSTR(CUDA_VERSION)
68+
#endif
69+
70+
#ifdef __CUDA_ARCH__
71+
#pragma message "The value of CUDA_ARCH: " XSTR(__CUDA_ARCH__)
72+
#else
73+
#pragma message "CUDA ARCH NOT DEFINED!"
74+
#endif
75+
5176
#ifdef __arm__
5277
#define PADDLE_ARM_32
5378
#endif
@@ -359,6 +384,7 @@ struct PADDLE_ALIGN(2) float16 {
359384
// arithmetic operators
360385
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
361386
__device__ inline float16 operator+(const float16& a, const float16& b) {
387+
printf("GPU Intrinsic used!");
362388
return float16(__hadd(half(a), half(b)));
363389
}
364390

@@ -495,6 +521,7 @@ __host__ inline bool operator>=(const float16& a, const float16& b) {
495521

496522
#else // software emulation on other cpu
497523
PADDLE_HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
524+
LOG(INFO) << "CPU emulation used";
498525
return float16(float(a) + float(b));
499526
}
500527

@@ -656,7 +683,7 @@ PADDLE_HOSTDEVICE inline float16 float_to_half_rn(float f) {
656683
PADDLE_HOSTDEVICE inline float half_to_float(float16 h) {
657684
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
658685
half tmp = *reinterpret_cast<half*>(&h);
659-
return __half2float(h);
686+
return __half2float(tmp);
660687

661688
#elif defined(PADDLE_NEON_64)
662689
float res;

paddle/math/tests/test_float16.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ limitations under the License. */
1515
namespace paddle {
1616

1717
TEST(float16, conversion_cpu) {
18+
LOG(INFO) << "cpu test started!";
19+
1820
// Conversion to and from Eigen::half
1921
EXPECT_EQ(float16(Eigen::half(float16(1.0f))).x, 0x3c00);
2022
EXPECT_EQ(float16(Eigen::half(float16(0.5f))).x, 0x3800);

paddle/math/tests/test_float16.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ namespace paddle {
1616

1717
#ifdef PADDLE_CUDA_FP16
1818
TEST(float16, conversion_gpu) {
19+
LOG(INFO) << "GPU tests started";
20+
1921
// Conversion to and from cuda half
2022
float16 v1 = half(float16(1.0f));
2123
EXPECT_EQ(v1.x, 0x3c00);

0 commit comments

Comments
 (0)