Skip to content

Commit 6d3da45

Browse files
authored
Fix/float16 style (#12446)
* "rewrite the test case" * "follow comment"
1 parent 91fb015 commit 6d3da45

File tree

2 files changed

+119
-84
lines changed

2 files changed

+119
-84
lines changed

paddle/fluid/platform/cuda_helper_test.cu

Lines changed: 109 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
// limitations under the License.
1414

1515
#include <gtest/gtest.h>
16-
#include <bitset>
1716
#include <iostream>
1817
#include <random>
1918

@@ -25,94 +24,130 @@
2524
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
2625
using paddle::platform::float16;
2726

28-
#define CUDA_ATOMIC_KERNEL(op, T) \
29-
__global__ void op##Kernel(const T* data_a, T* data_b, size_t num) { \
30-
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; \
31-
i += blockDim.x * gridDim.x) { \
32-
paddle::platform::CudaAtomic##op(&data_b[i], data_a[i]); \
33-
} \
27+
template <typename T>
28+
__global__ void AddKernel(const T* data_a, T* data_b, size_t num) {
29+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
30+
i += blockDim.x * gridDim.x) {
31+
paddle::platform::CudaAtomicAdd(&data_b[i], data_a[i]);
3432
}
33+
}
3534

3635
template <typename T>
3736
struct AddFunctor {
3837
T operator()(const T& a, const T& b) { return a + b; }
3938
};
4039

4140
template <typename T>
42-
struct SubFunctor {
43-
T operator()(const T& a, const T& b) { return a - b; }
44-
};
45-
46-
// NOTE(dzhwinter): the float16 add has small underflow/overflow
47-
// so we use EXPECT_NEAR to check the result.
48-
#define ARITHMETIC_KERNEL_LAUNCH(op, T) \
49-
void Test##T##op(size_t num) { \
50-
T *in1, *in2, *out; \
51-
T *d_in1, *d_in2; \
52-
size_t size = sizeof(T) * num; \
53-
cudaMalloc(reinterpret_cast<void**>(&d_in1), size); \
54-
cudaMalloc(reinterpret_cast<void**>(&d_in2), size); \
55-
in1 = reinterpret_cast<T*>(malloc(size)); \
56-
in2 = reinterpret_cast<T*>(malloc(size)); \
57-
out = reinterpret_cast<T*>(malloc(size)); \
58-
std::minstd_rand engine; \
59-
std::uniform_real_distribution<double> dist(0.0, 1.0); \
60-
for (size_t i = 0; i < num; ++i) { \
61-
in1[i] = static_cast<T>(dist(engine)); \
62-
in2[i] = static_cast<T>(dist(engine)); \
63-
} \
64-
cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); \
65-
cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice); \
66-
op##Kernel<<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num); \
67-
cudaDeviceSynchronize(); \
68-
cudaMemcpy(out, d_in2, size, cudaMemcpyDeviceToHost); \
69-
cudaDeviceSynchronize(); \
70-
for (size_t i = 0; i < num; ++i) { \
71-
EXPECT_NEAR(static_cast<float>(out[i]), \
72-
static_cast<float>(op##Functor<T>()(in1[i], in2[i])), \
73-
0.001); \
74-
} \
75-
free(in1); \
76-
free(in2); \
77-
free(out); \
78-
cudaFree(d_in1); \
79-
cudaFree(d_in2); \
41+
void TestCase(size_t num) {
42+
T *in1, *in2, *out;
43+
T *d_in1, *d_in2;
44+
size_t size = sizeof(T) * num;
45+
cudaMalloc(reinterpret_cast<void**>(&d_in1), size);
46+
cudaMalloc(reinterpret_cast<void**>(&d_in2), size);
47+
in1 = reinterpret_cast<T*>(malloc(size));
48+
in2 = reinterpret_cast<T*>(malloc(size));
49+
out = reinterpret_cast<T*>(malloc(size));
50+
std::minstd_rand engine;
51+
std::uniform_real_distribution<double> dist(0.0, 1.0);
52+
for (size_t i = 0; i < num; ++i) {
53+
in1[i] = static_cast<T>(dist(engine));
54+
in2[i] = static_cast<T>(dist(engine));
8055
}
81-
CUDA_ATOMIC_KERNEL(Add, float);
82-
CUDA_ATOMIC_KERNEL(Add, double);
83-
CUDA_ATOMIC_KERNEL(Add, float16);
84-
85-
ARITHMETIC_KERNEL_LAUNCH(Add, float);
86-
ARITHMETIC_KERNEL_LAUNCH(Add, double);
87-
ARITHMETIC_KERNEL_LAUNCH(Add, float16);
88-
89-
namespace paddle {
90-
namespace platform {
91-
USE_CUDA_ATOMIC(Sub, int);
92-
};
93-
};
94-
CUDA_ATOMIC_KERNEL(Sub, int);
95-
ARITHMETIC_KERNEL_LAUNCH(Sub, int);
56+
cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice);
57+
cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice);
58+
AddKernel<T><<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num);
59+
cudaDeviceSynchronize();
60+
cudaMemcpy(out, d_in2, size, cudaMemcpyDeviceToHost);
61+
cudaDeviceSynchronize();
62+
for (size_t i = 0; i < num; ++i) {
63+
// NOTE(dzhwinter): the float16 add has small underflow/overflow
64+
// so we use EXPECT_NEAR to check the result.
65+
EXPECT_NEAR(static_cast<float>(out[i]),
66+
static_cast<float>(AddFunctor<T>()(in1[i], in2[i])), 0.001);
67+
}
68+
free(in1);
69+
free(in2);
70+
free(out);
71+
cudaFree(d_in1);
72+
cudaFree(d_in2);
73+
}
9674

9775
// cuda primitives
9876
TEST(CudaAtomic, Add) {
99-
TestfloatAdd(static_cast<size_t>(10));
100-
TestfloatAdd(static_cast<size_t>(1024 * 1024));
101-
TestdoubleAdd(static_cast<size_t>(10));
102-
TestdoubleAdd(static_cast<size_t>(1024 * 1024));
103-
}
77+
TestCase<float>(static_cast<size_t>(10));
78+
TestCase<float>(static_cast<size_t>(1024 * 1024));
10479

105-
TEST(CudaAtomic, Sub) {
106-
TestintSub(static_cast<size_t>(10));
107-
TestintSub(static_cast<size_t>(1024 * 1024));
80+
TestCase<double>(static_cast<size_t>(10));
81+
TestCase<double>(static_cast<size_t>(1024 * 1024));
10882
}
10983

11084
TEST(CudaAtomic, float16) {
111-
using paddle::platform::float16;
112-
Testfloat16Add(static_cast<size_t>(1));
113-
Testfloat16Add(static_cast<size_t>(2));
114-
Testfloat16Add(static_cast<size_t>(3));
85+
TestCase<float16>(static_cast<size_t>(1));
86+
TestCase<float16>(static_cast<size_t>(2));
87+
TestCase<float16>(static_cast<size_t>(3));
88+
89+
TestCase<float16>(static_cast<size_t>(10));
90+
TestCase<float16>(static_cast<size_t>(1024 * 1024));
91+
}
92+
93+
// unalignment of uint8
94+
void TestUnalign(size_t num, const int shift_bit) {
95+
PADDLE_ENFORCE(num % 2 == 0, "must be a multiple of 2");
96+
float16 *in1, *in2, *out;
97+
float16 *d_in1, *d_in2;
98+
size_t size = sizeof(uint8_t) * (num + shift_bit);
99+
size_t array_size = sizeof(float16) * (num / 2);
100+
101+
cudaMalloc(reinterpret_cast<void**>(&d_in1), size);
102+
cudaMalloc(reinterpret_cast<void**>(&d_in2), size);
103+
in1 = reinterpret_cast<float16*>(malloc(size));
104+
in2 = reinterpret_cast<float16*>(malloc(size));
105+
out = reinterpret_cast<float16*>(malloc(size));
106+
107+
// right shift 1, mimic the unalignment of address
108+
float16* r_in1 =
109+
reinterpret_cast<float16*>(reinterpret_cast<uint8_t*>(in1) + shift_bit);
110+
float16* r_in2 =
111+
reinterpret_cast<float16*>(reinterpret_cast<uint8_t*>(in2) + shift_bit);
112+
113+
std::minstd_rand engine;
114+
std::uniform_real_distribution<double> dist(0.0, 1.0);
115+
for (size_t i = 0; i < num / 2; ++i) {
116+
r_in1[i] = static_cast<float16>(dist(engine));
117+
r_in2[i] = static_cast<float16>(dist(engine));
118+
}
119+
cudaMemcpy(d_in1, r_in1, array_size, cudaMemcpyHostToDevice);
120+
cudaMemcpy(d_in2, r_in2, array_size, cudaMemcpyHostToDevice);
121+
AddKernel<float16><<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num / 2);
122+
cudaDeviceSynchronize();
123+
cudaMemcpy(out, d_in2, array_size, cudaMemcpyDeviceToHost);
124+
cudaDeviceSynchronize();
125+
for (size_t i = 0; i < num / 2; ++i) {
126+
// NOTE(dzhwinter): the float16 add has small underflow/overflow
127+
// so we use EXPECT_NEAR to check the result.
128+
EXPECT_NEAR(static_cast<float>(out[i]),
129+
static_cast<float>(AddFunctor<float16>()(r_in1[i], r_in2[i])),
130+
0.001);
131+
}
132+
free(in1);
133+
free(in2);
134+
free(out);
135+
cudaFree(d_in1);
136+
cudaFree(d_in2);
137+
}
138+
139+
TEST(CudaAtomic, float16Unalign) {
140+
// same with float16 testcase
141+
TestUnalign(static_cast<size_t>(2), /*shift_bit*/ 2);
142+
TestUnalign(static_cast<size_t>(1024), /*shift_bit*/ 2);
143+
TestUnalign(static_cast<size_t>(1024 * 1024), /*shift_bit*/ 2);
144+
145+
// shift the address.
146+
TestUnalign(static_cast<size_t>(2), /*shift_bit*/ 1);
147+
TestUnalign(static_cast<size_t>(1024), /*shift_bit*/ 1);
148+
TestUnalign(static_cast<size_t>(1024 * 1024), /*shift_bit*/ 1);
115149

116-
Testfloat16Add(static_cast<size_t>(10));
117-
Testfloat16Add(static_cast<size_t>(1024 * 1024));
150+
TestUnalign(static_cast<size_t>(2), /*shift_bit*/ 3);
151+
TestUnalign(static_cast<size_t>(1024), /*shift_bit*/ 3);
152+
TestUnalign(static_cast<size_t>(1024 * 1024), /*shift_bit*/ 3);
118153
}

paddle/fluid/platform/cuda_primitives.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,41 +79,41 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
7979

8080
// convert the value into float and do the add arithmetic.
8181
// then store the result into a uint32.
82-
inline __device__ uint32_t add_to_low_half(uint32_t val, float x) {
82+
inline static __device__ uint32_t add_to_low_half(uint32_t val, float x) {
8383
float16 low_half;
8484
// the float16 in lower 16bits
85-
low_half.x = static_cast<uint16_t>(val & 0xffffu);
85+
low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
8686
low_half = static_cast<float16>(static_cast<float>(low_half) + x);
87-
return (val & 0xffff0000u) | low_half.x;
87+
return (val & 0xFFFF0000u) | low_half.x;
8888
}
8989

90-
inline __device__ uint32_t add_to_high_half(uint32_t val, float x) {
90+
inline static __device__ uint32_t add_to_high_half(uint32_t val, float x) {
9191
float16 high_half;
9292
// the float16 in higher 16bits
9393
high_half.x = static_cast<uint16_t>(val >> 16);
9494
high_half = static_cast<float16>(static_cast<float>(high_half) + x);
95-
return (val & 0xffffu) | (static_cast<uint32_t>(high_half.x) << 16);
95+
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
9696
}
9797

9898
CUDA_ATOMIC_WRAPPER(Add, float16) {
9999
// concrete packed float16 value may exsits in lower or higher 16bits
100100
// of the 32bits address.
101-
uint32_t *address_as_ui =
102-
reinterpret_cast<uint32_t *>(reinterpret_cast<char *>(address) -
103-
(reinterpret_cast<size_t>(address) & 2));
101+
uint32_t *address_as_ui = reinterpret_cast<uint32_t *>(
102+
reinterpret_cast<char *>(address) -
103+
(reinterpret_cast<uintptr_t>(address) & 0x02));
104104
float val_f = static_cast<float>(val);
105105
uint32_t old = *address_as_ui;
106106
uint32_t sum;
107107
uint32_t newval;
108108
uint32_t assumed;
109-
if (((size_t)address & 2) == 0) {
109+
if (((uintptr_t)address & 0x02) == 0) {
110110
// the float16 value stay at lower 16 bits of the address.
111111
do {
112112
assumed = old;
113113
old = atomicCAS(address_as_ui, assumed, add_to_low_half(assumed, val_f));
114114
} while (old != assumed);
115115
float16 ret;
116-
ret.x = old & 0xffffu;
116+
ret.x = old & 0xFFFFu;
117117
return ret;
118118
} else {
119119
// the float16 value stay at higher 16 bits of the address.

0 commit comments

Comments
 (0)