Skip to content

Commit f9e3db7

Browse files
author
Raghuveer Devulapalli
committed
Move classes to a separate header file
1 parent ac6c10c commit f9e3db7

7 files changed

+1150
-1136
lines changed

src/avx512-16bit-common.h

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,6 @@
1414
* sorting network (see
1515
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg)
1616
*/
17-
// ZMM register: 31,30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0
18-
static const uint16_t network[6][32]
19-
= {{7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8,
20-
23, 22, 21, 20, 19, 18, 17, 16, 31, 30, 29, 28, 27, 26, 25, 24},
21-
{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
22-
31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16},
23-
{4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11,
24-
20, 21, 22, 23, 16, 17, 18, 19, 28, 29, 30, 31, 24, 25, 26, 27},
25-
{31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16,
26-
15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0},
27-
{8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7,
28-
24, 25, 26, 27, 28, 29, 30, 31, 16, 17, 18, 19, 20, 21, 22, 23},
29-
{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
30-
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}};
31-
3217
/*
3318
* Assumes zmm is random and performs a full sorting network defined in
3419
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg

src/avx512-16bit-qsort.hpp

Lines changed: 0 additions & 340 deletions
Original file line numberDiff line numberDiff line change
@@ -9,346 +9,6 @@
99

1010
#include "avx512-16bit-common.h"
1111

12-
struct float16 {
13-
uint16_t val;
14-
};
15-
16-
template <>
17-
struct zmm_vector<float16> {
18-
using type_t = uint16_t;
19-
using zmm_t = __m512i;
20-
using ymm_t = __m256i;
21-
using opmask_t = __mmask32;
22-
static const uint8_t numlanes = 32;
23-
24-
static zmm_t get_network(int index)
25-
{
26-
return _mm512_loadu_si512(&network[index - 1][0]);
27-
}
28-
static type_t type_max()
29-
{
30-
return X86_SIMD_SORT_INFINITYH;
31-
}
32-
static type_t type_min()
33-
{
34-
return X86_SIMD_SORT_NEGINFINITYH;
35-
}
36-
static zmm_t zmm_max()
37-
{
38-
return _mm512_set1_epi16(type_max());
39-
}
40-
static opmask_t knot_opmask(opmask_t x)
41-
{
42-
return _knot_mask32(x);
43-
}
44-
45-
static opmask_t ge(zmm_t x, zmm_t y)
46-
{
47-
zmm_t sign_x = _mm512_and_si512(x, _mm512_set1_epi16(0x8000));
48-
zmm_t sign_y = _mm512_and_si512(y, _mm512_set1_epi16(0x8000));
49-
zmm_t exp_x = _mm512_and_si512(x, _mm512_set1_epi16(0x7c00));
50-
zmm_t exp_y = _mm512_and_si512(y, _mm512_set1_epi16(0x7c00));
51-
zmm_t mant_x = _mm512_and_si512(x, _mm512_set1_epi16(0x3ff));
52-
zmm_t mant_y = _mm512_and_si512(y, _mm512_set1_epi16(0x3ff));
53-
54-
__mmask32 mask_ge = _mm512_cmp_epu16_mask(
55-
sign_x, sign_y, _MM_CMPINT_LT); // only greater than
56-
__mmask32 sign_eq = _mm512_cmpeq_epu16_mask(sign_x, sign_y);
57-
__mmask32 neg = _mm512_mask_cmpeq_epu16_mask(
58-
sign_eq,
59-
sign_x,
60-
_mm512_set1_epi16(0x8000)); // both numbers are -ve
61-
62-
// compare exponents only if signs are equal:
63-
mask_ge = mask_ge
64-
| _mm512_mask_cmp_epu16_mask(
65-
sign_eq, exp_x, exp_y, _MM_CMPINT_NLE);
66-
// get mask for elements for which both sign and exponents are equal:
67-
__mmask32 exp_eq = _mm512_mask_cmpeq_epu16_mask(sign_eq, exp_x, exp_y);
68-
69-
// compare mantissa for elements for which both sign and expponent are equal:
70-
mask_ge = mask_ge
71-
| _mm512_mask_cmp_epu16_mask(
72-
exp_eq, mant_x, mant_y, _MM_CMPINT_NLT);
73-
return _kxor_mask32(mask_ge, neg);
74-
}
75-
static zmm_t loadu(void const *mem)
76-
{
77-
return _mm512_loadu_si512(mem);
78-
}
79-
static zmm_t max(zmm_t x, zmm_t y)
80-
{
81-
return _mm512_mask_mov_epi16(y, ge(x, y), x);
82-
}
83-
static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x)
84-
{
85-
// AVX512_VBMI2
86-
return _mm512_mask_compressstoreu_epi16(mem, mask, x);
87-
}
88-
static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem)
89-
{
90-
// AVX512BW
91-
return _mm512_mask_loadu_epi16(x, mask, mem);
92-
}
93-
static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y)
94-
{
95-
return _mm512_mask_mov_epi16(x, mask, y);
96-
}
97-
static void mask_storeu(void *mem, opmask_t mask, zmm_t x)
98-
{
99-
return _mm512_mask_storeu_epi16(mem, mask, x);
100-
}
101-
static zmm_t min(zmm_t x, zmm_t y)
102-
{
103-
return _mm512_mask_mov_epi16(x, ge(x, y), y);
104-
}
105-
static zmm_t permutexvar(__m512i idx, zmm_t zmm)
106-
{
107-
return _mm512_permutexvar_epi16(idx, zmm);
108-
}
109-
// Apparently this is a terrible for perf, npy_half_to_float seems to work
110-
// better
111-
//static float uint16_to_float(uint16_t val)
112-
//{
113-
// // Ideally use _mm_loadu_si16, but its only gcc > 11.x
114-
// // TODO: use inline ASM? https://godbolt.org/z/aGYvh7fMM
115-
// __m128i xmm = _mm_maskz_loadu_epi16(0x01, &val);
116-
// __m128 xmm2 = _mm_cvtph_ps(xmm);
117-
// return _mm_cvtss_f32(xmm2);
118-
//}
119-
static type_t float_to_uint16(float val)
120-
{
121-
__m128 xmm = _mm_load_ss(&val);
122-
__m128i xmm2 = _mm_cvtps_ph(xmm, _MM_FROUND_NO_EXC);
123-
return _mm_extract_epi16(xmm2, 0);
124-
}
125-
static type_t reducemax(zmm_t v)
126-
{
127-
__m512 lo = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 0));
128-
__m512 hi = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 1));
129-
float lo_max = _mm512_reduce_max_ps(lo);
130-
float hi_max = _mm512_reduce_max_ps(hi);
131-
return float_to_uint16(std::max(lo_max, hi_max));
132-
}
133-
static type_t reducemin(zmm_t v)
134-
{
135-
__m512 lo = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 0));
136-
__m512 hi = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 1));
137-
float lo_max = _mm512_reduce_min_ps(lo);
138-
float hi_max = _mm512_reduce_min_ps(hi);
139-
return float_to_uint16(std::min(lo_max, hi_max));
140-
}
141-
static zmm_t set1(type_t v)
142-
{
143-
return _mm512_set1_epi16(v);
144-
}
145-
template <uint8_t mask>
146-
static zmm_t shuffle(zmm_t zmm)
147-
{
148-
zmm = _mm512_shufflehi_epi16(zmm, (_MM_PERM_ENUM)mask);
149-
return _mm512_shufflelo_epi16(zmm, (_MM_PERM_ENUM)mask);
150-
}
151-
static void storeu(void *mem, zmm_t x)
152-
{
153-
return _mm512_storeu_si512(mem, x);
154-
}
155-
};
156-
157-
template <>
158-
struct zmm_vector<int16_t> {
159-
using type_t = int16_t;
160-
using zmm_t = __m512i;
161-
using ymm_t = __m256i;
162-
using opmask_t = __mmask32;
163-
static const uint8_t numlanes = 32;
164-
165-
static zmm_t get_network(int index)
166-
{
167-
return _mm512_loadu_si512(&network[index - 1][0]);
168-
}
169-
static type_t type_max()
170-
{
171-
return X86_SIMD_SORT_MAX_INT16;
172-
}
173-
static type_t type_min()
174-
{
175-
return X86_SIMD_SORT_MIN_INT16;
176-
}
177-
static zmm_t zmm_max()
178-
{
179-
return _mm512_set1_epi16(type_max());
180-
}
181-
static opmask_t knot_opmask(opmask_t x)
182-
{
183-
return _knot_mask32(x);
184-
}
185-
186-
static opmask_t ge(zmm_t x, zmm_t y)
187-
{
188-
return _mm512_cmp_epi16_mask(x, y, _MM_CMPINT_NLT);
189-
}
190-
static zmm_t loadu(void const *mem)
191-
{
192-
return _mm512_loadu_si512(mem);
193-
}
194-
static zmm_t max(zmm_t x, zmm_t y)
195-
{
196-
return _mm512_max_epi16(x, y);
197-
}
198-
static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x)
199-
{
200-
// AVX512_VBMI2
201-
return _mm512_mask_compressstoreu_epi16(mem, mask, x);
202-
}
203-
static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem)
204-
{
205-
// AVX512BW
206-
return _mm512_mask_loadu_epi16(x, mask, mem);
207-
}
208-
static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y)
209-
{
210-
return _mm512_mask_mov_epi16(x, mask, y);
211-
}
212-
static void mask_storeu(void *mem, opmask_t mask, zmm_t x)
213-
{
214-
return _mm512_mask_storeu_epi16(mem, mask, x);
215-
}
216-
static zmm_t min(zmm_t x, zmm_t y)
217-
{
218-
return _mm512_min_epi16(x, y);
219-
}
220-
static zmm_t permutexvar(__m512i idx, zmm_t zmm)
221-
{
222-
return _mm512_permutexvar_epi16(idx, zmm);
223-
}
224-
static type_t reducemax(zmm_t v)
225-
{
226-
zmm_t lo = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 0));
227-
zmm_t hi = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 1));
228-
type_t lo_max = (type_t)_mm512_reduce_max_epi32(lo);
229-
type_t hi_max = (type_t)_mm512_reduce_max_epi32(hi);
230-
return std::max(lo_max, hi_max);
231-
}
232-
static type_t reducemin(zmm_t v)
233-
{
234-
zmm_t lo = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 0));
235-
zmm_t hi = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 1));
236-
type_t lo_min = (type_t)_mm512_reduce_min_epi32(lo);
237-
type_t hi_min = (type_t)_mm512_reduce_min_epi32(hi);
238-
return std::min(lo_min, hi_min);
239-
}
240-
static zmm_t set1(type_t v)
241-
{
242-
return _mm512_set1_epi16(v);
243-
}
244-
template <uint8_t mask>
245-
static zmm_t shuffle(zmm_t zmm)
246-
{
247-
zmm = _mm512_shufflehi_epi16(zmm, (_MM_PERM_ENUM)mask);
248-
return _mm512_shufflelo_epi16(zmm, (_MM_PERM_ENUM)mask);
249-
}
250-
static void storeu(void *mem, zmm_t x)
251-
{
252-
return _mm512_storeu_si512(mem, x);
253-
}
254-
};
255-
template <>
256-
struct zmm_vector<uint16_t> {
257-
using type_t = uint16_t;
258-
using zmm_t = __m512i;
259-
using ymm_t = __m256i;
260-
using opmask_t = __mmask32;
261-
static const uint8_t numlanes = 32;
262-
263-
static zmm_t get_network(int index)
264-
{
265-
return _mm512_loadu_si512(&network[index - 1][0]);
266-
}
267-
static type_t type_max()
268-
{
269-
return X86_SIMD_SORT_MAX_UINT16;
270-
}
271-
static type_t type_min()
272-
{
273-
return 0;
274-
}
275-
static zmm_t zmm_max()
276-
{
277-
return _mm512_set1_epi16(type_max());
278-
}
279-
280-
static opmask_t knot_opmask(opmask_t x)
281-
{
282-
return _knot_mask32(x);
283-
}
284-
static opmask_t ge(zmm_t x, zmm_t y)
285-
{
286-
return _mm512_cmp_epu16_mask(x, y, _MM_CMPINT_NLT);
287-
}
288-
static zmm_t loadu(void const *mem)
289-
{
290-
return _mm512_loadu_si512(mem);
291-
}
292-
static zmm_t max(zmm_t x, zmm_t y)
293-
{
294-
return _mm512_max_epu16(x, y);
295-
}
296-
static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x)
297-
{
298-
return _mm512_mask_compressstoreu_epi16(mem, mask, x);
299-
}
300-
static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem)
301-
{
302-
return _mm512_mask_loadu_epi16(x, mask, mem);
303-
}
304-
static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y)
305-
{
306-
return _mm512_mask_mov_epi16(x, mask, y);
307-
}
308-
static void mask_storeu(void *mem, opmask_t mask, zmm_t x)
309-
{
310-
return _mm512_mask_storeu_epi16(mem, mask, x);
311-
}
312-
static zmm_t min(zmm_t x, zmm_t y)
313-
{
314-
return _mm512_min_epu16(x, y);
315-
}
316-
static zmm_t permutexvar(__m512i idx, zmm_t zmm)
317-
{
318-
return _mm512_permutexvar_epi16(idx, zmm);
319-
}
320-
static type_t reducemax(zmm_t v)
321-
{
322-
zmm_t lo = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 0));
323-
zmm_t hi = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 1));
324-
type_t lo_max = (type_t)_mm512_reduce_max_epi32(lo);
325-
type_t hi_max = (type_t)_mm512_reduce_max_epi32(hi);
326-
return std::max(lo_max, hi_max);
327-
}
328-
static type_t reducemin(zmm_t v)
329-
{
330-
zmm_t lo = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 0));
331-
zmm_t hi = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 1));
332-
type_t lo_min = (type_t)_mm512_reduce_min_epi32(lo);
333-
type_t hi_min = (type_t)_mm512_reduce_min_epi32(hi);
334-
return std::min(lo_min, hi_min);
335-
}
336-
static zmm_t set1(type_t v)
337-
{
338-
return _mm512_set1_epi16(v);
339-
}
340-
template <uint8_t mask>
341-
static zmm_t shuffle(zmm_t zmm)
342-
{
343-
zmm = _mm512_shufflehi_epi16(zmm, (_MM_PERM_ENUM)mask);
344-
return _mm512_shufflelo_epi16(zmm, (_MM_PERM_ENUM)mask);
345-
}
346-
static void storeu(void *mem, zmm_t x)
347-
{
348-
return _mm512_storeu_si512(mem, x);
349-
}
350-
};
351-
35212
template <>
35313
bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
35414
{

0 commit comments

Comments
 (0)