Skip to content

Commit 1834edc

Browse files
committed
Code review fixes
1 parent 7388ed7 commit 1834edc

File tree

4 files changed

+20
-70
lines changed

4 files changed

+20
-70
lines changed

src/avx2-32bit-half.hpp

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#ifndef AVX2_HALF_32BIT
88
#define AVX2_HALF_32BIT
99

10-
#include "xss-common-qsort.h"
10+
#include "xss-common-includes.h"
1111
#include "avx2-emu-funcs.hpp"
1212

1313
/*
@@ -46,7 +46,7 @@ template <>
4646
struct avx2_half_vector<int32_t> {
4747
using type_t = int32_t;
4848
using reg_t = __m128i;
49-
using ymmi_t = __m128i;
49+
using regi_t = __m128i;
5050
using opmask_t = __m128i;
5151
static const uint8_t numlanes = 4;
5252
static constexpr simd_type vec_type = simd_type::AVX2;
@@ -70,7 +70,7 @@ struct avx2_half_vector<int32_t> {
7070
auto mask = ((0x1ull << num_to_read) - 0x1ull);
7171
return convert_int_to_avx2_mask_half(mask);
7272
}
73-
static ymmi_t seti(int v1, int v2, int v3, int v4)
73+
static regi_t seti(int v1, int v2, int v3, int v4)
7474
{
7575
return _mm_set_epi32(v1, v2, v3, v4);
7676
}
@@ -86,8 +86,7 @@ struct avx2_half_vector<int32_t> {
8686
{
8787
opmask_t equal = eq(x, y);
8888
opmask_t greater = _mm_cmpgt_epi32(x, y);
89-
return _mm_castps_si128(
90-
_mm_or_ps(_mm_castsi128_ps(equal), _mm_castsi128_ps(greater)));
89+
return _mm_or_si128(equal, greater);
9190
}
9291
static opmask_t eq(reg_t x, reg_t y)
9392
{
@@ -150,10 +149,6 @@ struct avx2_half_vector<int32_t> {
150149
{
151150
return _mm_castps_si128(_mm_permutevar_ps(_mm_castsi128_ps(ymm), idx));
152151
}
153-
static reg_t permutevar(reg_t ymm, __m128i idx)
154-
{
155-
return _mm_castps_si128(_mm_permutevar_ps(_mm_castsi128_ps(ymm), idx));
156-
}
157152
static reg_t reverse(reg_t ymm)
158153
{
159154
const __m128i rev_index = _mm_set_epi32(0, 1, 2, 3);
@@ -205,7 +200,7 @@ template <>
205200
struct avx2_half_vector<uint32_t> {
206201
using type_t = uint32_t;
207202
using reg_t = __m128i;
208-
using ymmi_t = __m128i;
203+
using regi_t = __m128i;
209204
using opmask_t = __m128i;
210205
static const uint8_t numlanes = 4;
211206
static constexpr simd_type vec_type = simd_type::AVX2;
@@ -229,7 +224,7 @@ struct avx2_half_vector<uint32_t> {
229224
auto mask = ((0x1ull << num_to_read) - 0x1ull);
230225
return convert_int_to_avx2_mask_half(mask);
231226
}
232-
static ymmi_t seti(int v1, int v2, int v3, int v4)
227+
static regi_t seti(int v1, int v2, int v3, int v4)
233228
{
234229
return _mm_set_epi32(v1, v2, v3, v4);
235230
}
@@ -299,10 +294,6 @@ struct avx2_half_vector<uint32_t> {
299294
{
300295
return _mm_castps_si128(_mm_permutevar_ps(_mm_castsi128_ps(ymm), idx));
301296
}
302-
static reg_t permutevar(reg_t ymm, __m128i idx)
303-
{
304-
return _mm_castps_si128(_mm_permutevar_ps(_mm_castsi128_ps(ymm), idx));
305-
}
306297
static reg_t reverse(reg_t ymm)
307298
{
308299
const __m128i rev_index = _mm_set_epi32(0, 1, 2, 3);
@@ -354,7 +345,7 @@ template <>
354345
struct avx2_half_vector<float> {
355346
using type_t = float;
356347
using reg_t = __m128;
357-
using ymmi_t = __m128i;
348+
using regi_t = __m128i;
358349
using opmask_t = __m128i;
359350
static const uint8_t numlanes = 4;
360351
static constexpr simd_type vec_type = simd_type::AVX2;
@@ -374,7 +365,7 @@ struct avx2_half_vector<float> {
374365
return _mm_set1_ps(type_max());
375366
}
376367

377-
static ymmi_t seti(int v1, int v2, int v3, int v4)
368+
static regi_t seti(int v1, int v2, int v3, int v4)
378369
{
379370
return _mm_set_epi32(v1, v2, v3, v4);
380371
}
@@ -464,10 +455,6 @@ struct avx2_half_vector<float> {
464455
{
465456
return _mm_permutevar_ps(ymm, idx);
466457
}
467-
static reg_t permutevar(reg_t ymm, __m128i idx)
468-
{
469-
return _mm_permutevar_ps(ymm, idx);
470-
}
471458
static reg_t reverse(reg_t ymm)
472459
{
473460
const __m128i rev_index = _mm_set_epi32(0, 1, 2, 3);
@@ -520,23 +507,15 @@ struct avx2_32bit_half_swizzle_ops {
520507
template <typename vtype, int scale>
521508
X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg)
522509
{
523-
__m128i v = vtype::cast_to(reg);
524-
525510
if constexpr (scale == 2) {
526-
__m128 vf = _mm_castsi128_ps(v);
527-
vf = _mm_permute_ps(vf, 0b10110001);
528-
v = _mm_castps_si128(vf);
511+
return vtype::template shuffle<0b10110001>(reg);
529512
}
530513
else if constexpr (scale == 4) {
531-
__m128 vf = _mm_castsi128_ps(v);
532-
vf = _mm_permute_ps(vf, 0b01001110);
533-
v = _mm_castps_si128(vf);
514+
return vtype::template shuffle<0b01001110>(reg);
534515
}
535516
else {
536517
static_assert(scale == -1, "should not be reached");
537518
}
538-
539-
return vtype::cast_from(v);
540519
}
541520

542521
template <typename vtype, int scale>

src/avx2-emu-funcs.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ void avx2_emu_mask_compressstoreu32(void *base_addr,
277277
const __m256i &left = _mm256_loadu_si256(
278278
(const __m256i *)avx2_compressstore_lut32_left[shortMask].data());
279279

280-
typename vtype::reg_t temp = vtype::permutevar(reg, perm);
280+
typename vtype::reg_t temp = vtype::permutexvar(perm, reg);
281281

282282
vtype::mask_storeu(leftStore, left, temp);
283283
}
@@ -300,7 +300,7 @@ void avx2_emu_mask_compressstoreu32_half(
300300
(const __m128i *)avx2_compressstore_lut32_half_left[shortMask]
301301
.data());
302302

303-
typename vtype::reg_t temp = vtype::permutevar(reg, perm);
303+
typename vtype::reg_t temp = vtype::permutexvar(perm, reg);
304304

305305
vtype::mask_storeu(leftStore, left, temp);
306306
}
@@ -341,7 +341,7 @@ int avx2_double_compressstore32(void *left_addr,
341341
const __m256i &perm = _mm256_loadu_si256(
342342
(const __m256i *)avx2_compressstore_lut32_perm[shortMask].data());
343343

344-
typename vtype::reg_t temp = vtype::permutevar(reg, perm);
344+
typename vtype::reg_t temp = vtype::permutexvar(perm, reg);
345345

346346
vtype::storeu(leftStore, temp);
347347
vtype::storeu(rightStore, temp);
@@ -365,7 +365,7 @@ int avx2_double_compressstore32_half(void *left_addr,
365365
(const __m128i *)avx2_compressstore_lut32_half_perm[shortMask]
366366
.data());
367367

368-
typename vtype::reg_t temp = vtype::permutevar(reg, perm);
368+
typename vtype::reg_t temp = vtype::permutexvar(perm, reg);
369369

370370
vtype::storeu(leftStore, temp);
371371
vtype::storeu(rightStore, temp);

src/xss-common-argsort.h

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -64,20 +64,6 @@ std_argsort(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right)
6464
});
6565
}
6666

67-
/* Workaround for NumPy failed build on macOS x86_64: implicit instantiation of
68-
* undefined template 'zmm_vector<unsigned long>'*/
69-
#ifdef __APPLE__
70-
using argtypeAVX512 =
71-
typename std::conditional<sizeof(arrsize_t) == sizeof(int32_t),
72-
ymm_vector<uint32_t>,
73-
zmm_vector<uint64_t>>::type;
74-
#else
75-
using argtypeAVX512 =
76-
typename std::conditional<sizeof(arrsize_t) == sizeof(int32_t),
77-
ymm_vector<arrsize_t>,
78-
zmm_vector<arrsize_t>>::type;
79-
#endif
80-
8167
/*
8268
* Parition one ZMM register based on the pivot and returns the index of the
8369
* last element that is less than equal to the pivot.
@@ -129,7 +115,7 @@ X86_SIMD_SORT_INLINE int32_t partition_vec_avx2(type_t *arg,
129115
/* which elements are larger than the pivot */
130116
typename vtype::opmask_t ge_mask_vtype = vtype::ge(curr_vec, pivot_vec);
131117
typename argtype::opmask_t ge_mask
132-
= extend_mask<vtype, argtype>(ge_mask_vtype);
118+
= resize_mask<vtype, argtype>(ge_mask_vtype);
133119

134120
auto l_store = arg + left;
135121
auto r_store = arg + right - vtype::numlanes;
@@ -727,19 +713,4 @@ avx2_argselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false)
727713
return indices;
728714
}
729715

730-
/* To maintain compatibility with NumPy build */
731-
template <typename T>
732-
X86_SIMD_SORT_INLINE void
733-
avx512_argselect(T *arr, int64_t *arg, arrsize_t k, arrsize_t arrsize)
734-
{
735-
avx512_argselect(arr, reinterpret_cast<arrsize_t *>(arg), k, arrsize);
736-
}
737-
738-
template <typename T>
739-
X86_SIMD_SORT_INLINE void
740-
avx512_argsort(T *arr, int64_t *arg, arrsize_t arrsize)
741-
{
742-
avx512_argsort(arr, reinterpret_cast<arrsize_t *>(arg), arrsize);
743-
}
744-
745716
#endif // XSS_COMMON_ARGSORT

src/xss-network-keyvaluesort.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#define NETWORK_32BIT_7 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8
1111

1212
template <typename keyType, typename valueType>
13-
typename valueType::opmask_t extend_mask(typename keyType::opmask_t mask)
13+
typename valueType::opmask_t resize_mask(typename keyType::opmask_t mask)
1414
{
1515
using inT = typename keyType::opmask_t;
1616
using outT = typename valueType::opmask_t;
@@ -48,7 +48,7 @@ COEX(reg_t1 &key1, reg_t1 &key2, reg_t2 &index1, reg_t2 &index2)
4848
reg_t1 key_t1 = vtype1::min(key1, key2);
4949
reg_t1 key_t2 = vtype1::max(key1, key2);
5050

51-
auto eqMask = extend_mask<vtype1, vtype2>(vtype1::eq(key_t1, key1));
51+
auto eqMask = resize_mask<vtype1, vtype2>(vtype1::eq(key_t1, key1));
5252

5353
reg_t2 index_t1 = vtype2::mask_mov(index2, eqMask, index1);
5454
reg_t2 index_t2 = vtype2::mask_mov(index1, eqMask, index2);
@@ -73,7 +73,7 @@ X86_SIMD_SORT_INLINE reg_t1 cmp_merge(reg_t1 in1,
7373
reg_t1 tmp_keys = cmp_merge<vtype1>(in1, in2, mask);
7474
indexes1 = vtype2::mask_mov(
7575
indexes2,
76-
extend_mask<vtype1, vtype2>(vtype1::eq(tmp_keys, in1)),
76+
resize_mask<vtype1, vtype2>(vtype1::eq(tmp_keys, in1)),
7777
indexes1);
7878
return tmp_keys; // 0 -> min, 1 -> max
7979
}
@@ -503,7 +503,7 @@ X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys,
503503
for (int i = numVecs / 2; i < numVecs; i++) {
504504
indexVecs[i] = indexType::mask_loadu(
505505
indexType::zmm_max(),
506-
extend_mask<keyType, indexType>(ioMasks[i - numVecs / 2]),
506+
resize_mask<keyType, indexType>(ioMasks[i - numVecs / 2]),
507507
indices + i * indexType::numlanes);
508508

509509
keyVecs[i] = keyType::template mask_i64gather<sizeof(
@@ -532,7 +532,7 @@ X86_SIMD_SORT_INLINE void argsort_n_vec(typename keyType::type_t *keys,
532532
for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) {
533533
indexType::mask_storeu(
534534
indices + i * indexType::numlanes,
535-
extend_mask<keyType, indexType>(ioMasks[i - numVecs / 2]),
535+
resize_mask<keyType, indexType>(ioMasks[i - numVecs / 2]),
536536
indexVecs[i]);
537537
}
538538
}

0 commit comments

Comments
 (0)