|
13 | 13 | // limitations under the License.
|
14 | 14 |
|
15 | 15 | #include <gtest/gtest.h>
|
16 |
| -#include <bitset> |
17 | 16 | #include <iostream>
|
18 | 17 | #include <random>
|
19 | 18 |
|
|
25 | 24 | using paddle::platform::PADDLE_CUDA_NUM_THREADS;
|
26 | 25 | using paddle::platform::float16;
|
27 | 26 |
|
28 |
| -#define CUDA_ATOMIC_KERNEL(op, T) \ |
29 |
| - __global__ void op##Kernel(const T* data_a, T* data_b, size_t num) { \ |
30 |
| - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; \ |
31 |
| - i += blockDim.x * gridDim.x) { \ |
32 |
| - paddle::platform::CudaAtomic##op(&data_b[i], data_a[i]); \ |
33 |
| - } \ |
| 27 | +template <typename T> |
| 28 | +__global__ void AddKernel(const T* data_a, T* data_b, size_t num) { |
| 29 | + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; |
| 30 | + i += blockDim.x * gridDim.x) { |
| 31 | + paddle::platform::CudaAtomicAdd(&data_b[i], data_a[i]); |
34 | 32 | }
|
| 33 | +} |
35 | 34 |
|
36 | 35 | template <typename T>
|
37 | 36 | struct AddFunctor {
|
38 | 37 | T operator()(const T& a, const T& b) { return a + b; }
|
39 | 38 | };
|
40 | 39 |
|
41 | 40 | template <typename T>
|
42 |
| -struct SubFunctor { |
43 |
| - T operator()(const T& a, const T& b) { return a - b; } |
44 |
| -}; |
45 |
| - |
46 |
| -// NOTE(dzhwinter): the float16 add has small underflow/overflow |
47 |
| -// so we use EXPECT_NEAR to check the result. |
48 |
| -#define ARITHMETIC_KERNEL_LAUNCH(op, T) \ |
49 |
| - void Test##T##op(size_t num) { \ |
50 |
| - T *in1, *in2, *out; \ |
51 |
| - T *d_in1, *d_in2; \ |
52 |
| - size_t size = sizeof(T) * num; \ |
53 |
| - cudaMalloc(reinterpret_cast<void**>(&d_in1), size); \ |
54 |
| - cudaMalloc(reinterpret_cast<void**>(&d_in2), size); \ |
55 |
| - in1 = reinterpret_cast<T*>(malloc(size)); \ |
56 |
| - in2 = reinterpret_cast<T*>(malloc(size)); \ |
57 |
| - out = reinterpret_cast<T*>(malloc(size)); \ |
58 |
| - std::minstd_rand engine; \ |
59 |
| - std::uniform_real_distribution<double> dist(0.0, 1.0); \ |
60 |
| - for (size_t i = 0; i < num; ++i) { \ |
61 |
| - in1[i] = static_cast<T>(dist(engine)); \ |
62 |
| - in2[i] = static_cast<T>(dist(engine)); \ |
63 |
| - } \ |
64 |
| - cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); \ |
65 |
| - cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice); \ |
66 |
| - op##Kernel<<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num); \ |
67 |
| - cudaDeviceSynchronize(); \ |
68 |
| - cudaMemcpy(out, d_in2, size, cudaMemcpyDeviceToHost); \ |
69 |
| - cudaDeviceSynchronize(); \ |
70 |
| - for (size_t i = 0; i < num; ++i) { \ |
71 |
| - EXPECT_NEAR(static_cast<float>(out[i]), \ |
72 |
| - static_cast<float>(op##Functor<T>()(in1[i], in2[i])), \ |
73 |
| - 0.001); \ |
74 |
| - } \ |
75 |
| - free(in1); \ |
76 |
| - free(in2); \ |
77 |
| - free(out); \ |
78 |
| - cudaFree(d_in1); \ |
79 |
| - cudaFree(d_in2); \ |
| 41 | +void TestCase(size_t num) { |
| 42 | + T *in1, *in2, *out; |
| 43 | + T *d_in1, *d_in2; |
| 44 | + size_t size = sizeof(T) * num; |
| 45 | + cudaMalloc(reinterpret_cast<void**>(&d_in1), size); |
| 46 | + cudaMalloc(reinterpret_cast<void**>(&d_in2), size); |
| 47 | + in1 = reinterpret_cast<T*>(malloc(size)); |
| 48 | + in2 = reinterpret_cast<T*>(malloc(size)); |
| 49 | + out = reinterpret_cast<T*>(malloc(size)); |
| 50 | + std::minstd_rand engine; |
| 51 | + std::uniform_real_distribution<double> dist(0.0, 1.0); |
| 52 | + for (size_t i = 0; i < num; ++i) { |
| 53 | + in1[i] = static_cast<T>(dist(engine)); |
| 54 | + in2[i] = static_cast<T>(dist(engine)); |
80 | 55 | }
|
81 |
| -CUDA_ATOMIC_KERNEL(Add, float); |
82 |
| -CUDA_ATOMIC_KERNEL(Add, double); |
83 |
| -CUDA_ATOMIC_KERNEL(Add, float16); |
84 |
| - |
85 |
| -ARITHMETIC_KERNEL_LAUNCH(Add, float); |
86 |
| -ARITHMETIC_KERNEL_LAUNCH(Add, double); |
87 |
| -ARITHMETIC_KERNEL_LAUNCH(Add, float16); |
88 |
| - |
89 |
| -namespace paddle { |
90 |
| -namespace platform { |
91 |
| -USE_CUDA_ATOMIC(Sub, int); |
92 |
| -}; |
93 |
| -}; |
94 |
| -CUDA_ATOMIC_KERNEL(Sub, int); |
95 |
| -ARITHMETIC_KERNEL_LAUNCH(Sub, int); |
| 56 | + cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); |
| 57 | + cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice); |
| 58 | + AddKernel<T><<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num); |
| 59 | + cudaDeviceSynchronize(); |
| 60 | + cudaMemcpy(out, d_in2, size, cudaMemcpyDeviceToHost); |
| 61 | + cudaDeviceSynchronize(); |
| 62 | + for (size_t i = 0; i < num; ++i) { |
| 63 | + // NOTE(dzhwinter): the float16 add has small underflow/overflow |
| 64 | + // so we use EXPECT_NEAR to check the result. |
| 65 | + EXPECT_NEAR(static_cast<float>(out[i]), |
| 66 | + static_cast<float>(AddFunctor<T>()(in1[i], in2[i])), 0.001); |
| 67 | + } |
| 68 | + free(in1); |
| 69 | + free(in2); |
| 70 | + free(out); |
| 71 | + cudaFree(d_in1); |
| 72 | + cudaFree(d_in2); |
| 73 | +} |
96 | 74 |
|
97 | 75 | // cuda primitives
|
98 | 76 | TEST(CudaAtomic, Add) {
|
99 |
| - TestfloatAdd(static_cast<size_t>(10)); |
100 |
| - TestfloatAdd(static_cast<size_t>(1024 * 1024)); |
101 |
| - TestdoubleAdd(static_cast<size_t>(10)); |
102 |
| - TestdoubleAdd(static_cast<size_t>(1024 * 1024)); |
103 |
| -} |
| 77 | + TestCase<float>(static_cast<size_t>(10)); |
| 78 | + TestCase<float>(static_cast<size_t>(1024 * 1024)); |
104 | 79 |
|
105 |
| -TEST(CudaAtomic, Sub) { |
106 |
| - TestintSub(static_cast<size_t>(10)); |
107 |
| - TestintSub(static_cast<size_t>(1024 * 1024)); |
| 80 | + TestCase<double>(static_cast<size_t>(10)); |
| 81 | + TestCase<double>(static_cast<size_t>(1024 * 1024)); |
108 | 82 | }
|
109 | 83 |
|
110 | 84 | TEST(CudaAtomic, float16) {
|
111 |
| - using paddle::platform::float16; |
112 |
| - Testfloat16Add(static_cast<size_t>(1)); |
113 |
| - Testfloat16Add(static_cast<size_t>(2)); |
114 |
| - Testfloat16Add(static_cast<size_t>(3)); |
| 85 | + TestCase<float16>(static_cast<size_t>(1)); |
| 86 | + TestCase<float16>(static_cast<size_t>(2)); |
| 87 | + TestCase<float16>(static_cast<size_t>(3)); |
| 88 | + |
| 89 | + TestCase<float16>(static_cast<size_t>(10)); |
| 90 | + TestCase<float16>(static_cast<size_t>(1024 * 1024)); |
| 91 | +} |
| 92 | + |
| 93 | +// unalignment of uint8 |
| 94 | +void TestUnalign(size_t num, const int shift_bit) { |
| 95 | + PADDLE_ENFORCE(num % 2 == 0, "must be a multiple of 2"); |
| 96 | + float16 *in1, *in2, *out; |
| 97 | + float16 *d_in1, *d_in2; |
| 98 | + size_t size = sizeof(uint8_t) * (num + shift_bit); |
| 99 | + size_t array_size = sizeof(float16) * (num / 2); |
| 100 | + |
| 101 | + cudaMalloc(reinterpret_cast<void**>(&d_in1), size); |
| 102 | + cudaMalloc(reinterpret_cast<void**>(&d_in2), size); |
| 103 | + in1 = reinterpret_cast<float16*>(malloc(size)); |
| 104 | + in2 = reinterpret_cast<float16*>(malloc(size)); |
| 105 | + out = reinterpret_cast<float16*>(malloc(size)); |
| 106 | + |
| 107 | + // right shift 1, mimic the unalignment of address |
| 108 | + float16* r_in1 = |
| 109 | + reinterpret_cast<float16*>(reinterpret_cast<uint8_t*>(in1) + shift_bit); |
| 110 | + float16* r_in2 = |
| 111 | + reinterpret_cast<float16*>(reinterpret_cast<uint8_t*>(in2) + shift_bit); |
| 112 | + |
| 113 | + std::minstd_rand engine; |
| 114 | + std::uniform_real_distribution<double> dist(0.0, 1.0); |
| 115 | + for (size_t i = 0; i < num / 2; ++i) { |
| 116 | + r_in1[i] = static_cast<float16>(dist(engine)); |
| 117 | + r_in2[i] = static_cast<float16>(dist(engine)); |
| 118 | + } |
| 119 | + cudaMemcpy(d_in1, r_in1, array_size, cudaMemcpyHostToDevice); |
| 120 | + cudaMemcpy(d_in2, r_in2, array_size, cudaMemcpyHostToDevice); |
| 121 | + AddKernel<float16><<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num / 2); |
| 122 | + cudaDeviceSynchronize(); |
| 123 | + cudaMemcpy(out, d_in2, array_size, cudaMemcpyDeviceToHost); |
| 124 | + cudaDeviceSynchronize(); |
| 125 | + for (size_t i = 0; i < num / 2; ++i) { |
| 126 | + // NOTE(dzhwinter): the float16 add has small underflow/overflow |
| 127 | + // so we use EXPECT_NEAR to check the result. |
| 128 | + EXPECT_NEAR(static_cast<float>(out[i]), |
| 129 | + static_cast<float>(AddFunctor<float16>()(r_in1[i], r_in2[i])), |
| 130 | + 0.001); |
| 131 | + } |
| 132 | + free(in1); |
| 133 | + free(in2); |
| 134 | + free(out); |
| 135 | + cudaFree(d_in1); |
| 136 | + cudaFree(d_in2); |
| 137 | +} |
| 138 | + |
| 139 | +TEST(CudaAtomic, float16Unalign) { |
| 140 | + // same with float16 testcase |
| 141 | + TestUnalign(static_cast<size_t>(2), /*shift_bit*/ 2); |
| 142 | + TestUnalign(static_cast<size_t>(1024), /*shift_bit*/ 2); |
| 143 | + TestUnalign(static_cast<size_t>(1024 * 1024), /*shift_bit*/ 2); |
| 144 | + |
| 145 | + // shift the address. |
| 146 | + TestUnalign(static_cast<size_t>(2), /*shift_bit*/ 1); |
| 147 | + TestUnalign(static_cast<size_t>(1024), /*shift_bit*/ 1); |
| 148 | + TestUnalign(static_cast<size_t>(1024 * 1024), /*shift_bit*/ 1); |
115 | 149 |
|
116 |
| - Testfloat16Add(static_cast<size_t>(10)); |
117 |
| - Testfloat16Add(static_cast<size_t>(1024 * 1024)); |
| 150 | + TestUnalign(static_cast<size_t>(2), /*shift_bit*/ 3); |
| 151 | + TestUnalign(static_cast<size_t>(1024), /*shift_bit*/ 3); |
| 152 | + TestUnalign(static_cast<size_t>(1024 * 1024), /*shift_bit*/ 3); |
118 | 153 | }
|
0 commit comments