Skip to content

Commit 723dc58

Browse files
author
Raghuveer Devulapalli
committed
Revert "Move classes to a separate header file"
This reverts commit 66ec396.
1 parent 8c2066a commit 723dc58

7 files changed

+1136
-1149
lines changed

src/avx512-16bit-common.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,21 @@
1414
* sorting network (see
1515
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg)
1616
*/
17+
// ZMM register: 31,30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0
18+
static const uint16_t network[6][32]
19+
= {{7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8,
20+
23, 22, 21, 20, 19, 18, 17, 16, 31, 30, 29, 28, 27, 26, 25, 24},
21+
{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
22+
31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16},
23+
{4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11,
24+
20, 21, 22, 23, 16, 17, 18, 19, 28, 29, 30, 31, 24, 25, 26, 27},
25+
{31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16,
26+
15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0},
27+
{8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7,
28+
24, 25, 26, 27, 28, 29, 30, 31, 16, 17, 18, 19, 20, 21, 22, 23},
29+
{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
30+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}};
31+
1732
/*
1833
* Assumes zmm is random and performs a full sorting network defined in
1934
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg

src/avx512-16bit-qsort.hpp

Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,346 @@
99

1010
#include "avx512-16bit-common.h"
1111

12+
struct float16 {
13+
uint16_t val;
14+
};
15+
16+
template <>
17+
struct zmm_vector<float16> {
18+
using type_t = uint16_t;
19+
using zmm_t = __m512i;
20+
using ymm_t = __m256i;
21+
using opmask_t = __mmask32;
22+
static const uint8_t numlanes = 32;
23+
24+
static zmm_t get_network(int index)
25+
{
26+
return _mm512_loadu_si512(&network[index - 1][0]);
27+
}
28+
static type_t type_max()
29+
{
30+
return X86_SIMD_SORT_INFINITYH;
31+
}
32+
static type_t type_min()
33+
{
34+
return X86_SIMD_SORT_NEGINFINITYH;
35+
}
36+
static zmm_t zmm_max()
37+
{
38+
return _mm512_set1_epi16(type_max());
39+
}
40+
static opmask_t knot_opmask(opmask_t x)
41+
{
42+
return _knot_mask32(x);
43+
}
44+
45+
static opmask_t ge(zmm_t x, zmm_t y)
46+
{
47+
zmm_t sign_x = _mm512_and_si512(x, _mm512_set1_epi16(0x8000));
48+
zmm_t sign_y = _mm512_and_si512(y, _mm512_set1_epi16(0x8000));
49+
zmm_t exp_x = _mm512_and_si512(x, _mm512_set1_epi16(0x7c00));
50+
zmm_t exp_y = _mm512_and_si512(y, _mm512_set1_epi16(0x7c00));
51+
zmm_t mant_x = _mm512_and_si512(x, _mm512_set1_epi16(0x3ff));
52+
zmm_t mant_y = _mm512_and_si512(y, _mm512_set1_epi16(0x3ff));
53+
54+
__mmask32 mask_ge = _mm512_cmp_epu16_mask(
55+
sign_x, sign_y, _MM_CMPINT_LT); // only greater than
56+
__mmask32 sign_eq = _mm512_cmpeq_epu16_mask(sign_x, sign_y);
57+
__mmask32 neg = _mm512_mask_cmpeq_epu16_mask(
58+
sign_eq,
59+
sign_x,
60+
_mm512_set1_epi16(0x8000)); // both numbers are -ve
61+
62+
// compare exponents only if signs are equal:
63+
mask_ge = mask_ge
64+
| _mm512_mask_cmp_epu16_mask(
65+
sign_eq, exp_x, exp_y, _MM_CMPINT_NLE);
66+
// get mask for elements for which both sign and exponents are equal:
67+
__mmask32 exp_eq = _mm512_mask_cmpeq_epu16_mask(sign_eq, exp_x, exp_y);
68+
69+
// compare mantissa for elements for which both sign and expponent are equal:
70+
mask_ge = mask_ge
71+
| _mm512_mask_cmp_epu16_mask(
72+
exp_eq, mant_x, mant_y, _MM_CMPINT_NLT);
73+
return _kxor_mask32(mask_ge, neg);
74+
}
75+
static zmm_t loadu(void const *mem)
76+
{
77+
return _mm512_loadu_si512(mem);
78+
}
79+
static zmm_t max(zmm_t x, zmm_t y)
80+
{
81+
return _mm512_mask_mov_epi16(y, ge(x, y), x);
82+
}
83+
static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x)
84+
{
85+
// AVX512_VBMI2
86+
return _mm512_mask_compressstoreu_epi16(mem, mask, x);
87+
}
88+
static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem)
89+
{
90+
// AVX512BW
91+
return _mm512_mask_loadu_epi16(x, mask, mem);
92+
}
93+
static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y)
94+
{
95+
return _mm512_mask_mov_epi16(x, mask, y);
96+
}
97+
static void mask_storeu(void *mem, opmask_t mask, zmm_t x)
98+
{
99+
return _mm512_mask_storeu_epi16(mem, mask, x);
100+
}
101+
static zmm_t min(zmm_t x, zmm_t y)
102+
{
103+
return _mm512_mask_mov_epi16(x, ge(x, y), y);
104+
}
105+
static zmm_t permutexvar(__m512i idx, zmm_t zmm)
106+
{
107+
return _mm512_permutexvar_epi16(idx, zmm);
108+
}
109+
// Apparently this is a terrible for perf, npy_half_to_float seems to work
110+
// better
111+
//static float uint16_to_float(uint16_t val)
112+
//{
113+
// // Ideally use _mm_loadu_si16, but its only gcc > 11.x
114+
// // TODO: use inline ASM? https://godbolt.org/z/aGYvh7fMM
115+
// __m128i xmm = _mm_maskz_loadu_epi16(0x01, &val);
116+
// __m128 xmm2 = _mm_cvtph_ps(xmm);
117+
// return _mm_cvtss_f32(xmm2);
118+
//}
119+
static type_t float_to_uint16(float val)
120+
{
121+
__m128 xmm = _mm_load_ss(&val);
122+
__m128i xmm2 = _mm_cvtps_ph(xmm, _MM_FROUND_NO_EXC);
123+
return _mm_extract_epi16(xmm2, 0);
124+
}
125+
static type_t reducemax(zmm_t v)
126+
{
127+
__m512 lo = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 0));
128+
__m512 hi = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 1));
129+
float lo_max = _mm512_reduce_max_ps(lo);
130+
float hi_max = _mm512_reduce_max_ps(hi);
131+
return float_to_uint16(std::max(lo_max, hi_max));
132+
}
133+
static type_t reducemin(zmm_t v)
134+
{
135+
__m512 lo = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 0));
136+
__m512 hi = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 1));
137+
float lo_max = _mm512_reduce_min_ps(lo);
138+
float hi_max = _mm512_reduce_min_ps(hi);
139+
return float_to_uint16(std::min(lo_max, hi_max));
140+
}
141+
static zmm_t set1(type_t v)
142+
{
143+
return _mm512_set1_epi16(v);
144+
}
145+
template <uint8_t mask>
146+
static zmm_t shuffle(zmm_t zmm)
147+
{
148+
zmm = _mm512_shufflehi_epi16(zmm, (_MM_PERM_ENUM)mask);
149+
return _mm512_shufflelo_epi16(zmm, (_MM_PERM_ENUM)mask);
150+
}
151+
static void storeu(void *mem, zmm_t x)
152+
{
153+
return _mm512_storeu_si512(mem, x);
154+
}
155+
};
156+
157+
template <>
158+
struct zmm_vector<int16_t> {
159+
using type_t = int16_t;
160+
using zmm_t = __m512i;
161+
using ymm_t = __m256i;
162+
using opmask_t = __mmask32;
163+
static const uint8_t numlanes = 32;
164+
165+
static zmm_t get_network(int index)
166+
{
167+
return _mm512_loadu_si512(&network[index - 1][0]);
168+
}
169+
static type_t type_max()
170+
{
171+
return X86_SIMD_SORT_MAX_INT16;
172+
}
173+
static type_t type_min()
174+
{
175+
return X86_SIMD_SORT_MIN_INT16;
176+
}
177+
static zmm_t zmm_max()
178+
{
179+
return _mm512_set1_epi16(type_max());
180+
}
181+
static opmask_t knot_opmask(opmask_t x)
182+
{
183+
return _knot_mask32(x);
184+
}
185+
186+
static opmask_t ge(zmm_t x, zmm_t y)
187+
{
188+
return _mm512_cmp_epi16_mask(x, y, _MM_CMPINT_NLT);
189+
}
190+
static zmm_t loadu(void const *mem)
191+
{
192+
return _mm512_loadu_si512(mem);
193+
}
194+
static zmm_t max(zmm_t x, zmm_t y)
195+
{
196+
return _mm512_max_epi16(x, y);
197+
}
198+
static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x)
199+
{
200+
// AVX512_VBMI2
201+
return _mm512_mask_compressstoreu_epi16(mem, mask, x);
202+
}
203+
static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem)
204+
{
205+
// AVX512BW
206+
return _mm512_mask_loadu_epi16(x, mask, mem);
207+
}
208+
static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y)
209+
{
210+
return _mm512_mask_mov_epi16(x, mask, y);
211+
}
212+
static void mask_storeu(void *mem, opmask_t mask, zmm_t x)
213+
{
214+
return _mm512_mask_storeu_epi16(mem, mask, x);
215+
}
216+
static zmm_t min(zmm_t x, zmm_t y)
217+
{
218+
return _mm512_min_epi16(x, y);
219+
}
220+
static zmm_t permutexvar(__m512i idx, zmm_t zmm)
221+
{
222+
return _mm512_permutexvar_epi16(idx, zmm);
223+
}
224+
static type_t reducemax(zmm_t v)
225+
{
226+
zmm_t lo = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 0));
227+
zmm_t hi = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 1));
228+
type_t lo_max = (type_t)_mm512_reduce_max_epi32(lo);
229+
type_t hi_max = (type_t)_mm512_reduce_max_epi32(hi);
230+
return std::max(lo_max, hi_max);
231+
}
232+
static type_t reducemin(zmm_t v)
233+
{
234+
zmm_t lo = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 0));
235+
zmm_t hi = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 1));
236+
type_t lo_min = (type_t)_mm512_reduce_min_epi32(lo);
237+
type_t hi_min = (type_t)_mm512_reduce_min_epi32(hi);
238+
return std::min(lo_min, hi_min);
239+
}
240+
static zmm_t set1(type_t v)
241+
{
242+
return _mm512_set1_epi16(v);
243+
}
244+
template <uint8_t mask>
245+
static zmm_t shuffle(zmm_t zmm)
246+
{
247+
zmm = _mm512_shufflehi_epi16(zmm, (_MM_PERM_ENUM)mask);
248+
return _mm512_shufflelo_epi16(zmm, (_MM_PERM_ENUM)mask);
249+
}
250+
static void storeu(void *mem, zmm_t x)
251+
{
252+
return _mm512_storeu_si512(mem, x);
253+
}
254+
};
255+
template <>
256+
struct zmm_vector<uint16_t> {
257+
using type_t = uint16_t;
258+
using zmm_t = __m512i;
259+
using ymm_t = __m256i;
260+
using opmask_t = __mmask32;
261+
static const uint8_t numlanes = 32;
262+
263+
static zmm_t get_network(int index)
264+
{
265+
return _mm512_loadu_si512(&network[index - 1][0]);
266+
}
267+
static type_t type_max()
268+
{
269+
return X86_SIMD_SORT_MAX_UINT16;
270+
}
271+
static type_t type_min()
272+
{
273+
return 0;
274+
}
275+
static zmm_t zmm_max()
276+
{
277+
return _mm512_set1_epi16(type_max());
278+
}
279+
280+
static opmask_t knot_opmask(opmask_t x)
281+
{
282+
return _knot_mask32(x);
283+
}
284+
static opmask_t ge(zmm_t x, zmm_t y)
285+
{
286+
return _mm512_cmp_epu16_mask(x, y, _MM_CMPINT_NLT);
287+
}
288+
static zmm_t loadu(void const *mem)
289+
{
290+
return _mm512_loadu_si512(mem);
291+
}
292+
static zmm_t max(zmm_t x, zmm_t y)
293+
{
294+
return _mm512_max_epu16(x, y);
295+
}
296+
static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x)
297+
{
298+
return _mm512_mask_compressstoreu_epi16(mem, mask, x);
299+
}
300+
static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem)
301+
{
302+
return _mm512_mask_loadu_epi16(x, mask, mem);
303+
}
304+
static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y)
305+
{
306+
return _mm512_mask_mov_epi16(x, mask, y);
307+
}
308+
static void mask_storeu(void *mem, opmask_t mask, zmm_t x)
309+
{
310+
return _mm512_mask_storeu_epi16(mem, mask, x);
311+
}
312+
static zmm_t min(zmm_t x, zmm_t y)
313+
{
314+
return _mm512_min_epu16(x, y);
315+
}
316+
static zmm_t permutexvar(__m512i idx, zmm_t zmm)
317+
{
318+
return _mm512_permutexvar_epi16(idx, zmm);
319+
}
320+
static type_t reducemax(zmm_t v)
321+
{
322+
zmm_t lo = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 0));
323+
zmm_t hi = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 1));
324+
type_t lo_max = (type_t)_mm512_reduce_max_epi32(lo);
325+
type_t hi_max = (type_t)_mm512_reduce_max_epi32(hi);
326+
return std::max(lo_max, hi_max);
327+
}
328+
static type_t reducemin(zmm_t v)
329+
{
330+
zmm_t lo = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 0));
331+
zmm_t hi = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 1));
332+
type_t lo_min = (type_t)_mm512_reduce_min_epi32(lo);
333+
type_t hi_min = (type_t)_mm512_reduce_min_epi32(hi);
334+
return std::min(lo_min, hi_min);
335+
}
336+
static zmm_t set1(type_t v)
337+
{
338+
return _mm512_set1_epi16(v);
339+
}
340+
template <uint8_t mask>
341+
static zmm_t shuffle(zmm_t zmm)
342+
{
343+
zmm = _mm512_shufflehi_epi16(zmm, (_MM_PERM_ENUM)mask);
344+
return _mm512_shufflelo_epi16(zmm, (_MM_PERM_ENUM)mask);
345+
}
346+
static void storeu(void *mem, zmm_t x)
347+
{
348+
return _mm512_storeu_si512(mem, x);
349+
}
350+
};
351+
12352
template <>
13353
bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
14354
{

0 commit comments

Comments
 (0)