Skip to content

Commit af37838

Browse files
committed
add test for float16
1 parent d9642cb commit af37838

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

paddle/math/float16.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ limitations under the License. */
2020
#include <istream>
2121
#include <ostream>
2222

23-
#include <cuda.h>
23+
#define USE_EIGEN
2424

2525
#ifdef USE_EIGEN // delete this #if macro
2626
#include "Eigen/src/Core/arch/CUDA/Half.h"
@@ -100,8 +100,6 @@ PADDLE_HOSTDEVICE inline float half_to_float(float16 h);
100100
struct PADDLE_ALIGN(2) float16 {
101101
uint16_t x;
102102

103-
// explicit for different types, implicit for half and Eigen::half
104-
105103
PADDLE_HOSTDEVICE inline float16() {}
106104

107105
PADDLE_HOSTDEVICE inline float16(const float16& h) : x(h.x) {}
@@ -120,7 +118,8 @@ struct PADDLE_ALIGN(2) float16 {
120118
PADDLE_HOSTDEVICE inline float16(const Eigen::half& h) : x(h.x) {}
121119
#endif // USE_EIGEN
122120

123-
#ifdef PADDLE_NEON
121+
#if (PADDLE_GNUC_VER >= 61 || PADDLE_CLANG_VER >= 34) && \
122+
defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16)
124123
// __fp16 is a native half precision data type for arm cpu,
125124
// float16_t is an alias for __fp16 in arm_fp16.h,
126125
// which is included in arm_neon.h.
@@ -208,7 +207,8 @@ struct PADDLE_ALIGN(2) float16 {
208207
}
209208
#endif // USE_EIGEN
210209

211-
#ifdef PADDLE_NEON
210+
#if (PADDLE_GNUC_VER >= 61 || PADDLE_CLANG_VER >= 34) && \
211+
defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16)
212212
PADDLE_HOSTDEVICE inline float16& operator=(const float16_t* rhs) {
213213
x = *reinterpret_cast<uint16_t*>(rhs);
214214
return *this;
@@ -302,7 +302,8 @@ struct PADDLE_ALIGN(2) float16 {
302302
}
303303
#endif // USE_EIGEN
304304

305-
#ifdef PADDLE_NEON
305+
#if (PADDLE_GNUC_VER >= 61 || PADDLE_CLANG_VER >= 34) && \
306+
defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16)
306307
// check whether it works or not
307308
PADDLE_HOSTDEVICE inline operator float16_t() const {
308309
float16 h = *this;
@@ -371,7 +372,6 @@ __device__ inline float16 operator*(const float16& a, const float16& b) {
371372

372373
__device__ inline float16 operator/(const float16& a, const float16& b) {
373374
// TODO(kexinzhao): check the cuda version that starts to support __hdiv
374-
// instinsic
375375
float num = __half2float(half(a));
376376
float denom = __half2float(half(b));
377377
return float16(num / denom);
@@ -595,7 +595,7 @@ constexpr int32_t minD = minC - subC - 1;
595595
PADDLE_HOSTDEVICE inline float16 float_to_half_rn(float f) {
596596
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
597597
half tmp = __float2half(f);
598-
return *reinterpret_cast<float16*>(&(tmp));
598+
return *reinterpret_cast<float16*>(&tmp);
599599

600600
#elif defined(PADDLE_NEON_64) // test on RPI
601601
float16 res;

paddle/math/tests/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ if(WITH_GPU)
2121
CUDA_ADD_EXECUTABLE(test_Tensor test_Tensor.cu)
2222
link_paddle_test(test_Tensor)
2323
CUDA_ADD_EXECUTABLE(test_lazyAssign test_lazyAssign.cu)
24-
link_paddle_test(test_lazyAssign)
24+
link_paddle_test(test_lazyAssign)
2525
else()
2626
compile_cu_as_cpp(test_Tensor.cu)
2727
add_unittest(test_Tensor test_Tensor.cu)
@@ -33,3 +33,4 @@ add_simple_unittest(test_FPException)
3333
add_simple_unittest(test_GpuProfiler)
3434
add_simple_unittest(test_BaseMatrix)
3535
add_simple_unittest(test_Matrix)
36+
add_simple_unittest(test_float16)

0 commit comments

Comments
 (0)