Skip to content

Commit 0f4bf1c

Browse files
committed
Add GPU device code for testing
1 parent 734cac1 commit 0f4bf1c

File tree

3 files changed

+296
-94
lines changed

3 files changed

+296
-94
lines changed

paddle/math/float16.h

Lines changed: 14 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
// need to define PADDLE_ARM_FP16
16-
1715
#pragma once
1816

1917
#include <cstdint>
2018
#include <istream>
2119
#include <ostream>
2220

2321
#include <cuda.h>
24-
25-
#include "paddle/utils/Logging.h"
26-
27-
#define USE_EIGEN
28-
29-
#ifdef USE_EIGEN // delete this #if macro
3022
#include "unsupported/Eigen/CXX11/Tensor"
31-
#endif
3223

3324
#ifdef __GNUC__
3425
#define PADDLE_GNUC_VER (__GNUC__ * 10 + __GNUC_MINOR__)
@@ -52,27 +43,6 @@ limitations under the License. */
5243
#define PADDLE_HOSTDEVICE
5344
#endif // __CUDACC__
5445

55-
#define STR(x) #x
56-
#define XSTR(x) STR(x)
57-
58-
#ifndef __CUDACC__
59-
#pragma message "__CUDACC__ not defined"
60-
#else
61-
#pragma message "__CUDACC__ defined"
62-
#endif
63-
64-
#ifndef CUDA_VERSION
65-
#pragma message "CUDA_VERSION not defined"
66-
#else
67-
#pragma message "CUDA_VERSION defined: " XSTR(CUDA_VERSION)
68-
#endif
69-
70-
#ifdef __CUDA_ARCH__
71-
#pragma message "The value of CUDA_ARCH: " XSTR(__CUDA_ARCH__)
72-
#else
73-
#pragma message "CUDA ARCH NOT DEFINED!"
74-
#endif
75-
7646
#ifdef __arm__
7747
#define PADDLE_ARM_32
7848
#endif
@@ -113,7 +83,7 @@ namespace paddle {
11383
struct float16;
11484

11585
namespace fp16_impl {
116-
// convert from float to half precision in round-to-nearest-even mode
86+
// Convert from float to half precision in round-to-nearest-even mode
11787
PADDLE_HOSTDEVICE inline float16 float_to_half_rn(float f);
11888
PADDLE_HOSTDEVICE inline float half_to_float(float16 h);
11989
} // namespace fp16_impl
@@ -125,7 +95,7 @@ PADDLE_HOSTDEVICE inline float half_to_float(float16 h);
12595
struct PADDLE_ALIGN(2) float16 {
12696
uint16_t x;
12797

128-
PADDLE_HOSTDEVICE inline float16() {}
98+
PADDLE_HOSTDEVICE inline float16() : x(0) {}
12999

130100
PADDLE_HOSTDEVICE inline float16(const float16& h) : x(h.x) {}
131101

@@ -139,21 +109,15 @@ struct PADDLE_ALIGN(2) float16 {
139109
}
140110
#endif // PADDLE_CUDA_FP16
141111

142-
#ifdef USE_EIGEN
143112
PADDLE_HOSTDEVICE inline float16(const Eigen::half& h) : x(h.x) {}
144-
#endif // USE_EIGEN
145113

146114
#if defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16) && \
147115
(PADDLE_GNUC_VER >= 61 || PADDLE_CLANG_VER >= 34)
148116
// __fp16 is a native half precision data type for arm cpu,
149117
// float16_t is an alias for __fp16 in arm_fp16.h,
150118
// which is included in arm_neon.h.
151-
// According to gcc, __fp16 can only be used as an argument to fp16
152-
// intrinsic defined in arm_neon.h or as a storage type. It cannot
153-
// be used as a formal function argument.
154-
// TODO(kexinzhao): test it on RPI
155-
PADDLE_HOSTDEVICE inline float16(const float16_t* h) {
156-
x = *reinterpret_cast<uint16_t*>(h);
119+
PADDLE_HOSTDEVICE inline float16(const float16_t& h) {
120+
x = *reinterpret_cast<uint16_t*>(&h);
157121
}
158122
#endif
159123

@@ -225,17 +189,15 @@ struct PADDLE_ALIGN(2) float16 {
225189
}
226190
#endif
227191

228-
#ifdef USE_EIGEN
229192
PADDLE_HOSTDEVICE inline float16& operator=(const Eigen::half& rhs) {
230193
x = rhs.x;
231194
return *this;
232195
}
233-
#endif // USE_EIGEN
234196

235197
#if defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16) && \
236198
(PADDLE_GNUC_VER >= 61 || PADDLE_CLANG_VER >= 34)
237-
PADDLE_HOSTDEVICE inline float16& operator=(const float16_t* rhs) {
238-
x = *reinterpret_cast<uint16_t*>(rhs);
199+
PADDLE_HOSTDEVICE inline float16& operator=(const float16_t& rhs) {
200+
x = *reinterpret_cast<uint16_t*>(&rhs);
239201
return *this;
240202
}
241203
#endif
@@ -319,17 +281,14 @@ struct PADDLE_ALIGN(2) float16 {
319281
}
320282
#endif // PADDLE_CUDA_FP16
321283

322-
#ifdef USE_EIGEN
323284
PADDLE_HOSTDEVICE inline operator Eigen::half() const {
324285
Eigen::half h;
325286
h.x = x;
326287
return h;
327288
}
328-
#endif // USE_EIGEN
329289

330290
#if defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16) && \
331291
(PADDLE_GNUC_VER >= 61 || PADDLE_CLANG_VER >= 34)
332-
// check whether it works or not
333292
PADDLE_HOSTDEVICE inline operator float16_t() const {
334293
float16 h = *this;
335294
return *reinterpret_cast<float16_t*>(&h);
@@ -381,10 +340,9 @@ struct PADDLE_ALIGN(2) float16 {
381340
}
382341
};
383342

384-
// arithmetic operators
343+
// Arithmetic operators
385344
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
386345
__device__ inline float16 operator+(const float16& a, const float16& b) {
387-
printf("GPU Intrinsic used!");
388346
return float16(__hadd(half(a), half(b)));
389347
}
390348

@@ -452,7 +410,7 @@ __device__ inline bool operator>=(const float16& a, const float16& b) {
452410
}
453411

454412
// On ARMv8.2-A CPU
455-
#elif defined(PADDLE_NEON_64) && defined(PADDLE_ARM_FP16) && \
413+
#elif defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16) && \
456414
(PADDLE_GNUC_VER >= 71 || PADDLE_CLANG_VER >= 39)
457415
__host__ inline float16 operator+(const float16& a, const float16& b) {
458416
return float16(vaddh_f16(float16_t(a), float16_t(b)));
@@ -502,7 +460,7 @@ __host__ inline bool operator!=(const float16& a, const float16& b) {
502460
return !(a == b);
503461
}
504462

505-
// compare only available in NEON_64
463+
#ifdef PADDLE_NEON_64
506464
__host__ inline bool operator<(const float16& a, const float16& b) {
507465
return static_cast<bool>(vclth_f16(float16_t(a), float16_t(b)));
508466
}
@@ -518,10 +476,10 @@ __host__ inline bool operator>(const float16& a, const float16& b) {
518476
__host__ inline bool operator>=(const float16& a, const float16& b) {
519477
return static_cast<bool>(vcgeh_f16(float16_t(a), float16_t(b)));
520478
}
479+
#endif // PADDLE_NEON_64
521480

522-
#else // software emulation on other cpu
481+
#else // Software emulation on other cpu
523482
PADDLE_HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
524-
LOG(INFO) << "CPU emulation used";
525483
return float16(float(a) + float(b));
526484
}
527485

@@ -624,7 +582,7 @@ PADDLE_HOSTDEVICE inline float16 float_to_half_rn(float f) {
624582
half tmp = __float2half(f);
625583
return *reinterpret_cast<float16*>(&tmp);
626584

627-
#elif defined(PADDLE_NEON_64) // test on RPI
585+
#elif defined(PADDLE_NEON_64)
628586
float16 res;
629587
asm volatile(
630588
"ld1 {v0.s}[0], [%[float_ptr]]\n"
@@ -638,7 +596,7 @@ PADDLE_HOSTDEVICE inline float16 float_to_half_rn(float f) {
638596
"memory", "v0");
639597
return res;
640598

641-
#elif defined(PADDLE_NEON_32) // test on RPI
599+
#elif defined(PADDLE_NEON_32)
642600
float16 res;
643601
asm volatile(
644602
"vld1.32 {d0[0]}, [%[float_ptr]]\n"
@@ -689,7 +647,7 @@ PADDLE_HOSTDEVICE inline float half_to_float(float16 h) {
689647
float res;
690648
asm volatile(
691649
"ld1 {v0.h}[0], [%[half_ptr]]\n"
692-
"FCVT s0, h0\n"
650+
"fcvt s0, h0\n"
693651
"st1 {v0.s}[0], [%[float_ptr]]\n"
694652
: // outputs
695653
: // inputs
@@ -739,5 +697,4 @@ PADDLE_HOSTDEVICE inline float half_to_float(float16 h) {
739697
}
740698

741699
} // namespace fp16_impl
742-
743700
} // namespace paddle

paddle/math/tests/test_float16.cpp

Lines changed: 89 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99
See the License for the specific language governing permissions and
1010
limitations under the License. */
1111

12-
#include <gtest/gtest.h>
1312
#include "paddle/math/float16.h"
1413

14+
#include <gtest/gtest.h>
15+
1516
namespace paddle {
1617

1718
TEST(float16, conversion_cpu) {
18-
LOG(INFO) << "cpu test started!";
19-
20-
// Conversion to and from Eigen::half
21-
EXPECT_EQ(float16(Eigen::half(float16(1.0f))).x, 0x3c00);
22-
EXPECT_EQ(float16(Eigen::half(float16(0.5f))).x, 0x3800);
23-
EXPECT_EQ(float16(Eigen::half(float16(0.33333f))).x, 0x3555);
24-
EXPECT_EQ(float16(Eigen::half(float16(0.0f))).x, 0x0000);
25-
EXPECT_EQ(float16(Eigen::half(float16(-0.0f))).x, 0x8000);
26-
EXPECT_EQ(float16(Eigen::half(float16(65504.0f))).x, 0x7bff);
27-
EXPECT_EQ(float16(Eigen::half(float16(65536.0f))).x, 0x7c00);
19+
// Explicit conversion from Eigen::half
20+
EXPECT_EQ(float16(Eigen::half(1.0f)).x, 0x3c00);
21+
EXPECT_EQ(float16(Eigen::half(0.5f)).x, 0x3800);
22+
EXPECT_EQ(float16(Eigen::half(0.33333f)).x, 0x3555);
23+
EXPECT_EQ(float16(Eigen::half(0.0f)).x, 0x0000);
24+
EXPECT_EQ(float16(Eigen::half(-0.0f)).x, 0x8000);
25+
EXPECT_EQ(float16(Eigen::half(65504.0f)).x, 0x7bff);
26+
EXPECT_EQ(float16(Eigen::half(65536.0f)).x, 0x7c00);
2827

2928
// Conversion from float
3029
EXPECT_EQ(float16(1.0f).x, 0x3c00);
@@ -36,14 +35,91 @@ TEST(float16, conversion_cpu) {
3635
EXPECT_EQ(float16(65536.0f).x, 0x7c00);
3736

3837
// Conversion from double
38+
EXPECT_EQ(float16(1.0).x, 0x3c00);
39+
EXPECT_EQ(float16(0.5).x, 0x3800);
40+
EXPECT_EQ(float16(0.33333).x, 0x3555);
41+
EXPECT_EQ(float16(0.0).x, 0x0000);
42+
EXPECT_EQ(float16(-0.0).x, 0x8000);
43+
EXPECT_EQ(float16(65504.0).x, 0x7bff);
44+
EXPECT_EQ(float16(65536.0).x, 0x7c00);
3945

4046
// Conversion from int
47+
EXPECT_EQ(float16(-1).x, 0xbc00);
48+
EXPECT_EQ(float16(0).x, 0x0000);
49+
EXPECT_EQ(float16(1).x, 0x3c00);
50+
EXPECT_EQ(float16(2).x, 0x4000);
51+
EXPECT_EQ(float16(3).x, 0x4200);
4152

4253
// Conversion from bool
54+
EXPECT_EQ(float16(true).x, 0x3c00);
55+
EXPECT_EQ(float16(false).x, 0x0000);
56+
57+
// Implicit conversion to and from Eigen::half
58+
Eigen::half tmp = float16(1.0f);
59+
float16 v_conv = tmp;
60+
EXPECT_EQ(tmp.x, 0x3c00);
61+
EXPECT_EQ(v_conv.x, 0x3c00);
62+
63+
// Default constructor
64+
float16 v_def;
65+
EXPECT_EQ(v_def.x, 0x0000);
66+
67+
// Assignment operator
68+
float16 v_assign;
69+
v_assign = v_def;
70+
EXPECT_EQ(v_assign.x, 0x0000);
71+
v_assign = Eigen::half(1.0f);
72+
EXPECT_EQ(v_assign.x, 0x3c00);
73+
v_assign = 0.5f;
74+
EXPECT_EQ(v_assign.x, 0x3800);
75+
v_assign = 0.33333;
76+
EXPECT_EQ(v_assign.x, 0x3555);
77+
v_assign = -1;
78+
EXPECT_EQ(v_assign.x, 0xbc00);
79+
v_assign = true;
80+
EXPECT_EQ(v_assign.x, 0x3c00);
81+
82+
// Conversion operator
83+
EXPECT_EQ(Eigen::half(float16(1.0f)).x, 0x3c00);
84+
EXPECT_EQ(float(float16(0.5f)), 0.5f);
85+
EXPECT_NEAR(double(float16(0.33333)), 0.33333, 0.0001);
86+
EXPECT_EQ(int(float16(-1)), -1);
87+
EXPECT_EQ(bool(float16(true)), true);
4388
}
4489

45-
TEST(float16, arithmetic_cpu) { EXPECT_EQ(float(float16(2) + float16(2)), 4); }
90+
TEST(float16, arithmetic_cpu) {
91+
EXPECT_EQ(float(float16(1) + float16(1)), 2);
92+
EXPECT_EQ(float(float16(5) + float16(-5)), 0);
93+
EXPECT_NEAR(float(float16(0.33333f) + float16(0.66667f)), 1.0f, 0.001);
94+
EXPECT_EQ(float(float16(3) - float16(5)), -2);
95+
EXPECT_NEAR(float(float16(0.66667f) - float16(0.33333f)), 0.33334f, 0.001);
96+
EXPECT_NEAR(float(float16(3.3f) * float16(2.0f)), 6.6f, 0.01);
97+
EXPECT_NEAR(float(float16(-2.1f) * float16(-3.0f)), 6.3f, 0.01);
98+
EXPECT_NEAR(float(float16(2.0f) / float16(3.0f)), 0.66667f, 0.001);
99+
EXPECT_EQ(float(float16(1.0f) / float16(2.0f)), 0.5f);
100+
EXPECT_EQ(float(-float16(512.0f)), -512.0f);
101+
EXPECT_EQ(float(-float16(-512.0f)), 512.0f);
102+
}
46103

47-
TEST(float16, comparison_cpu) { EXPECT_TRUE(float16(1.0f) > float16(0.5f)); }
104+
TEST(float16, comparison_cpu) {
105+
EXPECT_TRUE(float16(1.0f) == float16(1.0f));
106+
EXPECT_FALSE(float16(-1.0f) == float16(-0.5f));
107+
EXPECT_TRUE(float16(1.0f) != float16(0.5f));
108+
EXPECT_FALSE(float16(-1.0f) != float16(-1.0f));
109+
EXPECT_TRUE(float16(1.0f) < float16(2.0f));
110+
EXPECT_FALSE(float16(-1.0f) < float16(-1.0f));
111+
EXPECT_TRUE(float16(1.0f) <= float16(1.0f));
112+
EXPECT_TRUE(float16(2.0f) > float16(1.0f));
113+
EXPECT_FALSE(float16(-2.0f) > float16(-2.0f));
114+
EXPECT_TRUE(float16(2.0f) >= float16(2.0f));
115+
116+
EXPECT_TRUE(float16(0.0f) == float16(-0.0f));
117+
EXPECT_TRUE(float16(0.0f) <= float16(-0.0f));
118+
EXPECT_TRUE(float16(0.0f) >= float16(-0.0f));
119+
EXPECT_FALSE(float16(0.0f) < float16(-0.0f));
120+
EXPECT_FALSE(float16(-0.0f) < float16(0.0f));
121+
EXPECT_FALSE(float16(0.0f) > float16(-0.0f));
122+
EXPECT_FALSE(float16(-0.0f) > float16(0.0f));
123+
}
48124

49125
} // namespace paddle

0 commit comments

Comments
 (0)