Skip to content

Commit 39ac9e3

Browse files
authored
float16 type support enhance (#12181)
* cherry picked * "cherry picked platform" * "add comment" * "fix ci"
1 parent 19ef4ba commit 39ac9e3

File tree

7 files changed

+334
-7
lines changed

7 files changed

+334
-7
lines changed

paddle/fluid/platform/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,7 @@ cc_test(profiler_test SRCS profiler_test.cc DEPS profiler)
6060

6161
nv_test(float16_gpu_test SRCS float16_test.cu DEPS lod_tensor)
6262
cc_test(float16_test SRCS float16_test.cc DEPS lod_tensor)
63+
64+
IF(WITH_GPU)
65+
nv_test(cuda_helper_test SRCS cuda_helper_test.cu)
66+
ENDIF()

paddle/fluid/platform/cuda_device_function.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ limitations under the License. */
1414

1515
#pragma once
1616
#include <cuda.h>
17+
// NOTE(): support float16 to half in header file.
18+
#define PADDLE_CUDA_FP16
19+
#include <cuda_fp16.h>
20+
#include "paddle/fluid/platform/float16.h"
1721

1822
namespace paddle {
1923
namespace platform {
@@ -36,6 +40,18 @@ __forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val,
3640
#endif
3741
}
3842

43+
// CUDA 9.0 have native compatible float16 shfl_down
44+
#if CUDA_VERSION < 9000
45+
template <>
46+
__forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask,
47+
float16 val, int delta,
48+
int width) {
49+
half tmp = static_cast<half>(val);
50+
__shfl_down(tmp, static_cast<unsigned>(delta), width);
51+
return float16(tmp);
52+
}
53+
#endif
54+
3955
template <typename T>
4056
__forceinline__ __device__ T CudaShuffleSync(unsigned mask, T val, int src_line,
4157
int width = 32) {
@@ -46,6 +62,11 @@ __forceinline__ __device__ T CudaShuffleSync(unsigned mask, T val, int src_line,
4662
#endif
4763
}
4864

65+
template <typename T>
66+
HOSTDEVICE T Infinity() {
67+
return INFINITY;
68+
}
69+
4970
template <typename T>
5071
__device__ T reduceSum(T val, int tid, int len) {
5172
// NOTE(zcd): The warp size should be taken from the
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <gtest/gtest.h>
16+
#include <bitset>
17+
#include <iostream>
18+
#include <random>
19+
20+
#define PADDLE_CUDA_FP16
21+
#include "paddle/fluid/platform/cuda_device_function.h"
22+
#include "paddle/fluid/platform/cuda_primitives.h"
23+
#include "paddle/fluid/platform/float16.h"
24+
25+
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
26+
using paddle::platform::float16;
27+
28+
#define CUDA_ATOMIC_KERNEL(op, T) \
29+
__global__ void op##Kernel(const T* data_a, T* data_b, size_t num) { \
30+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; \
31+
i += blockDim.x * gridDim.x) { \
32+
paddle::platform::CudaAtomic##op(&data_b[i], data_a[i]); \
33+
} \
34+
}
35+
36+
template <typename T>
37+
struct AddFunctor {
38+
T operator()(const T& a, const T& b) { return a + b; }
39+
};
40+
41+
template <typename T>
42+
struct SubFunctor {
43+
T operator()(const T& a, const T& b) { return a - b; }
44+
};
45+
46+
// NOTE(dzhwinter): the float16 add has small underflow/overflow
47+
// so we use EXPECT_NEAR to check the result.
48+
#define ARITHMETIC_KERNEL_LAUNCH(op, T) \
49+
void Test##T##op(size_t num) { \
50+
T *in1, *in2, *out; \
51+
T *d_in1, *d_in2; \
52+
size_t size = sizeof(T) * num; \
53+
cudaMalloc(reinterpret_cast<void**>(&d_in1), size); \
54+
cudaMalloc(reinterpret_cast<void**>(&d_in2), size); \
55+
in1 = reinterpret_cast<T*>(malloc(size)); \
56+
in2 = reinterpret_cast<T*>(malloc(size)); \
57+
out = reinterpret_cast<T*>(malloc(size)); \
58+
std::minstd_rand engine; \
59+
std::uniform_real_distribution<double> dist(0.0, 1.0); \
60+
for (size_t i = 0; i < num; ++i) { \
61+
in1[i] = static_cast<T>(dist(engine)); \
62+
in2[i] = static_cast<T>(dist(engine)); \
63+
} \
64+
cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); \
65+
cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice); \
66+
op##Kernel<<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num); \
67+
cudaDeviceSynchronize(); \
68+
cudaMemcpy(out, d_in2, size, cudaMemcpyDeviceToHost); \
69+
cudaDeviceSynchronize(); \
70+
for (size_t i = 0; i < num; ++i) { \
71+
EXPECT_NEAR(static_cast<float>(out[i]), \
72+
static_cast<float>(op##Functor<T>()(in1[i], in2[i])), \
73+
0.001); \
74+
} \
75+
free(in1); \
76+
free(in2); \
77+
free(out); \
78+
cudaFree(d_in1); \
79+
cudaFree(d_in2); \
80+
}
81+
CUDA_ATOMIC_KERNEL(Add, float);
82+
CUDA_ATOMIC_KERNEL(Add, double);
83+
CUDA_ATOMIC_KERNEL(Add, float16);
84+
85+
ARITHMETIC_KERNEL_LAUNCH(Add, float);
86+
ARITHMETIC_KERNEL_LAUNCH(Add, double);
87+
ARITHMETIC_KERNEL_LAUNCH(Add, float16);
88+
89+
namespace paddle {
90+
namespace platform {
91+
USE_CUDA_ATOMIC(Sub, int);
92+
};
93+
};
94+
CUDA_ATOMIC_KERNEL(Sub, int);
95+
ARITHMETIC_KERNEL_LAUNCH(Sub, int);
96+
97+
// cuda primitives
98+
TEST(CudaAtomic, Add) {
99+
TestfloatAdd(static_cast<size_t>(10));
100+
TestfloatAdd(static_cast<size_t>(1024 * 1024));
101+
TestdoubleAdd(static_cast<size_t>(10));
102+
TestdoubleAdd(static_cast<size_t>(1024 * 1024));
103+
}
104+
105+
TEST(CudaAtomic, Sub) {
106+
TestintSub(static_cast<size_t>(10));
107+
TestintSub(static_cast<size_t>(1024 * 1024));
108+
}
109+
110+
TEST(CudaAtomic, float16) {
111+
using paddle::platform::float16;
112+
Testfloat16Add(static_cast<size_t>(1));
113+
Testfloat16Add(static_cast<size_t>(2));
114+
Testfloat16Add(static_cast<size_t>(3));
115+
116+
Testfloat16Add(static_cast<size_t>(10));
117+
Testfloat16Add(static_cast<size_t>(1024 * 1024));
118+
}

paddle/fluid/platform/cuda_primitives.h

Lines changed: 69 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@ limitations under the License. */
1414

1515
#pragma once
1616
#include <cuda.h>
17+
#include <stdio.h>
18+
#include "paddle/fluid/platform/float16.h"
1719

1820
namespace paddle {
1921
namespace platform {
2022

2123
#define CUDA_ATOMIC_WRAPPER(op, T) \
22-
__device__ __forceinline__ T CudaAtomic##op(T* address, const T val)
24+
__device__ __forceinline__ T CudaAtomic##op(T *address, const T val)
2325

2426
#define USE_CUDA_ATOMIC(op, T) \
2527
CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }
@@ -42,17 +44,17 @@ CUDA_ATOMIC_WRAPPER(Add, int64_t) {
4244
static_assert(sizeof(int64_t) == sizeof(long long int), // NOLINT
4345
"long long should be int64");
4446
return CudaAtomicAdd(
45-
reinterpret_cast<unsigned long long int*>(address), // NOLINT
46-
static_cast<unsigned long long int>(val)); // NOLINT
47+
reinterpret_cast<unsigned long long int *>(address), // NOLINT
48+
static_cast<unsigned long long int>(val)); // NOLINT
4749
}
4850

4951
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
5052
USE_CUDA_ATOMIC(Add, double);
5153
#else
5254
CUDA_ATOMIC_WRAPPER(Add, double) {
53-
unsigned long long int* address_as_ull = // NOLINT
54-
reinterpret_cast<unsigned long long int*>(address); // NOLINT
55-
unsigned long long int old = *address_as_ull, assumed; // NOLINT
55+
unsigned long long int *address_as_ull = // NOLINT
56+
reinterpret_cast<unsigned long long int *>(address); // NOLINT
57+
unsigned long long int old = *address_as_ull, assumed; // NOLINT
5658

5759
do {
5860
assumed = old;
@@ -64,6 +66,67 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
6466

6567
return __longlong_as_double(old);
6668
}
69+
#endif
70+
71+
#ifdef PADDLE_CUDA_FP16
72+
// NOTE(dzhwinter): cuda do not have atomicCAS for half.
73+
// Just use the half address as a unsigned value address and
74+
// do the atomicCAS. According to the value store at high 16 bits
75+
// or low 16 bits, then do a different sum and CAS.
76+
// Given most warp-threads will failed on the atomicCAS, so this
77+
// implemented should be avoided in high concurrency. It's will be
78+
// slower than the way convert value into 32bits and do a full atomicCAS.
79+
80+
// convert the value into float and do the add arithmetic.
81+
// then store the result into a uint32.
82+
inline __device__ uint32_t add_to_low_half(uint32_t val, float x) {
83+
float16 low_half;
84+
// the float16 in lower 16bits
85+
low_half.x = static_cast<uint16_t>(val & 0xffffu);
86+
low_half = static_cast<float16>(static_cast<float>(low_half) + x);
87+
return (val & 0xffff0000u) | low_half.x;
88+
}
89+
90+
inline __device__ uint32_t add_to_high_half(uint32_t val, float x) {
91+
float16 high_half;
92+
// the float16 in higher 16bits
93+
high_half.x = static_cast<uint16_t>(val >> 16);
94+
high_half = static_cast<float16>(static_cast<float>(high_half) + x);
95+
return (val & 0xffffu) | (static_cast<uint32_t>(high_half.x) << 16);
96+
}
97+
98+
CUDA_ATOMIC_WRAPPER(Add, float16) {
99+
// concrete packed float16 value may exsits in lower or higher 16bits
100+
// of the 32bits address.
101+
uint32_t *address_as_ui =
102+
reinterpret_cast<uint32_t *>(reinterpret_cast<char *>(address) -
103+
(reinterpret_cast<size_t>(address) & 2));
104+
float val_f = static_cast<float>(val);
105+
uint32_t old = *address_as_ui;
106+
uint32_t sum;
107+
uint32_t newval;
108+
uint32_t assumed;
109+
if (((size_t)address & 2) == 0) {
110+
// the float16 value stay at lower 16 bits of the address.
111+
do {
112+
assumed = old;
113+
old = atomicCAS(address_as_ui, assumed, add_to_low_half(assumed, val_f));
114+
} while (old != assumed);
115+
float16 ret;
116+
ret.x = old & 0xffffu;
117+
return ret;
118+
} else {
119+
// the float16 value stay at higher 16 bits of the address.
120+
do {
121+
assumed = old;
122+
old = atomicCAS(address_as_ui, assumed, add_to_high_half(assumed, val_f));
123+
} while (old != assumed);
124+
float16 ret;
125+
ret.x = old >> 16;
126+
return ret;
127+
}
128+
}
129+
67130
#endif
68131
} // namespace platform
69132
} // namespace paddle

paddle/fluid/platform/float16.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,11 @@ struct float16;
6767
} // namespace platform
6868
} // namespace paddle
6969

70+
// NOTE():
71+
// Do not move the eigen.h header, otherwise the eigen_vector<bool> will failed.
7072
#include "paddle/fluid/framework/eigen.h"
7173
#include "paddle/fluid/platform/hostdevice.h"
74+
#include "unsupported/Eigen/CXX11/Tensor"
7275

7376
namespace paddle {
7477
namespace platform {
@@ -898,6 +901,30 @@ struct is_pod<paddle::platform::float16> {
898901
is_standard_layout<paddle::platform::float16>::value;
899902
};
900903

904+
template <>
905+
struct is_floating_point<paddle::platform::float16>
906+
: std::integral_constant<
907+
bool, std::is_same<paddle::platform::float16,
908+
typename std::remove_cv<
909+
paddle::platform::float16>::type>::value> {};
910+
template <>
911+
struct is_signed<paddle::platform::float16> {
912+
static const bool value = true;
913+
};
914+
915+
template <>
916+
struct is_unsigned<paddle::platform::float16> {
917+
static const bool value = false;
918+
};
919+
920+
inline bool isnan(const paddle::platform::float16& a) {
921+
return paddle::platform::isnan(a);
922+
}
923+
924+
inline bool isinf(const paddle::platform::float16& a) {
925+
return paddle::platform::isinf(a);
926+
}
927+
901928
template <>
902929
struct numeric_limits<paddle::platform::float16> {
903930
static const bool is_specialized = true;

paddle/fluid/platform/float16_test.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,36 @@ TEST(float16, lod_tensor_cpu) {
141141
}
142142
}
143143

144+
TEST(float16, floating) {
145+
// compile time assert.
146+
PADDLE_ASSERT(std::is_floating_point<float16>::value);
147+
}
148+
144149
TEST(float16, print) {
145150
float16 a = float16(1.0f);
146151
std::cout << a << std::endl;
147152
}
148153

154+
// CPU test
155+
TEST(float16, isinf) {
156+
float16 a;
157+
a.x = 0x7c00;
158+
float16 b = float16(INFINITY);
159+
float16 c = static_cast<float16>(INFINITY);
160+
EXPECT_EQ(std::isinf(a), true);
161+
EXPECT_EQ(std::isinf(b), true);
162+
EXPECT_EQ(std::isinf(c), true);
163+
}
164+
165+
TEST(float16, isnan) {
166+
float16 a;
167+
a.x = 0x7fff;
168+
float16 b = float16(NAN);
169+
float16 c = static_cast<float16>(NAN);
170+
EXPECT_EQ(std::isnan(a), true);
171+
EXPECT_EQ(std::isnan(b), true);
172+
EXPECT_EQ(std::isnan(c), true);
173+
}
174+
149175
} // namespace platform
150176
} // namespace paddle

0 commit comments

Comments
 (0)