Skip to content

Commit 9ad4432

Browse files
author
Raghuveer Devulapalli
committed
Get rid of the avx2_mask_helper
1 parent 3561db3 commit 9ad4432

8 files changed

+106
-70
lines changed

src/avx2-32bit-common.h

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* sorting network (see
1616
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg)
1717
*/
18-
18+
1919
// ymm 7, 6, 5, 4, 3, 2, 1, 0
2020
#define NETWORK_32BIT_AVX2_1 4, 5, 6, 7, 0, 1, 2, 3
2121
#define NETWORK_32BIT_AVX2_2 0, 1, 2, 3, 4, 5, 6, 7
@@ -58,11 +58,11 @@ struct avx2_vector<int32_t> {
5858
using type_t = int32_t;
5959
using reg_t = __m256i;
6060
using ymmi_t = __m256i;
61-
using opmask_t = avx2_mask_helper32;
61+
using opmask_t = __m256i;
6262
static const uint8_t numlanes = 8;
6363
static constexpr int network_sort_threshold = 256;
6464
static constexpr int partition_unroll_factor = 4;
65-
65+
6666
using swizzle_ops = avx2_32bit_swizzle_ops;
6767

6868
static type_t type_max()
@@ -77,7 +77,11 @@ struct avx2_vector<int32_t> {
7777
{
7878
return _mm256_set1_epi32(type_max());
7979
} // TODO: this should broadcast bits as is?
80-
80+
static opmask_t get_partial_loadmask(uint64_t num_to_read)
81+
{
82+
auto mask = ((0x1ull << num_to_read) - 0x1ull);
83+
return convert_int_to_avx2_mask(mask);
84+
}
8185
static ymmi_t
8286
seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
8387
{
@@ -215,11 +219,11 @@ struct avx2_vector<uint32_t> {
215219
using type_t = uint32_t;
216220
using reg_t = __m256i;
217221
using ymmi_t = __m256i;
218-
using opmask_t = avx2_mask_helper32;
222+
using opmask_t = __m256i;
219223
static const uint8_t numlanes = 8;
220224
static constexpr int network_sort_threshold = 256;
221225
static constexpr int partition_unroll_factor = 4;
222-
226+
223227
using swizzle_ops = avx2_32bit_swizzle_ops;
224228

225229
static type_t type_max()
@@ -234,7 +238,11 @@ struct avx2_vector<uint32_t> {
234238
{
235239
return _mm256_set1_epi32(type_max());
236240
}
237-
241+
static opmask_t get_partial_loadmask(uint64_t num_to_read)
242+
{
243+
auto mask = ((0x1ull << num_to_read) - 0x1ull);
244+
return convert_int_to_avx2_mask(mask);
245+
}
238246
static ymmi_t
239247
seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
240248
{
@@ -357,11 +365,11 @@ struct avx2_vector<float> {
357365
using type_t = float;
358366
using reg_t = __m256;
359367
using ymmi_t = __m256i;
360-
using opmask_t = avx2_mask_helper32;
368+
using opmask_t = __m256i;
361369
static const uint8_t numlanes = 8;
362370
static constexpr int network_sort_threshold = 256;
363371
static constexpr int partition_unroll_factor = 4;
364-
372+
365373
using swizzle_ops = avx2_32bit_swizzle_ops;
366374

367375
static type_t type_max()
@@ -399,9 +407,14 @@ struct avx2_vector<float> {
399407
{
400408
return _mm256_castps_si256(_mm256_cmp_ps(x, y, _CMP_EQ_OQ));
401409
}
402-
static opmask_t get_partial_loadmask(int size)
410+
static opmask_t get_partial_loadmask(uint64_t num_to_read)
411+
{
412+
auto mask = ((0x1ull << num_to_read) - 0x1ull);
413+
return convert_int_to_avx2_mask(mask);
414+
}
415+
static int32_t convert_mask_to_int(opmask_t mask)
403416
{
404-
return (0x0001 << size) - 0x0001;
417+
return convert_avx2_mask_to_int(mask);
405418
}
406419
template <int type>
407420
static opmask_t fpclass(reg_t x)

src/avx2-emu-funcs.hpp

Lines changed: 15 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -46,50 +46,21 @@ constexpr auto avx2_compressstore_lut32_gen = [] {
4646
}
4747
return lutPair;
4848
}();
49+
4950
constexpr auto avx2_compressstore_lut32_perm = avx2_compressstore_lut32_gen[0];
5051
constexpr auto avx2_compressstore_lut32_left = avx2_compressstore_lut32_gen[1];
5152

52-
struct avx2_mask_helper32 {
53-
__m256i mask;
54-
55-
avx2_mask_helper32() = default;
56-
avx2_mask_helper32(int m)
57-
{
58-
mask = converter(m);
59-
}
60-
avx2_mask_helper32(__m256i m)
61-
{
62-
mask = m;
63-
}
64-
operator __m256i()
65-
{
66-
return mask;
67-
}
68-
operator int32_t()
69-
{
70-
return converter(mask);
71-
}
72-
__m256i operator=(int m)
73-
{
74-
mask = converter(m);
75-
return mask;
76-
}
77-
78-
private:
79-
__m256i converter(int m)
80-
{
81-
return _mm256_loadu_si256(
82-
(const __m256i *)avx2_mask_helper_lut32[m].data());
83-
}
53+
X86_SIMD_SORT_INLINE
54+
__m256i convert_int_to_avx2_mask(int32_t m)
55+
{
56+
return _mm256_loadu_si256(
57+
(const __m256i *)avx2_mask_helper_lut32[m].data());
58+
}
8459

85-
int32_t converter(__m256i m)
86-
{
87-
return _mm256_movemask_ps(_mm256_castsi256_ps(m));
88-
}
89-
};
90-
static __m256i operator~(const avx2_mask_helper32 x)
60+
X86_SIMD_SORT_INLINE
61+
int32_t convert_avx2_mask_to_int(__m256i m)
9162
{
92-
return ~x.mask;
63+
return _mm256_movemask_ps(_mm256_castsi256_ps(m));
9364
}
9465

9566
// Emulators for intrinsics missing from AVX2 compared to AVX512
@@ -98,7 +69,7 @@ T avx2_emu_reduce_max32(typename avx2_vector<T>::reg_t x)
9869
{
9970
using vtype = avx2_vector<T>;
10071
using reg_t = typename vtype::reg_t;
101-
72+
10273
reg_t inter1 = vtype::max(x, vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(x));
10374
reg_t inter2 = vtype::max(inter1, vtype::template shuffle<SHUFFLE_MASK(1, 0, 3, 2)>(inter1));
10475
T can1 = vtype::template extract<0>(inter2);
@@ -111,7 +82,7 @@ T avx2_emu_reduce_min32(typename avx2_vector<T>::reg_t x)
11182
{
11283
using vtype = avx2_vector<T>;
11384
using reg_t = typename vtype::reg_t;
114-
85+
11586
reg_t inter1 = vtype::min(x, vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(x));
11687
reg_t inter2 = vtype::min(inter1, vtype::template shuffle<SHUFFLE_MASK(1, 0, 3, 2)>(inter1));
11788
T can1 = vtype::template extract<0>(inter2);
@@ -128,7 +99,7 @@ void avx2_emu_mask_compressstoreu(void *base_addr,
12899

129100
T *leftStore = (T *)base_addr;
130101

131-
int32_t shortMask = avx2_mask_helper32(k);
102+
int32_t shortMask = convert_avx2_mask_to_int(k);
132103
const __m256i &perm = _mm256_loadu_si256(
133104
(const __m256i *)avx2_compressstore_lut32_perm[shortMask].data());
134105
const __m256i &left = _mm256_loadu_si256(
@@ -150,7 +121,7 @@ int avx2_double_compressstore32(void *left_addr,
150121
T *leftStore = (T *)left_addr;
151122
T *rightStore = (T *)right_addr;
152123

153-
int32_t shortMask = avx2_mask_helper32(k);
124+
int32_t shortMask = convert_avx2_mask_to_int(k);
154125
const __m256i &perm = _mm256_loadu_si256(
155126
(const __m256i *)avx2_compressstore_lut32_perm[shortMask].data());
156127
const __m256i &left = _mm256_loadu_si256(
@@ -186,4 +157,4 @@ typename avx2_vector<T>::reg_t avx2_emu_min(typename avx2_vector<T>::reg_t x,
186157
_mm256_castsi256_pd(nlt)));
187158
}
188159

189-
#endif
160+
#endif

src/avx512-16bit-qsort.hpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,14 @@ struct zmm_vector<float16> {
8080
exp_eq, mant_x, mant_y, _MM_CMPINT_NLT);
8181
return _kxor_mask32(mask_ge, neg);
8282
}
83+
static opmask_t get_partial_loadmask(uint64_t num_to_read)
84+
{
85+
return ((0x1ull << num_to_read) - 0x1ull);
86+
}
87+
static int32_t convert_mask_to_int(opmask_t mask)
88+
{
89+
return mask;
90+
}
8391
static reg_t loadu(void const *mem)
8492
{
8593
return _mm512_loadu_si512(mem);
@@ -227,6 +235,10 @@ struct zmm_vector<int16_t> {
227235
{
228236
return _mm512_cmp_epi16_mask(x, y, _MM_CMPINT_NLT);
229237
}
238+
static opmask_t get_partial_loadmask(uint64_t num_to_read)
239+
{
240+
return ((0x1ull << num_to_read) - 0x1ull);
241+
}
230242
static reg_t loadu(void const *mem)
231243
{
232244
return _mm512_loadu_si512(mem);
@@ -357,6 +369,10 @@ struct zmm_vector<uint16_t> {
357369
{
358370
return _mm512_cmp_epu16_mask(x, y, _MM_CMPINT_NLT);
359371
}
372+
static opmask_t get_partial_loadmask(uint64_t num_to_read)
373+
{
374+
return ((0x1ull << num_to_read) - 0x1ull);
375+
}
360376
static reg_t loadu(void const *mem)
361377
{
362378
return _mm512_loadu_si512(mem);

src/avx512-32bit-qsort.hpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ struct zmm_vector<int32_t> {
6565
{
6666
return _mm512_cmp_epi32_mask(x, y, _MM_CMPINT_NLT);
6767
}
68+
static opmask_t get_partial_loadmask(uint64_t num_to_read)
69+
{
70+
return ((0x1ull << num_to_read) - 0x1ull);
71+
}
6872
template <int scale>
6973
static halfreg_t i64gather(__m512i index, void const *base)
7074
{
@@ -209,6 +213,10 @@ struct zmm_vector<uint32_t> {
209213
{
210214
return _mm512_cmp_epu32_mask(x, y, _MM_CMPINT_NLT);
211215
}
216+
static opmask_t get_partial_loadmask(uint64_t num_to_read)
217+
{
218+
return ((0x1ull << num_to_read) - 0x1ull);
219+
}
212220
static reg_t loadu(void const *mem)
213221
{
214222
return _mm512_loadu_si512(mem);
@@ -333,9 +341,13 @@ struct zmm_vector<float> {
333341
{
334342
return _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ);
335343
}
336-
static opmask_t get_partial_loadmask(int size)
344+
static opmask_t get_partial_loadmask(uint64_t num_to_read)
345+
{
346+
return ((0x1ull << num_to_read) - 0x1ull);
347+
}
348+
static int32_t convert_mask_to_int(opmask_t mask)
337349
{
338-
return (0x0001 << size) - 0x0001;
350+
return mask;
339351
}
340352
template <int type>
341353
static opmask_t fpclass(reg_t x)

src/avx512-64bit-common.h

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,9 @@ struct ymm_vector<float> {
8181
{
8282
return _mm256_cmp_ps_mask(x, y, _CMP_EQ_OQ);
8383
}
84-
static opmask_t get_partial_loadmask(int size)
84+
static opmask_t get_partial_loadmask(uint64_t num_to_read)
8585
{
86-
return (0x01 << size) - 0x01;
86+
return ((0x1ull << num_to_read) - 0x1ull);
8787
}
8888
template <int type>
8989
static opmask_t fpclass(reg_t x)
@@ -244,6 +244,10 @@ struct ymm_vector<uint32_t> {
244244
{
245245
return _mm256_cmp_epu32_mask(x, y, _MM_CMPINT_NLT);
246246
}
247+
static opmask_t get_partial_loadmask(uint64_t num_to_read)
248+
{
249+
return ((0x1ull << num_to_read) - 0x1ull);
250+
}
247251
static opmask_t eq(reg_t x, reg_t y)
248252
{
249253
return _mm256_cmp_epu32_mask(x, y, _MM_CMPINT_EQ);
@@ -396,6 +400,10 @@ struct ymm_vector<int32_t> {
396400
{
397401
return _mm256_cmp_epi32_mask(x, y, _MM_CMPINT_NLT);
398402
}
403+
static opmask_t get_partial_loadmask(uint64_t num_to_read)
404+
{
405+
return ((0x1ull << num_to_read) - 0x1ull);
406+
}
399407
static opmask_t eq(reg_t x, reg_t y)
400408
{
401409
return _mm256_cmp_epi32_mask(x, y, _MM_CMPINT_EQ);
@@ -557,6 +565,10 @@ struct zmm_vector<int64_t> {
557565
{
558566
return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_NLT);
559567
}
568+
static opmask_t get_partial_loadmask(uint64_t num_to_read)
569+
{
570+
return ((0x1ull << num_to_read) - 0x1ull);
571+
}
560572
static opmask_t eq(reg_t x, reg_t y)
561573
{
562574
return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_EQ);
@@ -745,6 +757,10 @@ struct zmm_vector<uint64_t> {
745757
{
746758
return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_NLT);
747759
}
760+
static opmask_t get_partial_loadmask(uint64_t num_to_read)
761+
{
762+
return ((0x1ull << num_to_read) - 0x1ull);
763+
}
748764
static opmask_t eq(reg_t x, reg_t y)
749765
{
750766
return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_EQ);
@@ -894,9 +910,13 @@ struct zmm_vector<double> {
894910
{
895911
return _mm512_cmp_pd_mask(x, y, _CMP_EQ_OQ);
896912
}
897-
static opmask_t get_partial_loadmask(int size)
913+
static opmask_t get_partial_loadmask(uint64_t num_to_read)
914+
{
915+
return ((0x1ull << num_to_read) - 0x1ull);
916+
}
917+
static int32_t convert_mask_to_int(opmask_t mask)
898918
{
899-
return (0x01 << size) - 0x01;
919+
return mask;
900920
}
901921
template <int type>
902922
static opmask_t fpclass(reg_t x)

src/avx512fp16-16bit-qsort.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,13 @@ struct zmm_vector<_Float16> {
5454
{
5555
return _mm512_cmp_ph_mask(x, y, _CMP_GE_OQ);
5656
}
57-
static opmask_t get_partial_loadmask(int size)
57+
static opmask_t get_partial_loadmask(uint64_t num_to_read)
5858
{
59-
return (0x00000001 << size) - 0x00000001;
59+
return ((0x1ull << num_to_read) - 0x1ull);
60+
}
61+
static int32_t convert_mask_to_int(opmask_t mask)
62+
{
63+
return mask;
6064
}
6165
template <int type>
6266
static opmask_t fpclass(reg_t x)

src/xss-common-qsort.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ X86_SIMD_SORT_INLINE arrsize_t replace_nan_with_inf(T *arr, arrsize_t size)
6565
in = vtype::loadu(arr + ii);
6666
}
6767
opmask_t nanmask = vtype::template fpclass<0x01 | 0x80>(in);
68-
nan_count += _mm_popcnt_u32((int32_t)nanmask);
68+
nan_count += _mm_popcnt_u32(vtype::convert_mask_to_int(nanmask));
6969
vtype::mask_storeu(arr + ii, nanmask, vtype::zmm_max());
7070
}
7171
return nan_count;
@@ -174,7 +174,7 @@ int avx512_double_compressstore(type_t *left_addr,
174174
vtype::mask_compressstoreu(left_addr, vtype::knot_opmask(k), reg);
175175
vtype::mask_compressstoreu(
176176
right_addr + vtype::numlanes - amount_ge_pivot, k, reg);
177-
177+
178178
return amount_ge_pivot;
179179
}
180180

@@ -188,7 +188,7 @@ X86_SIMD_SORT_INLINE arrsize_t partition_vec(type_t *l_store,
188188
reg_t &biggest_vec)
189189
{
190190
typename vtype::opmask_t ge_mask = vtype::ge(curr_vec, pivot_vec);
191-
191+
192192
int amount_ge_pivot = vtype::double_compressstore(l_store, r_store, ge_mask, curr_vec);
193193

194194
smallest_vec = vtype::min(curr_vec, smallest_vec);

0 commit comments

Comments
 (0)