|
9 | 9 |
|
10 | 10 | #include "avx512-16bit-common.h"
|
11 | 11 |
|
12 |
| -struct float16 { |
13 |
| - uint16_t val; |
14 |
| -}; |
15 |
| - |
16 |
| -template <> |
17 |
| -struct zmm_vector<float16> { |
18 |
| - using type_t = uint16_t; |
19 |
| - using zmm_t = __m512i; |
20 |
| - using ymm_t = __m256i; |
21 |
| - using opmask_t = __mmask32; |
22 |
| - static const uint8_t numlanes = 32; |
23 |
| - |
24 |
| - static zmm_t get_network(int index) |
25 |
| - { |
26 |
| - return _mm512_loadu_si512(&network[index - 1][0]); |
27 |
| - } |
28 |
| - static type_t type_max() |
29 |
| - { |
30 |
| - return X86_SIMD_SORT_INFINITYH; |
31 |
| - } |
32 |
| - static type_t type_min() |
33 |
| - { |
34 |
| - return X86_SIMD_SORT_NEGINFINITYH; |
35 |
| - } |
36 |
| - static zmm_t zmm_max() |
37 |
| - { |
38 |
| - return _mm512_set1_epi16(type_max()); |
39 |
| - } |
40 |
| - static opmask_t knot_opmask(opmask_t x) |
41 |
| - { |
42 |
| - return _knot_mask32(x); |
43 |
| - } |
44 |
| - |
45 |
| - static opmask_t ge(zmm_t x, zmm_t y) |
46 |
| - { |
47 |
| - zmm_t sign_x = _mm512_and_si512(x, _mm512_set1_epi16(0x8000)); |
48 |
| - zmm_t sign_y = _mm512_and_si512(y, _mm512_set1_epi16(0x8000)); |
49 |
| - zmm_t exp_x = _mm512_and_si512(x, _mm512_set1_epi16(0x7c00)); |
50 |
| - zmm_t exp_y = _mm512_and_si512(y, _mm512_set1_epi16(0x7c00)); |
51 |
| - zmm_t mant_x = _mm512_and_si512(x, _mm512_set1_epi16(0x3ff)); |
52 |
| - zmm_t mant_y = _mm512_and_si512(y, _mm512_set1_epi16(0x3ff)); |
53 |
| - |
54 |
| - __mmask32 mask_ge = _mm512_cmp_epu16_mask( |
55 |
| - sign_x, sign_y, _MM_CMPINT_LT); // only greater than |
56 |
| - __mmask32 sign_eq = _mm512_cmpeq_epu16_mask(sign_x, sign_y); |
57 |
| - __mmask32 neg = _mm512_mask_cmpeq_epu16_mask( |
58 |
| - sign_eq, |
59 |
| - sign_x, |
60 |
| - _mm512_set1_epi16(0x8000)); // both numbers are -ve |
61 |
| - |
62 |
| - // compare exponents only if signs are equal: |
63 |
| - mask_ge = mask_ge |
64 |
| - | _mm512_mask_cmp_epu16_mask( |
65 |
| - sign_eq, exp_x, exp_y, _MM_CMPINT_NLE); |
66 |
| - // get mask for elements for which both sign and exponents are equal: |
67 |
| - __mmask32 exp_eq = _mm512_mask_cmpeq_epu16_mask(sign_eq, exp_x, exp_y); |
68 |
| - |
69 |
| - // compare mantissa for elements for which both sign and expponent are equal: |
70 |
| - mask_ge = mask_ge |
71 |
| - | _mm512_mask_cmp_epu16_mask( |
72 |
| - exp_eq, mant_x, mant_y, _MM_CMPINT_NLT); |
73 |
| - return _kxor_mask32(mask_ge, neg); |
74 |
| - } |
75 |
| - static zmm_t loadu(void const *mem) |
76 |
| - { |
77 |
| - return _mm512_loadu_si512(mem); |
78 |
| - } |
79 |
| - static zmm_t max(zmm_t x, zmm_t y) |
80 |
| - { |
81 |
| - return _mm512_mask_mov_epi16(y, ge(x, y), x); |
82 |
| - } |
83 |
| - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) |
84 |
| - { |
85 |
| - // AVX512_VBMI2 |
86 |
| - return _mm512_mask_compressstoreu_epi16(mem, mask, x); |
87 |
| - } |
88 |
| - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) |
89 |
| - { |
90 |
| - // AVX512BW |
91 |
| - return _mm512_mask_loadu_epi16(x, mask, mem); |
92 |
| - } |
93 |
| - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) |
94 |
| - { |
95 |
| - return _mm512_mask_mov_epi16(x, mask, y); |
96 |
| - } |
97 |
| - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) |
98 |
| - { |
99 |
| - return _mm512_mask_storeu_epi16(mem, mask, x); |
100 |
| - } |
101 |
| - static zmm_t min(zmm_t x, zmm_t y) |
102 |
| - { |
103 |
| - return _mm512_mask_mov_epi16(x, ge(x, y), y); |
104 |
| - } |
105 |
| - static zmm_t permutexvar(__m512i idx, zmm_t zmm) |
106 |
| - { |
107 |
| - return _mm512_permutexvar_epi16(idx, zmm); |
108 |
| - } |
109 |
| - // Apparently this is a terrible for perf, npy_half_to_float seems to work |
110 |
| - // better |
111 |
| - //static float uint16_to_float(uint16_t val) |
112 |
| - //{ |
113 |
| - // // Ideally use _mm_loadu_si16, but its only gcc > 11.x |
114 |
| - // // TODO: use inline ASM? https://godbolt.org/z/aGYvh7fMM |
115 |
| - // __m128i xmm = _mm_maskz_loadu_epi16(0x01, &val); |
116 |
| - // __m128 xmm2 = _mm_cvtph_ps(xmm); |
117 |
| - // return _mm_cvtss_f32(xmm2); |
118 |
| - //} |
119 |
| - static type_t float_to_uint16(float val) |
120 |
| - { |
121 |
| - __m128 xmm = _mm_load_ss(&val); |
122 |
| - __m128i xmm2 = _mm_cvtps_ph(xmm, _MM_FROUND_NO_EXC); |
123 |
| - return _mm_extract_epi16(xmm2, 0); |
124 |
| - } |
125 |
| - static type_t reducemax(zmm_t v) |
126 |
| - { |
127 |
| - __m512 lo = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 0)); |
128 |
| - __m512 hi = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 1)); |
129 |
| - float lo_max = _mm512_reduce_max_ps(lo); |
130 |
| - float hi_max = _mm512_reduce_max_ps(hi); |
131 |
| - return float_to_uint16(std::max(lo_max, hi_max)); |
132 |
| - } |
133 |
| - static type_t reducemin(zmm_t v) |
134 |
| - { |
135 |
| - __m512 lo = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 0)); |
136 |
| - __m512 hi = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(v, 1)); |
137 |
| - float lo_max = _mm512_reduce_min_ps(lo); |
138 |
| - float hi_max = _mm512_reduce_min_ps(hi); |
139 |
| - return float_to_uint16(std::min(lo_max, hi_max)); |
140 |
| - } |
141 |
| - static zmm_t set1(type_t v) |
142 |
| - { |
143 |
| - return _mm512_set1_epi16(v); |
144 |
| - } |
145 |
| - template <uint8_t mask> |
146 |
| - static zmm_t shuffle(zmm_t zmm) |
147 |
| - { |
148 |
| - zmm = _mm512_shufflehi_epi16(zmm, (_MM_PERM_ENUM)mask); |
149 |
| - return _mm512_shufflelo_epi16(zmm, (_MM_PERM_ENUM)mask); |
150 |
| - } |
151 |
| - static void storeu(void *mem, zmm_t x) |
152 |
| - { |
153 |
| - return _mm512_storeu_si512(mem, x); |
154 |
| - } |
155 |
| -}; |
156 |
| - |
157 |
| -template <> |
158 |
| -struct zmm_vector<int16_t> { |
159 |
| - using type_t = int16_t; |
160 |
| - using zmm_t = __m512i; |
161 |
| - using ymm_t = __m256i; |
162 |
| - using opmask_t = __mmask32; |
163 |
| - static const uint8_t numlanes = 32; |
164 |
| - |
165 |
| - static zmm_t get_network(int index) |
166 |
| - { |
167 |
| - return _mm512_loadu_si512(&network[index - 1][0]); |
168 |
| - } |
169 |
| - static type_t type_max() |
170 |
| - { |
171 |
| - return X86_SIMD_SORT_MAX_INT16; |
172 |
| - } |
173 |
| - static type_t type_min() |
174 |
| - { |
175 |
| - return X86_SIMD_SORT_MIN_INT16; |
176 |
| - } |
177 |
| - static zmm_t zmm_max() |
178 |
| - { |
179 |
| - return _mm512_set1_epi16(type_max()); |
180 |
| - } |
181 |
| - static opmask_t knot_opmask(opmask_t x) |
182 |
| - { |
183 |
| - return _knot_mask32(x); |
184 |
| - } |
185 |
| - |
186 |
| - static opmask_t ge(zmm_t x, zmm_t y) |
187 |
| - { |
188 |
| - return _mm512_cmp_epi16_mask(x, y, _MM_CMPINT_NLT); |
189 |
| - } |
190 |
| - static zmm_t loadu(void const *mem) |
191 |
| - { |
192 |
| - return _mm512_loadu_si512(mem); |
193 |
| - } |
194 |
| - static zmm_t max(zmm_t x, zmm_t y) |
195 |
| - { |
196 |
| - return _mm512_max_epi16(x, y); |
197 |
| - } |
198 |
| - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) |
199 |
| - { |
200 |
| - // AVX512_VBMI2 |
201 |
| - return _mm512_mask_compressstoreu_epi16(mem, mask, x); |
202 |
| - } |
203 |
| - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) |
204 |
| - { |
205 |
| - // AVX512BW |
206 |
| - return _mm512_mask_loadu_epi16(x, mask, mem); |
207 |
| - } |
208 |
| - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) |
209 |
| - { |
210 |
| - return _mm512_mask_mov_epi16(x, mask, y); |
211 |
| - } |
212 |
| - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) |
213 |
| - { |
214 |
| - return _mm512_mask_storeu_epi16(mem, mask, x); |
215 |
| - } |
216 |
| - static zmm_t min(zmm_t x, zmm_t y) |
217 |
| - { |
218 |
| - return _mm512_min_epi16(x, y); |
219 |
| - } |
220 |
| - static zmm_t permutexvar(__m512i idx, zmm_t zmm) |
221 |
| - { |
222 |
| - return _mm512_permutexvar_epi16(idx, zmm); |
223 |
| - } |
224 |
| - static type_t reducemax(zmm_t v) |
225 |
| - { |
226 |
| - zmm_t lo = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 0)); |
227 |
| - zmm_t hi = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 1)); |
228 |
| - type_t lo_max = (type_t)_mm512_reduce_max_epi32(lo); |
229 |
| - type_t hi_max = (type_t)_mm512_reduce_max_epi32(hi); |
230 |
| - return std::max(lo_max, hi_max); |
231 |
| - } |
232 |
| - static type_t reducemin(zmm_t v) |
233 |
| - { |
234 |
| - zmm_t lo = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 0)); |
235 |
| - zmm_t hi = _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(v, 1)); |
236 |
| - type_t lo_min = (type_t)_mm512_reduce_min_epi32(lo); |
237 |
| - type_t hi_min = (type_t)_mm512_reduce_min_epi32(hi); |
238 |
| - return std::min(lo_min, hi_min); |
239 |
| - } |
240 |
| - static zmm_t set1(type_t v) |
241 |
| - { |
242 |
| - return _mm512_set1_epi16(v); |
243 |
| - } |
244 |
| - template <uint8_t mask> |
245 |
| - static zmm_t shuffle(zmm_t zmm) |
246 |
| - { |
247 |
| - zmm = _mm512_shufflehi_epi16(zmm, (_MM_PERM_ENUM)mask); |
248 |
| - return _mm512_shufflelo_epi16(zmm, (_MM_PERM_ENUM)mask); |
249 |
| - } |
250 |
| - static void storeu(void *mem, zmm_t x) |
251 |
| - { |
252 |
| - return _mm512_storeu_si512(mem, x); |
253 |
| - } |
254 |
| -}; |
255 |
| -template <> |
256 |
| -struct zmm_vector<uint16_t> { |
257 |
| - using type_t = uint16_t; |
258 |
| - using zmm_t = __m512i; |
259 |
| - using ymm_t = __m256i; |
260 |
| - using opmask_t = __mmask32; |
261 |
| - static const uint8_t numlanes = 32; |
262 |
| - |
263 |
| - static zmm_t get_network(int index) |
264 |
| - { |
265 |
| - return _mm512_loadu_si512(&network[index - 1][0]); |
266 |
| - } |
267 |
| - static type_t type_max() |
268 |
| - { |
269 |
| - return X86_SIMD_SORT_MAX_UINT16; |
270 |
| - } |
271 |
| - static type_t type_min() |
272 |
| - { |
273 |
| - return 0; |
274 |
| - } |
275 |
| - static zmm_t zmm_max() |
276 |
| - { |
277 |
| - return _mm512_set1_epi16(type_max()); |
278 |
| - } |
279 |
| - |
280 |
| - static opmask_t knot_opmask(opmask_t x) |
281 |
| - { |
282 |
| - return _knot_mask32(x); |
283 |
| - } |
284 |
| - static opmask_t ge(zmm_t x, zmm_t y) |
285 |
| - { |
286 |
| - return _mm512_cmp_epu16_mask(x, y, _MM_CMPINT_NLT); |
287 |
| - } |
288 |
| - static zmm_t loadu(void const *mem) |
289 |
| - { |
290 |
| - return _mm512_loadu_si512(mem); |
291 |
| - } |
292 |
| - static zmm_t max(zmm_t x, zmm_t y) |
293 |
| - { |
294 |
| - return _mm512_max_epu16(x, y); |
295 |
| - } |
296 |
| - static void mask_compressstoreu(void *mem, opmask_t mask, zmm_t x) |
297 |
| - { |
298 |
| - return _mm512_mask_compressstoreu_epi16(mem, mask, x); |
299 |
| - } |
300 |
| - static zmm_t mask_loadu(zmm_t x, opmask_t mask, void const *mem) |
301 |
| - { |
302 |
| - return _mm512_mask_loadu_epi16(x, mask, mem); |
303 |
| - } |
304 |
| - static zmm_t mask_mov(zmm_t x, opmask_t mask, zmm_t y) |
305 |
| - { |
306 |
| - return _mm512_mask_mov_epi16(x, mask, y); |
307 |
| - } |
308 |
| - static void mask_storeu(void *mem, opmask_t mask, zmm_t x) |
309 |
| - { |
310 |
| - return _mm512_mask_storeu_epi16(mem, mask, x); |
311 |
| - } |
312 |
| - static zmm_t min(zmm_t x, zmm_t y) |
313 |
| - { |
314 |
| - return _mm512_min_epu16(x, y); |
315 |
| - } |
316 |
| - static zmm_t permutexvar(__m512i idx, zmm_t zmm) |
317 |
| - { |
318 |
| - return _mm512_permutexvar_epi16(idx, zmm); |
319 |
| - } |
320 |
| - static type_t reducemax(zmm_t v) |
321 |
| - { |
322 |
| - zmm_t lo = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 0)); |
323 |
| - zmm_t hi = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 1)); |
324 |
| - type_t lo_max = (type_t)_mm512_reduce_max_epi32(lo); |
325 |
| - type_t hi_max = (type_t)_mm512_reduce_max_epi32(hi); |
326 |
| - return std::max(lo_max, hi_max); |
327 |
| - } |
328 |
| - static type_t reducemin(zmm_t v) |
329 |
| - { |
330 |
| - zmm_t lo = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 0)); |
331 |
| - zmm_t hi = _mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(v, 1)); |
332 |
| - type_t lo_min = (type_t)_mm512_reduce_min_epi32(lo); |
333 |
| - type_t hi_min = (type_t)_mm512_reduce_min_epi32(hi); |
334 |
| - return std::min(lo_min, hi_min); |
335 |
| - } |
336 |
| - static zmm_t set1(type_t v) |
337 |
| - { |
338 |
| - return _mm512_set1_epi16(v); |
339 |
| - } |
340 |
| - template <uint8_t mask> |
341 |
| - static zmm_t shuffle(zmm_t zmm) |
342 |
| - { |
343 |
| - zmm = _mm512_shufflehi_epi16(zmm, (_MM_PERM_ENUM)mask); |
344 |
| - return _mm512_shufflelo_epi16(zmm, (_MM_PERM_ENUM)mask); |
345 |
| - } |
346 |
| - static void storeu(void *mem, zmm_t x) |
347 |
| - { |
348 |
| - return _mm512_storeu_si512(mem, x); |
349 |
| - } |
350 |
| -}; |
351 |
| - |
352 | 12 | template <>
|
353 | 13 | bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
|
354 | 14 | {
|
|
0 commit comments