Skip to content

Commit 990ae6c

Browse files
author
Raghuveer Devulapalli
authored
Merge pull request #152 from sterrettm2/swizzle_cleanup
Cleanup for single vector sort/bitonic merge (and minor cleanup for argsort/argselect)
2 parents d62f656 + 3b41715 commit 990ae6c

13 files changed

+597
-722
lines changed

src/avx2-32bit-half.hpp

Lines changed: 18 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,36 +9,6 @@
99

1010
#include "avx2-emu-funcs.hpp"
1111

12-
/*
13-
* Constants used in sorting 8 elements in a ymm registers. Based on Bitonic
14-
* sorting network (see
15-
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg)
16-
*/
17-
18-
// ymm 7, 6, 5, 4, 3, 2, 1, 0
19-
#define NETWORK_32BIT_AVX2_1 4, 5, 6, 7, 0, 1, 2, 3
20-
#define NETWORK_32BIT_AVX2_2 0, 1, 2, 3, 4, 5, 6, 7
21-
#define NETWORK_32BIT_AVX2_3 5, 4, 7, 6, 1, 0, 3, 2
22-
#define NETWORK_32BIT_AVX2_4 3, 2, 1, 0, 7, 6, 5, 4
23-
24-
/*
25-
* Assumes ymm is random and performs a full sorting network defined in
26-
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg
27-
*/
28-
template <typename vtype, typename reg_t = typename vtype::reg_t>
29-
X86_SIMD_SORT_INLINE reg_t sort_ymm_32bit_half(reg_t ymm)
30-
{
31-
using swizzle = typename vtype::swizzle_ops;
32-
33-
const typename vtype::opmask_t oxAA = vtype::seti(-1, 0, -1, 0);
34-
const typename vtype::opmask_t oxCC = vtype::seti(-1, -1, 0, 0);
35-
36-
ymm = cmp_merge<vtype>(ymm, swizzle::template swap_n<vtype, 2>(ymm), oxAA);
37-
ymm = cmp_merge<vtype>(ymm, vtype::reverse(ymm), oxCC);
38-
ymm = cmp_merge<vtype>(ymm, swizzle::template swap_n<vtype, 2>(ymm), oxAA);
39-
return ymm;
40-
}
41-
4212
struct avx2_32bit_half_swizzle_ops;
4313

4414
template <>
@@ -74,6 +44,10 @@ struct avx2_half_vector<int32_t> {
7444
auto mask = ((0x1ull << num_to_read) - 0x1ull);
7545
return convert_int_to_avx2_mask_half(mask);
7646
}
47+
static opmask_t convert_int_to_mask(uint64_t intMask)
48+
{
49+
return convert_int_to_avx2_mask_half(intMask);
50+
}
7751
static regi_t seti(int v1, int v2, int v3, int v4)
7852
{
7953
return _mm_set_epi32(v1, v2, v3, v4);
@@ -155,7 +129,7 @@ struct avx2_half_vector<int32_t> {
155129
}
156130
static reg_t reverse(reg_t ymm)
157131
{
158-
const __m128i rev_index = _mm_set_epi32(0, 1, 2, 3);
132+
const __m128i rev_index = _mm_set_epi32(NETWORK_REVERSE_4LANES);
159133
return permutexvar(rev_index, ymm);
160134
}
161135
static type_t reducemax(reg_t v)
@@ -181,7 +155,7 @@ struct avx2_half_vector<int32_t> {
181155
}
182156
static reg_t sort_vec(reg_t x)
183157
{
184-
return sort_ymm_32bit_half<avx2_half_vector<type_t>>(x);
158+
return sort_reg_4lanes<avx2_half_vector<type_t>>(x);
185159
}
186160
static reg_t cast_from(__m128i v)
187161
{
@@ -237,6 +211,10 @@ struct avx2_half_vector<uint32_t> {
237211
auto mask = ((0x1ull << num_to_read) - 0x1ull);
238212
return convert_int_to_avx2_mask_half(mask);
239213
}
214+
static opmask_t convert_int_to_mask(uint64_t intMask)
215+
{
216+
return convert_int_to_avx2_mask_half(intMask);
217+
}
240218
static regi_t seti(int v1, int v2, int v3, int v4)
241219
{
242220
return _mm_set_epi32(v1, v2, v3, v4);
@@ -309,7 +287,7 @@ struct avx2_half_vector<uint32_t> {
309287
}
310288
static reg_t reverse(reg_t ymm)
311289
{
312-
const __m128i rev_index = _mm_set_epi32(0, 1, 2, 3);
290+
const __m128i rev_index = _mm_set_epi32(NETWORK_REVERSE_4LANES);
313291
return permutexvar(rev_index, ymm);
314292
}
315293
static type_t reducemax(reg_t v)
@@ -335,7 +313,7 @@ struct avx2_half_vector<uint32_t> {
335313
}
336314
static reg_t sort_vec(reg_t x)
337315
{
338-
return sort_ymm_32bit_half<avx2_half_vector<type_t>>(x);
316+
return sort_reg_4lanes<avx2_half_vector<type_t>>(x);
339317
}
340318
static reg_t cast_from(__m128i v)
341319
{
@@ -411,6 +389,10 @@ struct avx2_half_vector<float> {
411389
auto mask = ((0x1ull << num_to_read) - 0x1ull);
412390
return convert_int_to_avx2_mask_half(mask);
413391
}
392+
static opmask_t convert_int_to_mask(uint64_t intMask)
393+
{
394+
return convert_int_to_avx2_mask_half(intMask);
395+
}
414396
static int32_t convert_mask_to_int(opmask_t mask)
415397
{
416398
return convert_avx2_mask_to_int_half(mask);
@@ -478,7 +460,7 @@ struct avx2_half_vector<float> {
478460
}
479461
static reg_t reverse(reg_t ymm)
480462
{
481-
const __m128i rev_index = _mm_set_epi32(0, 1, 2, 3);
463+
const __m128i rev_index = _mm_set_epi32(NETWORK_REVERSE_4LANES);
482464
return permutexvar(rev_index, ymm);
483465
}
484466
static type_t reducemax(reg_t v)
@@ -504,7 +486,7 @@ struct avx2_half_vector<float> {
504486
}
505487
static reg_t sort_vec(reg_t x)
506488
{
507-
return sort_ymm_32bit_half<avx2_half_vector<type_t>>(x);
489+
return sort_reg_4lanes<avx2_half_vector<type_t>>(x);
508490
}
509491
static reg_t cast_from(__m128i v)
510492
{

src/avx2-32bit-qsort.hpp

Lines changed: 6 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -9,51 +9,6 @@
99

1010
#include "avx2-emu-funcs.hpp"
1111

12-
/*
13-
* Constants used in sorting 8 elements in a ymm registers. Based on Bitonic
14-
* sorting network (see
15-
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg)
16-
*/
17-
18-
// ymm 7, 6, 5, 4, 3, 2, 1, 0
19-
#define NETWORK_32BIT_AVX2_1 4, 5, 6, 7, 0, 1, 2, 3
20-
#define NETWORK_32BIT_AVX2_2 0, 1, 2, 3, 4, 5, 6, 7
21-
#define NETWORK_32BIT_AVX2_3 5, 4, 7, 6, 1, 0, 3, 2
22-
#define NETWORK_32BIT_AVX2_4 3, 2, 1, 0, 7, 6, 5, 4
23-
24-
/*
25-
* Assumes ymm is random and performs a full sorting network defined in
26-
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg
27-
*/
28-
template <typename vtype, typename reg_t = typename vtype::reg_t>
29-
X86_SIMD_SORT_INLINE reg_t sort_ymm_32bit(reg_t ymm)
30-
{
31-
const typename vtype::opmask_t oxAA = _mm256_set_epi32(
32-
0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0);
33-
const typename vtype::opmask_t oxCC = _mm256_set_epi32(
34-
0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0);
35-
const typename vtype::opmask_t oxF0 = _mm256_set_epi32(
36-
0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0, 0);
37-
38-
const typename vtype::ymmi_t rev_index = vtype::seti(NETWORK_32BIT_AVX2_2);
39-
ymm = cmp_merge<vtype>(
40-
ymm, vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(ymm), oxAA);
41-
ymm = cmp_merge<vtype>(
42-
ymm,
43-
vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_1), ymm),
44-
oxCC);
45-
ymm = cmp_merge<vtype>(
46-
ymm, vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(ymm), oxAA);
47-
ymm = cmp_merge<vtype>(ymm, vtype::permutexvar(rev_index, ymm), oxF0);
48-
ymm = cmp_merge<vtype>(
49-
ymm,
50-
vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_3), ymm),
51-
oxCC);
52-
ymm = cmp_merge<vtype>(
53-
ymm, vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(ymm), oxAA);
54-
return ymm;
55-
}
56-
5712
struct avx2_32bit_swizzle_ops;
5813

5914
template <>
@@ -180,7 +135,7 @@ struct avx2_vector<int32_t> {
180135
}
181136
static reg_t reverse(reg_t ymm)
182137
{
183-
const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2);
138+
const __m256i rev_index = _mm256_set_epi32(NETWORK_REVERSE_8LANES);
184139
return permutexvar(rev_index, ymm);
185140
}
186141
static type_t reducemax(reg_t v)
@@ -206,7 +161,7 @@ struct avx2_vector<int32_t> {
206161
}
207162
static reg_t sort_vec(reg_t x)
208163
{
209-
return sort_ymm_32bit<avx2_vector<type_t>>(x);
164+
return sort_reg_8lanes<avx2_vector<type_t>>(x);
210165
}
211166
static reg_t cast_from(__m256i v)
212167
{
@@ -342,7 +297,7 @@ struct avx2_vector<uint32_t> {
342297
}
343298
static reg_t reverse(reg_t ymm)
344299
{
345-
const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2);
300+
const __m256i rev_index = _mm256_set_epi32(NETWORK_REVERSE_8LANES);
346301
return permutexvar(rev_index, ymm);
347302
}
348303
static type_t reducemax(reg_t v)
@@ -368,7 +323,7 @@ struct avx2_vector<uint32_t> {
368323
}
369324
static reg_t sort_vec(reg_t x)
370325
{
371-
return sort_ymm_32bit<avx2_vector<type_t>>(x);
326+
return sort_reg_8lanes<avx2_vector<type_t>>(x);
372327
}
373328
static reg_t cast_from(__m256i v)
374329
{
@@ -520,7 +475,7 @@ struct avx2_vector<float> {
520475
}
521476
static reg_t reverse(reg_t ymm)
522477
{
523-
const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2);
478+
const __m256i rev_index = _mm256_set_epi32(NETWORK_REVERSE_8LANES);
524479
return permutexvar(rev_index, ymm);
525480
}
526481
static type_t reducemax(reg_t v)
@@ -547,7 +502,7 @@ struct avx2_vector<float> {
547502
}
548503
static reg_t sort_vec(reg_t x)
549504
{
550-
return sort_ymm_32bit<avx2_vector<type_t>>(x);
505+
return sort_reg_8lanes<avx2_vector<type_t>>(x);
551506
}
552507
static reg_t cast_from(__m256i v)
553508
{

src/avx2-64bit-qsort.hpp

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,32 +10,6 @@
1010

1111
#include "avx2-emu-funcs.hpp"
1212

13-
/*
14-
* Assumes ymm is random and performs a full sorting network defined in
15-
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg
16-
*/
17-
template <typename vtype, typename reg_t = typename vtype::reg_t>
18-
X86_SIMD_SORT_INLINE reg_t sort_ymm_64bit(reg_t ymm)
19-
{
20-
const typename vtype::opmask_t oxAA
21-
= _mm256_set_epi64x(0xFFFFFFFFFFFFFFFF, 0, 0xFFFFFFFFFFFFFFFF, 0);
22-
const typename vtype::opmask_t oxCC
23-
= _mm256_set_epi64x(0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0);
24-
ymm = cmp_merge<vtype>(
25-
ymm,
26-
vtype::template permutexvar<SHUFFLE_MASK(2, 3, 0, 1)>(ymm),
27-
oxAA);
28-
ymm = cmp_merge<vtype>(
29-
ymm,
30-
vtype::template permutexvar<SHUFFLE_MASK(0, 1, 2, 3)>(ymm),
31-
oxCC);
32-
ymm = cmp_merge<vtype>(
33-
ymm,
34-
vtype::template permutexvar<SHUFFLE_MASK(2, 3, 0, 1)>(ymm),
35-
oxAA);
36-
return ymm;
37-
}
38-
3913
struct avx2_64bit_swizzle_ops;
4014

4115
template <>
@@ -81,6 +55,10 @@ struct avx2_vector<int64_t> {
8155
auto mask = ((0x1ull << num_to_read) - 0x1ull);
8256
return convert_int_to_avx2_mask_64bit(mask);
8357
}
58+
static opmask_t convert_int_to_mask(uint64_t intMask)
59+
{
60+
return convert_int_to_avx2_mask_64bit(intMask);
61+
}
8462
static ymmi_t seti(int64_t v1, int64_t v2, int64_t v3, int64_t v4)
8563
{
8664
return _mm256_set_epi64x(v1, v2, v3, v4);
@@ -207,7 +185,7 @@ struct avx2_vector<int64_t> {
207185
}
208186
static reg_t sort_vec(reg_t x)
209187
{
210-
return sort_ymm_64bit<avx2_vector<type_t>>(x);
188+
return sort_reg_4lanes<avx2_vector<type_t>>(x);
211189
}
212190
static reg_t cast_from(__m256i v)
213191
{
@@ -265,6 +243,10 @@ struct avx2_vector<uint64_t> {
265243
auto mask = ((0x1ull << num_to_read) - 0x1ull);
266244
return convert_int_to_avx2_mask_64bit(mask);
267245
}
246+
static opmask_t convert_int_to_mask(uint64_t intMask)
247+
{
248+
return convert_int_to_avx2_mask_64bit(intMask);
249+
}
268250
static ymmi_t seti(int64_t v1, int64_t v2, int64_t v3, int64_t v4)
269251
{
270252
return _mm256_set_epi64x(v1, v2, v3, v4);
@@ -389,7 +371,7 @@ struct avx2_vector<uint64_t> {
389371
}
390372
static reg_t sort_vec(reg_t x)
391373
{
392-
return sort_ymm_64bit<avx2_vector<type_t>>(x);
374+
return sort_reg_4lanes<avx2_vector<type_t>>(x);
393375
}
394376
static reg_t cast_from(__m256i v)
395377
{
@@ -460,6 +442,10 @@ struct avx2_vector<double> {
460442
auto mask = ((0x1ull << num_to_read) - 0x1ull);
461443
return convert_int_to_avx2_mask_64bit(mask);
462444
}
445+
static opmask_t convert_int_to_mask(uint64_t intMask)
446+
{
447+
return convert_int_to_avx2_mask_64bit(intMask);
448+
}
463449
static int32_t convert_mask_to_int(opmask_t mask)
464450
{
465451
return convert_avx2_mask_to_int_64bit(mask);
@@ -593,7 +579,7 @@ struct avx2_vector<double> {
593579
}
594580
static reg_t sort_vec(reg_t x)
595581
{
596-
return sort_ymm_64bit<avx2_vector<type_t>>(x);
582+
return sort_reg_4lanes<avx2_vector<type_t>>(x);
597583
}
598584
static reg_t cast_from(__m256i v)
599585
{

0 commit comments

Comments
 (0)