Skip to content

Commit 3561db3

Browse files
committed
Changed partition code
1 parent 2f397e4 commit 3561db3

7 files changed

+106
-71
lines changed

src/avx2-32bit-common.h

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,6 @@ struct avx2_vector<int32_t> {
129129
{
130130
return avx2_emu_mask_compressstoreu<type_t>(mem, mask, x);
131131
}
132-
static int32_t double_compressstore(type_t *left_addr,
133-
type_t *right_addr,
134-
opmask_t k,
135-
reg_t reg)
136-
{
137-
return avx2_double_compressstore32<type_t>(
138-
left_addr, right_addr, k, reg);
139-
}
140132
static reg_t maskz_loadu(opmask_t mask, void const *mem)
141133
{
142134
return _mm256_maskload_epi32((const int *)mem, mask);
@@ -210,6 +202,13 @@ struct avx2_vector<int32_t> {
210202
static __m256i cast_to(reg_t v){
211203
return v;
212204
}
205+
static int double_compressstore(type_t *left_addr,
206+
type_t *right_addr,
207+
opmask_t k,
208+
reg_t reg)
209+
{
210+
return avx2_double_compressstore32<type_t>(left_addr, right_addr, k, reg);
211+
}
213212
};
214213
template <>
215214
struct avx2_vector<uint32_t> {
@@ -277,14 +276,6 @@ struct avx2_vector<uint32_t> {
277276
{
278277
return avx2_emu_mask_compressstoreu<type_t>(mem, mask, x);
279278
}
280-
static int32_t double_compressstore(type_t *left_addr,
281-
type_t *right_addr,
282-
opmask_t k,
283-
reg_t reg)
284-
{
285-
return avx2_double_compressstore32<type_t>(
286-
left_addr, right_addr, k, reg);
287-
}
288279
static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem)
289280
{
290281
reg_t dst = _mm256_maskload_epi32((const int *)mem, mask);
@@ -353,6 +344,13 @@ struct avx2_vector<uint32_t> {
353344
static __m256i cast_to(reg_t v){
354345
return v;
355346
}
347+
static int double_compressstore(type_t *left_addr,
348+
type_t *right_addr,
349+
opmask_t k,
350+
reg_t reg)
351+
{
352+
return avx2_double_compressstore32<type_t>(left_addr, right_addr, k, reg);
353+
}
356354
};
357355
template <>
358356
struct avx2_vector<float> {
@@ -439,14 +437,6 @@ struct avx2_vector<float> {
439437
{
440438
return avx2_emu_mask_compressstoreu<type_t>(mem, mask, x);
441439
}
442-
static int32_t double_compressstore(type_t *left_addr,
443-
type_t *right_addr,
444-
opmask_t k,
445-
reg_t reg)
446-
{
447-
return avx2_double_compressstore32<type_t>(
448-
left_addr, right_addr, k, reg);
449-
}
450440
static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem)
451441
{
452442
reg_t dst = _mm256_maskload_ps((type_t *)mem, mask);
@@ -517,6 +507,13 @@ struct avx2_vector<float> {
517507
static __m256i cast_to(reg_t v){
518508
return _mm256_castps_si256(v);
519509
}
510+
static int double_compressstore(type_t *left_addr,
511+
type_t *right_addr,
512+
opmask_t k,
513+
reg_t reg)
514+
{
515+
return avx2_double_compressstore32<type_t>(left_addr, right_addr, k, reg);
516+
}
520517
};
521518

522519
struct avx2_32bit_swizzle_ops{

src/avx2-emu-funcs.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ void avx2_emu_mask_compressstoreu(void *base_addr,
140140
}
141141

142142
template <typename T>
143-
int32_t avx2_double_compressstore32(void *left_addr,
143+
int avx2_double_compressstore32(void *left_addr,
144144
void *right_addr,
145145
typename avx2_vector<T>::opmask_t k,
146146
typename avx2_vector<T>::reg_t reg)

src/avx512-16bit-qsort.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,13 @@ struct zmm_vector<float16> {
177177
{
178178
return v;
179179
}
180+
static int double_compressstore(type_t *left_addr,
181+
type_t *right_addr,
182+
opmask_t k,
183+
reg_t reg)
184+
{
185+
return avx512_double_compressstore<zmm_vector<float16>>(left_addr, right_addr, k, reg);
186+
}
180187
};
181188

182189
template <>
@@ -301,6 +308,13 @@ struct zmm_vector<int16_t> {
301308
{
302309
return v;
303310
}
311+
static int double_compressstore(type_t *left_addr,
312+
type_t *right_addr,
313+
opmask_t k,
314+
reg_t reg)
315+
{
316+
return avx512_double_compressstore<zmm_vector<type_t>>(left_addr, right_addr, k, reg);
317+
}
304318
};
305319
template <>
306320
struct zmm_vector<uint16_t> {
@@ -422,6 +436,13 @@ struct zmm_vector<uint16_t> {
422436
{
423437
return v;
424438
}
439+
static int double_compressstore(type_t *left_addr,
440+
type_t *right_addr,
441+
opmask_t k,
442+
reg_t reg)
443+
{
444+
return avx512_double_compressstore<zmm_vector<type_t>>(left_addr, right_addr, k, reg);
445+
}
425446
};
426447

427448
template <>

src/avx512-32bit-qsort.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,13 @@ struct zmm_vector<int32_t> {
154154
{
155155
return v;
156156
}
157+
static int double_compressstore(type_t *left_addr,
158+
type_t *right_addr,
159+
opmask_t k,
160+
reg_t reg)
161+
{
162+
return avx512_double_compressstore<zmm_vector<type_t>>(left_addr, right_addr, k, reg);
163+
}
157164
};
158165
template <>
159166
struct zmm_vector<uint32_t> {
@@ -281,6 +288,13 @@ struct zmm_vector<uint32_t> {
281288
{
282289
return v;
283290
}
291+
static int double_compressstore(type_t *left_addr,
292+
type_t *right_addr,
293+
opmask_t k,
294+
reg_t reg)
295+
{
296+
return avx512_double_compressstore<zmm_vector<type_t>>(left_addr, right_addr, k, reg);
297+
}
284298
};
285299
template <>
286300
struct zmm_vector<float> {
@@ -422,6 +436,13 @@ struct zmm_vector<float> {
422436
{
423437
return _mm512_castps_si512(v);
424438
}
439+
static int double_compressstore(type_t *left_addr,
440+
type_t *right_addr,
441+
opmask_t k,
442+
reg_t reg)
443+
{
444+
return avx512_double_compressstore<zmm_vector<type_t>>(left_addr, right_addr, k, reg);
445+
}
425446
};
426447

427448
/*

src/avx512-64bit-common.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,13 @@ struct zmm_vector<int64_t> {
660660
{
661661
return v;
662662
}
663+
static int double_compressstore(type_t *left_addr,
664+
type_t *right_addr,
665+
opmask_t k,
666+
reg_t reg)
667+
{
668+
return avx512_double_compressstore<zmm_vector<type_t>>(left_addr, right_addr, k, reg);
669+
}
663670
};
664671
template <>
665672
struct zmm_vector<uint64_t> {
@@ -818,6 +825,13 @@ struct zmm_vector<uint64_t> {
818825
{
819826
return v;
820827
}
828+
static int double_compressstore(type_t *left_addr,
829+
type_t *right_addr,
830+
opmask_t k,
831+
reg_t reg)
832+
{
833+
return avx512_double_compressstore<zmm_vector<type_t>>(left_addr, right_addr, k, reg);
834+
}
821835
};
822836
template <>
823837
struct zmm_vector<double> {
@@ -982,6 +996,13 @@ struct zmm_vector<double> {
982996
{
983997
return _mm512_castpd_si512(v);
984998
}
999+
static int double_compressstore(type_t *left_addr,
1000+
type_t *right_addr,
1001+
opmask_t k,
1002+
reg_t reg)
1003+
{
1004+
return avx512_double_compressstore<zmm_vector<type_t>>(left_addr, right_addr, k, reg);
1005+
}
9851006
};
9861007

9871008
/*

src/avx512fp16-16bit-qsort.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,13 @@ struct zmm_vector<_Float16> {
145145
{
146146
return _mm512_castph_si512(v);
147147
}
148+
static int double_compressstore(type_t *left_addr,
149+
type_t *right_addr,
150+
opmask_t k,
151+
reg_t reg)
152+
{
153+
return avx512_double_compressstore<zmm_vector<type_t>>(left_addr, right_addr, k, reg);
154+
}
148155
};
149156

150157
template <>

src/xss-common-qsort.h

Lines changed: 14 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -162,73 +162,41 @@ X86_SIMD_SORT_INLINE reg_t cmp_merge(reg_t in1, reg_t in2, opmask_t mask)
162162
reg_t max = vtype::max(in2, in1);
163163
return vtype::mask_mov(min, mask, max); // 0 -> min, 1 -> max
164164
}
165-
/*
166-
* Parition one ZMM register based on the pivot and returns the
167-
* number of elements that are greater than or equal to the pivot.
168-
*/
165+
169166
template <typename vtype, typename type_t, typename reg_t>
170-
X86_SIMD_SORT_INLINE arrsize_t partition_vec_avx512(type_t *l_store,
171-
type_t *r_store,
172-
const reg_t curr_vec,
173-
const reg_t pivot_vec,
174-
reg_t &smallest_vec,
175-
reg_t &biggest_vec)
167+
int avx512_double_compressstore(type_t *left_addr,
168+
type_t *right_addr,
169+
typename vtype::opmask_t k,
170+
reg_t reg)
176171
{
177-
typename vtype::opmask_t ge_mask = vtype::ge(curr_vec, pivot_vec);
178-
int amount_ge_pivot = _mm_popcnt_u32((int)ge_mask);
172+
int amount_ge_pivot = _mm_popcnt_u32((int)k);
179173

180-
vtype::mask_compressstoreu(l_store, vtype::knot_opmask(ge_mask), curr_vec);
174+
vtype::mask_compressstoreu(left_addr, vtype::knot_opmask(k), reg);
181175
vtype::mask_compressstoreu(
182-
r_store + vtype::numlanes - amount_ge_pivot, ge_mask, curr_vec);
183-
184-
smallest_vec = vtype::min(curr_vec, smallest_vec);
185-
biggest_vec = vtype::max(curr_vec, biggest_vec);
186-
176+
right_addr + vtype::numlanes - amount_ge_pivot, k, reg);
177+
187178
return amount_ge_pivot;
188179
}
189-
/*
190-
* Parition one YMM register based on the pivot and returns the
191-
* number of elements that are greater than or equal to the pivot.
192-
*/
180+
181+
// Generic function dispatches to AVX2 or AVX512 code
193182
template <typename vtype, typename type_t, typename reg_t = typename vtype::reg_t>
194-
X86_SIMD_SORT_INLINE arrsize_t partition_vec_avx2(type_t *l_store,
183+
X86_SIMD_SORT_INLINE arrsize_t partition_vec(type_t *l_store,
195184
type_t *r_store,
196185
const reg_t curr_vec,
197186
const reg_t pivot_vec,
198187
reg_t &smallest_vec,
199188
reg_t &biggest_vec)
200189
{
201-
/* which elements are larger than or equal to the pivot */
202190
typename vtype::opmask_t ge_mask = vtype::ge(curr_vec, pivot_vec);
203-
204-
int32_t amount_ge_pivot = vtype::double_compressstore(
205-
l_store, r_store, ge_mask, curr_vec);
191+
192+
int amount_ge_pivot = vtype::double_compressstore(l_store, r_store, ge_mask, curr_vec);
206193

207194
smallest_vec = vtype::min(curr_vec, smallest_vec);
208195
biggest_vec = vtype::max(curr_vec, biggest_vec);
209196

210197
return amount_ge_pivot;
211198
}
212199

213-
// Generic function dispatches to AVX2 or AVX512 code
214-
template <typename vtype, typename type_t, typename reg_t = typename vtype::reg_t>
215-
X86_SIMD_SORT_INLINE arrsize_t partition_vec(type_t *l_store,
216-
type_t *r_store,
217-
const reg_t curr_vec,
218-
const reg_t pivot_vec,
219-
reg_t &smallest_vec,
220-
reg_t &biggest_vec)
221-
{
222-
if constexpr (sizeof(reg_t) == 64){
223-
return partition_vec_avx512<vtype>(l_store, r_store, curr_vec, pivot_vec, smallest_vec, biggest_vec);
224-
}else if constexpr (sizeof(reg_t) == 32){
225-
return partition_vec_avx2<vtype>(l_store, r_store, curr_vec, pivot_vec, smallest_vec, biggest_vec);
226-
}else{
227-
static_assert(sizeof(reg_t) == -1, "should not reach here");
228-
return 0;
229-
}
230-
}
231-
232200
/*
233201
* Parition an array based on the pivot and returns the index of the
234202
* first element that is greater than or equal to the pivot.

0 commit comments

Comments
 (0)