Skip to content

Commit 4efed2a

Browse files
author
Raghuveer Devulapalli
committed
Style format
1 parent bdd0af6 commit 4efed2a

8 files changed

+150
-120
lines changed

src/avx512-16bit-qsort.hpp

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -377,9 +377,9 @@ bool comparison_func<zmm_vector<float16>>(const uint16_t &a, const uint16_t &b)
377377
//return npy_half_to_float(a) < npy_half_to_float(b);
378378
}
379379

380-
template<>
381-
int64_t
382-
replace_nan_with_inf<zmm_vector<float16>>(uint16_t *arr, int64_t arrsize)
380+
template <>
381+
int64_t replace_nan_with_inf<zmm_vector<float16>>(uint16_t *arr,
382+
int64_t arrsize)
383383
{
384384
int64_t nan_count = 0;
385385
__mmask16 loadmask = 0xFFFF;
@@ -405,21 +405,28 @@ bool is_a_nan<uint16_t>(uint16_t elem)
405405

406406
/* Specialized template function for 16-bit qsort_ funcs*/
407407
template <>
408-
void qsort_<zmm_vector<int16_t>>(int16_t* arr, int64_t left, int64_t right, int64_t maxiters)
408+
void qsort_<zmm_vector<int16_t>>(int16_t *arr,
409+
int64_t left,
410+
int64_t right,
411+
int64_t maxiters)
409412
{
410413
qsort_16bit_<zmm_vector<int16_t>>(arr, left, right, maxiters);
411414
}
412415

413416
template <>
414-
void qsort_<zmm_vector<uint16_t>>(uint16_t* arr, int64_t left, int64_t right, int64_t maxiters)
417+
void qsort_<zmm_vector<uint16_t>>(uint16_t *arr,
418+
int64_t left,
419+
int64_t right,
420+
int64_t maxiters)
415421
{
416422
qsort_16bit_<zmm_vector<uint16_t>>(arr, left, right, maxiters);
417423
}
418424

419425
void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize)
420426
{
421427
if (arrsize > 1) {
422-
int64_t nan_count = replace_nan_with_inf<zmm_vector<float16>, uint16_t>(arr, arrsize);
428+
int64_t nan_count = replace_nan_with_inf<zmm_vector<float16>, uint16_t>(
429+
arr, arrsize);
423430
qsort_16bit_<zmm_vector<float16>, uint16_t>(
424431
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
425432
replace_inf_with_nan(arr, arrsize, nan_count);
@@ -428,13 +435,15 @@ void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize)
428435

429436
/* Specialized template function for 16-bit qselect_ funcs*/
430437
template <>
431-
void qselect_<zmm_vector<int16_t>>(int16_t* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
438+
void qselect_<zmm_vector<int16_t>>(
439+
int16_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
432440
{
433441
qselect_16bit_<zmm_vector<int16_t>>(arr, k, left, right, maxiters);
434442
}
435443

436444
template <>
437-
void qselect_<zmm_vector<uint16_t>>(uint16_t* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
445+
void qselect_<zmm_vector<uint16_t>>(
446+
uint16_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
438447
{
439448
qselect_16bit_<zmm_vector<uint16_t>>(arr, k, left, right, maxiters);
440449
}

src/avx512-32bit-qsort.hpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -704,38 +704,50 @@ static void qselect_32bit_(type_t *arr,
704704

705705
/* Specialized template function for 32-bit qselect_ funcs*/
706706
template <>
707-
void qselect_<zmm_vector<int32_t>>(int32_t* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
707+
void qselect_<zmm_vector<int32_t>>(
708+
int32_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
708709
{
709710
qselect_32bit_<zmm_vector<int32_t>>(arr, k, left, right, maxiters);
710711
}
711712

712713
template <>
713-
void qselect_<zmm_vector<uint32_t>>(uint32_t* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
714+
void qselect_<zmm_vector<uint32_t>>(
715+
uint32_t *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
714716
{
715717
qselect_32bit_<zmm_vector<uint32_t>>(arr, k, left, right, maxiters);
716718
}
717719

718720
template <>
719-
void qselect_<zmm_vector<float>>(float* arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
721+
void qselect_<zmm_vector<float>>(
722+
float *arr, int64_t k, int64_t left, int64_t right, int64_t maxiters)
720723
{
721724
qselect_32bit_<zmm_vector<float>>(arr, k, left, right, maxiters);
722725
}
723726

724727
/* Specialized template function for 32-bit qsort_ funcs*/
725728
template <>
726-
void qsort_<zmm_vector<int32_t>>(int32_t* arr, int64_t left, int64_t right, int64_t maxiters)
729+
void qsort_<zmm_vector<int32_t>>(int32_t *arr,
730+
int64_t left,
731+
int64_t right,
732+
int64_t maxiters)
727733
{
728734
qsort_32bit_<zmm_vector<int32_t>>(arr, left, right, maxiters);
729735
}
730736

731737
template <>
732-
void qsort_<zmm_vector<uint32_t>>(uint32_t* arr, int64_t left, int64_t right, int64_t maxiters)
738+
void qsort_<zmm_vector<uint32_t>>(uint32_t *arr,
739+
int64_t left,
740+
int64_t right,
741+
int64_t maxiters)
733742
{
734743
qsort_32bit_<zmm_vector<uint32_t>>(arr, left, right, maxiters);
735744
}
736745

737746
template <>
738-
void qsort_<zmm_vector<float>>(float* arr, int64_t left, int64_t right, int64_t maxiters)
747+
void qsort_<zmm_vector<float>>(float *arr,
748+
int64_t left,
749+
int64_t right,
750+
int64_t maxiters)
739751
{
740752
qsort_32bit_<zmm_vector<float>>(arr, left, right, maxiters);
741753
}

src/avx512-64bit-common.h

Lines changed: 56 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,8 @@ struct ymm_vector<float> {
4040
return _mm256_set1_ps(type_max());
4141
}
4242

43-
static zmmi_t seti(int v1,
44-
int v2,
45-
int v3,
46-
int v4,
47-
int v5,
48-
int v6,
49-
int v7,
50-
int v8)
43+
static zmmi_t
44+
seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
5145
{
5246
return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8);
5347
}
@@ -93,7 +87,7 @@ struct ymm_vector<float> {
9387
}
9488
static zmm_t loadu(void const *mem)
9589
{
96-
return _mm256_loadu_ps((float*) mem);
90+
return _mm256_loadu_ps((float *)mem);
9791
}
9892
static zmm_t max(zmm_t x, zmm_t y)
9993
{
@@ -129,16 +123,22 @@ struct ymm_vector<float> {
129123
}
130124
static type_t reducemax(zmm_t v)
131125
{
132-
__m128 v128 = _mm_max_ps(_mm256_castps256_ps128(v), _mm256_extractf32x4_ps (v, 1));
133-
__m128 v64 = _mm_max_ps(v128, _mm_shuffle_ps(v128, v128, _MM_SHUFFLE(1, 0, 3, 2)));
134-
__m128 v32 = _mm_max_ps(v64, _mm_shuffle_ps(v64, v64, _MM_SHUFFLE(0, 0, 0, 1)));
126+
__m128 v128 = _mm_max_ps(_mm256_castps256_ps128(v),
127+
_mm256_extractf32x4_ps(v, 1));
128+
__m128 v64 = _mm_max_ps(
129+
v128, _mm_shuffle_ps(v128, v128, _MM_SHUFFLE(1, 0, 3, 2)));
130+
__m128 v32 = _mm_max_ps(
131+
v64, _mm_shuffle_ps(v64, v64, _MM_SHUFFLE(0, 0, 0, 1)));
135132
return _mm_cvtss_f32(v32);
136133
}
137134
static type_t reducemin(zmm_t v)
138135
{
139-
__m128 v128 = _mm_min_ps(_mm256_castps256_ps128(v), _mm256_extractf32x4_ps(v, 1));
140-
__m128 v64 = _mm_min_ps(v128, _mm_shuffle_ps(v128, v128,_MM_SHUFFLE(1, 0, 3, 2)));
141-
__m128 v32 = _mm_min_ps(v64, _mm_shuffle_ps(v64, v64,_MM_SHUFFLE(0, 0, 0, 1)));
136+
__m128 v128 = _mm_min_ps(_mm256_castps256_ps128(v),
137+
_mm256_extractf32x4_ps(v, 1));
138+
__m128 v64 = _mm_min_ps(
139+
v128, _mm_shuffle_ps(v128, v128, _MM_SHUFFLE(1, 0, 3, 2)));
140+
__m128 v32 = _mm_min_ps(
141+
v64, _mm_shuffle_ps(v64, v64, _MM_SHUFFLE(0, 0, 0, 1)));
142142
return _mm_cvtss_f32(v32);
143143
}
144144
static zmm_t set1(type_t v)
@@ -160,7 +160,7 @@ struct ymm_vector<float> {
160160
}
161161
static void storeu(void *mem, zmm_t x)
162162
{
163-
_mm256_storeu_ps((float*)mem, x);
163+
_mm256_storeu_ps((float *)mem, x);
164164
}
165165
};
166166
template <>
@@ -184,14 +184,8 @@ struct ymm_vector<uint32_t> {
184184
return _mm256_set1_epi32(type_max());
185185
}
186186

187-
static zmmi_t seti(int v1,
188-
int v2,
189-
int v3,
190-
int v4,
191-
int v5,
192-
int v6,
193-
int v7,
194-
int v8)
187+
static zmmi_t
188+
seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
195189
{
196190
return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8);
197191
}
@@ -228,7 +222,7 @@ struct ymm_vector<uint32_t> {
228222
}
229223
static zmm_t loadu(void const *mem)
230224
{
231-
return _mm256_loadu_si256((__m256i*) mem);
225+
return _mm256_loadu_si256((__m256i *)mem);
232226
}
233227
static zmm_t max(zmm_t x, zmm_t y)
234228
{
@@ -264,16 +258,22 @@ struct ymm_vector<uint32_t> {
264258
}
265259
static type_t reducemax(zmm_t v)
266260
{
267-
__m128i v128 = _mm_max_epu32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1));
268-
__m128i v64 = _mm_max_epu32(v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2)));
269-
__m128i v32 = _mm_max_epu32(v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1)));
261+
__m128i v128 = _mm_max_epu32(_mm256_castsi256_si128(v),
262+
_mm256_extracti128_si256(v, 1));
263+
__m128i v64 = _mm_max_epu32(
264+
v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2)));
265+
__m128i v32 = _mm_max_epu32(
266+
v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1)));
270267
return (type_t)_mm_cvtsi128_si32(v32);
271268
}
272269
static type_t reducemin(zmm_t v)
273270
{
274-
__m128i v128 = _mm_min_epu32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1));
275-
__m128i v64 = _mm_min_epu32(v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2)));
276-
__m128i v32 = _mm_min_epu32(v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1)));
271+
__m128i v128 = _mm_min_epu32(_mm256_castsi256_si128(v),
272+
_mm256_extracti128_si256(v, 1));
273+
__m128i v64 = _mm_min_epu32(
274+
v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2)));
275+
__m128i v32 = _mm_min_epu32(
276+
v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1)));
277277
return (type_t)_mm_cvtsi128_si32(v32);
278278
}
279279
static zmm_t set1(type_t v)
@@ -289,7 +289,7 @@ struct ymm_vector<uint32_t> {
289289
}
290290
static void storeu(void *mem, zmm_t x)
291291
{
292-
_mm256_storeu_si256((__m256i*) mem, x);
292+
_mm256_storeu_si256((__m256i *)mem, x);
293293
}
294294
};
295295
template <>
@@ -313,14 +313,8 @@ struct ymm_vector<int32_t> {
313313
return _mm256_set1_epi32(type_max());
314314
} // TODO: this should broadcast bits as is?
315315

316-
static zmmi_t seti(int v1,
317-
int v2,
318-
int v3,
319-
int v4,
320-
int v5,
321-
int v6,
322-
int v7,
323-
int v8)
316+
static zmmi_t
317+
seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
324318
{
325319
return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8);
326320
}
@@ -357,7 +351,7 @@ struct ymm_vector<int32_t> {
357351
}
358352
static zmm_t loadu(void const *mem)
359353
{
360-
return _mm256_loadu_si256((__m256i*) mem);
354+
return _mm256_loadu_si256((__m256i *)mem);
361355
}
362356
static zmm_t max(zmm_t x, zmm_t y)
363357
{
@@ -393,16 +387,22 @@ struct ymm_vector<int32_t> {
393387
}
394388
static type_t reducemax(zmm_t v)
395389
{
396-
__m128i v128 = _mm_max_epi32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1));
397-
__m128i v64 = _mm_max_epi32(v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2)));
398-
__m128i v32 = _mm_max_epi32(v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1)));
390+
__m128i v128 = _mm_max_epi32(_mm256_castsi256_si128(v),
391+
_mm256_extracti128_si256(v, 1));
392+
__m128i v64 = _mm_max_epi32(
393+
v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2)));
394+
__m128i v32 = _mm_max_epi32(
395+
v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1)));
399396
return (type_t)_mm_cvtsi128_si32(v32);
400397
}
401398
static type_t reducemin(zmm_t v)
402399
{
403-
__m128i v128 = _mm_min_epi32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1));
404-
__m128i v64 = _mm_min_epi32(v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2)));
405-
__m128i v32 = _mm_min_epi32(v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1)));
400+
__m128i v128 = _mm_min_epi32(_mm256_castsi256_si128(v),
401+
_mm256_extracti128_si256(v, 1));
402+
__m128i v64 = _mm_min_epi32(
403+
v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(1, 0, 3, 2)));
404+
__m128i v32 = _mm_min_epi32(
405+
v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1)));
406406
return (type_t)_mm_cvtsi128_si32(v32);
407407
}
408408
static zmm_t set1(type_t v)
@@ -418,7 +418,7 @@ struct ymm_vector<int32_t> {
418418
}
419419
static void storeu(void *mem, zmm_t x)
420420
{
421-
_mm256_storeu_si256((__m256i*) mem, x);
421+
_mm256_storeu_si256((__m256i *)mem, x);
422422
}
423423
};
424424
template <>
@@ -443,14 +443,8 @@ struct zmm_vector<int64_t> {
443443
return _mm512_set1_epi64(type_max());
444444
} // TODO: this should broadcast bits as is?
445445

446-
static zmmi_t seti(int v1,
447-
int v2,
448-
int v3,
449-
int v4,
450-
int v5,
451-
int v6,
452-
int v7,
453-
int v8)
446+
static zmmi_t
447+
seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
454448
{
455449
return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8);
456450
}
@@ -567,14 +561,8 @@ struct zmm_vector<uint64_t> {
567561
return _mm512_set1_epi64(type_max());
568562
}
569563

570-
static zmmi_t seti(int v1,
571-
int v2,
572-
int v3,
573-
int v4,
574-
int v5,
575-
int v6,
576-
int v7,
577-
int v8)
564+
static zmmi_t
565+
seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
578566
{
579567
return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8);
580568
}
@@ -679,14 +667,8 @@ struct zmm_vector<double> {
679667
return _mm512_set1_pd(type_max());
680668
}
681669

682-
static zmmi_t seti(int v1,
683-
int v2,
684-
int v3,
685-
int v4,
686-
int v5,
687-
int v6,
688-
int v7,
689-
int v8)
670+
static zmmi_t
671+
seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
690672
{
691673
return _mm512_set_epi64(v1, v2, v3, v4, v5, v6, v7, v8);
692674
}
@@ -793,16 +775,12 @@ X86_SIMD_SORT_INLINE zmm_t sort_zmm_64bit(zmm_t zmm)
793775
zmm = cmp_merge<vtype>(
794776
zmm, vtype::template shuffle<SHUFFLE_MASK(1, 1, 1, 1)>(zmm), 0xAA);
795777
zmm = cmp_merge<vtype>(
796-
zmm,
797-
vtype::permutexvar(vtype::seti(NETWORK_64BIT_1), zmm),
798-
0xCC);
778+
zmm, vtype::permutexvar(vtype::seti(NETWORK_64BIT_1), zmm), 0xCC);
799779
zmm = cmp_merge<vtype>(
800780
zmm, vtype::template shuffle<SHUFFLE_MASK(1, 1, 1, 1)>(zmm), 0xAA);
801781
zmm = cmp_merge<vtype>(zmm, vtype::permutexvar(rev_index, zmm), 0xF0);
802782
zmm = cmp_merge<vtype>(
803-
zmm,
804-
vtype::permutexvar(vtype::seti(NETWORK_64BIT_3), zmm),
805-
0xCC);
783+
zmm, vtype::permutexvar(vtype::seti(NETWORK_64BIT_3), zmm), 0xCC);
806784
zmm = cmp_merge<vtype>(
807785
zmm, vtype::template shuffle<SHUFFLE_MASK(1, 1, 1, 1)>(zmm), 0xAA);
808786
return zmm;

src/avx512-64bit-keyvaluesort.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,8 @@ void avx512_qsort_kv(T1 *keys, T2 *indexes, int64_t arrsize)
444444
{
445445
if (arrsize > 1) {
446446
if constexpr (std::is_floating_point_v<T1>) {
447-
int64_t nan_count = replace_nan_with_inf<zmm_vector<double>>(keys, arrsize);
447+
int64_t nan_count
448+
= replace_nan_with_inf<zmm_vector<double>>(keys, arrsize);
448449
qsort_64bit_<zmm_vector<T1>, zmm_vector<T2>>(
449450
keys, indexes, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
450451
replace_inf_with_nan(keys, arrsize, nan_count);

0 commit comments

Comments
 (0)