15
15
* sorting network (see
16
16
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg)
17
17
*/
18
-
18
+
19
19
// ymm 7, 6, 5, 4, 3, 2, 1, 0
20
20
#define NETWORK_32BIT_AVX2_1 4 , 5 , 6 , 7 , 0 , 1 , 2 , 3
21
21
#define NETWORK_32BIT_AVX2_2 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7
@@ -58,11 +58,11 @@ struct avx2_vector<int32_t> {
58
58
using type_t = int32_t ;
59
59
using reg_t = __m256i;
60
60
using ymmi_t = __m256i;
61
- using opmask_t = avx2_mask_helper32 ;
61
+ using opmask_t = __m256i ;
62
62
static const uint8_t numlanes = 8 ;
63
63
static constexpr int network_sort_threshold = 256 ;
64
64
static constexpr int partition_unroll_factor = 4 ;
65
-
65
+
66
66
using swizzle_ops = avx2_32bit_swizzle_ops;
67
67
68
68
static type_t type_max ()
@@ -77,7 +77,11 @@ struct avx2_vector<int32_t> {
77
77
{
78
78
return _mm256_set1_epi32 (type_max ());
79
79
} // TODO: this should broadcast bits as is?
80
-
80
+ static opmask_t get_partial_loadmask (uint64_t num_to_read)
81
+ {
82
+ auto mask = ((0x1ull << num_to_read) - 0x1ull );
83
+ return convert_int_to_avx2_mask (mask);
84
+ }
81
85
static ymmi_t
82
86
seti (int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
83
87
{
@@ -215,11 +219,11 @@ struct avx2_vector<uint32_t> {
215
219
using type_t = uint32_t ;
216
220
using reg_t = __m256i;
217
221
using ymmi_t = __m256i;
218
- using opmask_t = avx2_mask_helper32 ;
222
+ using opmask_t = __m256i ;
219
223
static const uint8_t numlanes = 8 ;
220
224
static constexpr int network_sort_threshold = 256 ;
221
225
static constexpr int partition_unroll_factor = 4 ;
222
-
226
+
223
227
using swizzle_ops = avx2_32bit_swizzle_ops;
224
228
225
229
static type_t type_max ()
@@ -234,7 +238,11 @@ struct avx2_vector<uint32_t> {
234
238
{
235
239
return _mm256_set1_epi32 (type_max ());
236
240
}
237
-
241
+ static opmask_t get_partial_loadmask (uint64_t num_to_read)
242
+ {
243
+ auto mask = ((0x1ull << num_to_read) - 0x1ull );
244
+ return convert_int_to_avx2_mask (mask);
245
+ }
238
246
static ymmi_t
239
247
seti (int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8)
240
248
{
@@ -357,11 +365,11 @@ struct avx2_vector<float> {
357
365
using type_t = float ;
358
366
using reg_t = __m256;
359
367
using ymmi_t = __m256i;
360
- using opmask_t = avx2_mask_helper32 ;
368
+ using opmask_t = __m256i ;
361
369
static const uint8_t numlanes = 8 ;
362
370
static constexpr int network_sort_threshold = 256 ;
363
371
static constexpr int partition_unroll_factor = 4 ;
364
-
372
+
365
373
using swizzle_ops = avx2_32bit_swizzle_ops;
366
374
367
375
static type_t type_max ()
@@ -399,9 +407,14 @@ struct avx2_vector<float> {
399
407
{
400
408
return _mm256_castps_si256 (_mm256_cmp_ps (x, y, _CMP_EQ_OQ));
401
409
}
402
- static opmask_t get_partial_loadmask (int size)
410
+ static opmask_t get_partial_loadmask (uint64_t num_to_read)
411
+ {
412
+ auto mask = ((0x1ull << num_to_read) - 0x1ull );
413
+ return convert_int_to_avx2_mask (mask);
414
+ }
415
+ static int32_t convert_mask_to_int (opmask_t mask)
403
416
{
404
- return ( 0x0001 << size) - 0x0001 ;
417
+ return convert_avx2_mask_to_int (mask) ;
405
418
}
406
419
template <int type>
407
420
static opmask_t fpclass (reg_t x)
0 commit comments