7
7
#ifndef AVX2_HALF_32BIT
8
8
#define AVX2_HALF_32BIT
9
9
10
- #include " xss-common-qsort .h"
10
+ #include " xss-common-includes .h"
11
11
#include " avx2-emu-funcs.hpp"
12
12
13
13
/*
@@ -46,7 +46,7 @@ template <>
46
46
struct avx2_half_vector <int32_t > {
47
47
using type_t = int32_t ;
48
48
using reg_t = __m128i;
49
- using ymmi_t = __m128i;
49
+ using regi_t = __m128i;
50
50
using opmask_t = __m128i;
51
51
static const uint8_t numlanes = 4 ;
52
52
static constexpr simd_type vec_type = simd_type::AVX2;
@@ -70,7 +70,7 @@ struct avx2_half_vector<int32_t> {
70
70
auto mask = ((0x1ull << num_to_read) - 0x1ull );
71
71
return convert_int_to_avx2_mask_half (mask);
72
72
}
73
- static ymmi_t seti (int v1, int v2, int v3, int v4)
73
+ static regi_t seti (int v1, int v2, int v3, int v4)
74
74
{
75
75
return _mm_set_epi32 (v1, v2, v3, v4);
76
76
}
@@ -86,8 +86,7 @@ struct avx2_half_vector<int32_t> {
86
86
{
87
87
opmask_t equal = eq (x, y);
88
88
opmask_t greater = _mm_cmpgt_epi32 (x, y);
89
- return _mm_castps_si128 (
90
- _mm_or_ps (_mm_castsi128_ps (equal), _mm_castsi128_ps (greater)));
89
+ return _mm_or_si128 (equal, greater);
91
90
}
92
91
static opmask_t eq (reg_t x, reg_t y)
93
92
{
@@ -150,10 +149,6 @@ struct avx2_half_vector<int32_t> {
150
149
{
151
150
return _mm_castps_si128 (_mm_permutevar_ps (_mm_castsi128_ps (ymm), idx));
152
151
}
153
- static reg_t permutevar (reg_t ymm, __m128i idx)
154
- {
155
- return _mm_castps_si128 (_mm_permutevar_ps (_mm_castsi128_ps (ymm), idx));
156
- }
157
152
static reg_t reverse (reg_t ymm)
158
153
{
159
154
const __m128i rev_index = _mm_set_epi32 (0 , 1 , 2 , 3 );
@@ -205,7 +200,7 @@ template <>
205
200
struct avx2_half_vector <uint32_t > {
206
201
using type_t = uint32_t ;
207
202
using reg_t = __m128i;
208
- using ymmi_t = __m128i;
203
+ using regi_t = __m128i;
209
204
using opmask_t = __m128i;
210
205
static const uint8_t numlanes = 4 ;
211
206
static constexpr simd_type vec_type = simd_type::AVX2;
@@ -229,7 +224,7 @@ struct avx2_half_vector<uint32_t> {
229
224
auto mask = ((0x1ull << num_to_read) - 0x1ull );
230
225
return convert_int_to_avx2_mask_half (mask);
231
226
}
232
- static ymmi_t seti (int v1, int v2, int v3, int v4)
227
+ static regi_t seti (int v1, int v2, int v3, int v4)
233
228
{
234
229
return _mm_set_epi32 (v1, v2, v3, v4);
235
230
}
@@ -299,10 +294,6 @@ struct avx2_half_vector<uint32_t> {
299
294
{
300
295
return _mm_castps_si128 (_mm_permutevar_ps (_mm_castsi128_ps (ymm), idx));
301
296
}
302
- static reg_t permutevar (reg_t ymm, __m128i idx)
303
- {
304
- return _mm_castps_si128 (_mm_permutevar_ps (_mm_castsi128_ps (ymm), idx));
305
- }
306
297
static reg_t reverse (reg_t ymm)
307
298
{
308
299
const __m128i rev_index = _mm_set_epi32 (0 , 1 , 2 , 3 );
@@ -354,7 +345,7 @@ template <>
354
345
struct avx2_half_vector <float > {
355
346
using type_t = float ;
356
347
using reg_t = __m128;
357
- using ymmi_t = __m128i;
348
+ using regi_t = __m128i;
358
349
using opmask_t = __m128i;
359
350
static const uint8_t numlanes = 4 ;
360
351
static constexpr simd_type vec_type = simd_type::AVX2;
@@ -374,7 +365,7 @@ struct avx2_half_vector<float> {
374
365
return _mm_set1_ps (type_max ());
375
366
}
376
367
377
- static ymmi_t seti (int v1, int v2, int v3, int v4)
368
+ static regi_t seti (int v1, int v2, int v3, int v4)
378
369
{
379
370
return _mm_set_epi32 (v1, v2, v3, v4);
380
371
}
@@ -464,10 +455,6 @@ struct avx2_half_vector<float> {
464
455
{
465
456
return _mm_permutevar_ps (ymm, idx);
466
457
}
467
- static reg_t permutevar (reg_t ymm, __m128i idx)
468
- {
469
- return _mm_permutevar_ps (ymm, idx);
470
- }
471
458
static reg_t reverse (reg_t ymm)
472
459
{
473
460
const __m128i rev_index = _mm_set_epi32 (0 , 1 , 2 , 3 );
@@ -520,23 +507,15 @@ struct avx2_32bit_half_swizzle_ops {
520
507
template <typename vtype, int scale>
521
508
X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n (typename vtype::reg_t reg)
522
509
{
523
- __m128i v = vtype::cast_to (reg);
524
-
525
510
if constexpr (scale == 2 ) {
526
- __m128 vf = _mm_castsi128_ps (v);
527
- vf = _mm_permute_ps (vf, 0b10110001 );
528
- v = _mm_castps_si128 (vf);
511
+ return vtype::template shuffle<0b10110001 >(reg);
529
512
}
530
513
else if constexpr (scale == 4 ) {
531
- __m128 vf = _mm_castsi128_ps (v);
532
- vf = _mm_permute_ps (vf, 0b01001110 );
533
- v = _mm_castps_si128 (vf);
514
+ return vtype::template shuffle<0b01001110 >(reg);
534
515
}
535
516
else {
536
517
static_assert (scale == -1 , " should not be reached" );
537
518
}
538
-
539
- return vtype::cast_from (v);
540
519
}
541
520
542
521
template <typename vtype, int scale>
0 commit comments