Skip to content

Commit 4d9b15b

Browse files
author
Raghuveer Devulapalli
authored
Merge pull request #65 from r-devulap/gather-scalar
Use scalar emulation of gather instruction for arg methods
2 parents 73cfb4f + 22c2f02 commit 4d9b15b

File tree

5 files changed

+427
-341
lines changed

5 files changed

+427
-341
lines changed

src/avx512-64bit-argsort.hpp

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ X86_SIMD_SORT_INLINE void argsort_8_64bit(type_t *arr, int64_t *arg, int32_t N)
6767
{
6868
using reg_t = typename vtype::reg_t;
6969
typename vtype::opmask_t load_mask = (0x01 << N) - 0x01;
70-
argzmm_t argzmm = argtype::maskz_loadu(load_mask, arg);
70+
argreg_t argzmm = argtype::maskz_loadu(load_mask, arg);
7171
reg_t arrzmm = vtype::template mask_i64gather<sizeof(type_t)>(
7272
vtype::zmm_max(), load_mask, argzmm, arr);
7373
arrzmm = sort_zmm_64bit<vtype, argtype>(arrzmm, argzmm);
@@ -83,9 +83,9 @@ X86_SIMD_SORT_INLINE void argsort_16_64bit(type_t *arr, int64_t *arg, int32_t N)
8383
}
8484
using reg_t = typename vtype::reg_t;
8585
typename vtype::opmask_t load_mask = (0x01 << (N - 8)) - 0x01;
86-
argzmm_t argzmm1 = argtype::loadu(arg);
87-
argzmm_t argzmm2 = argtype::maskz_loadu(load_mask, arg + 8);
88-
reg_t arrzmm1 = vtype::template i64gather<sizeof(type_t)>(argzmm1, arr);
86+
argreg_t argzmm1 = argtype::loadu(arg);
87+
argreg_t argzmm2 = argtype::maskz_loadu(load_mask, arg + 8);
88+
reg_t arrzmm1 = vtype::i64gather(arr, arg);
8989
reg_t arrzmm2 = vtype::template mask_i64gather<sizeof(type_t)>(
9090
vtype::zmm_max(), load_mask, argzmm2, arr);
9191
arrzmm1 = sort_zmm_64bit<vtype, argtype>(arrzmm1, argzmm1);
@@ -106,12 +106,12 @@ X86_SIMD_SORT_INLINE void argsort_32_64bit(type_t *arr, int64_t *arg, int32_t N)
106106
using reg_t = typename vtype::reg_t;
107107
using opmask_t = typename vtype::opmask_t;
108108
reg_t arrzmm[4];
109-
argzmm_t argzmm[4];
109+
argreg_t argzmm[4];
110110

111111
X86_SIMD_SORT_UNROLL_LOOP(2)
112112
for (int ii = 0; ii < 2; ++ii) {
113113
argzmm[ii] = argtype::loadu(arg + 8 * ii);
114-
arrzmm[ii] = vtype::template i64gather<sizeof(type_t)>(argzmm[ii], arr);
114+
arrzmm[ii] = vtype::i64gather(arr, arg + 8 * ii);
115115
arrzmm[ii] = sort_zmm_64bit<vtype, argtype>(arrzmm[ii], argzmm[ii]);
116116
}
117117

@@ -149,12 +149,12 @@ X86_SIMD_SORT_INLINE void argsort_64_64bit(type_t *arr, int64_t *arg, int32_t N)
149149
using reg_t = typename vtype::reg_t;
150150
using opmask_t = typename vtype::opmask_t;
151151
reg_t arrzmm[8];
152-
argzmm_t argzmm[8];
152+
argreg_t argzmm[8];
153153

154154
X86_SIMD_SORT_UNROLL_LOOP(4)
155155
for (int ii = 0; ii < 4; ++ii) {
156156
argzmm[ii] = argtype::loadu(arg + 8 * ii);
157-
arrzmm[ii] = vtype::template i64gather<sizeof(type_t)>(argzmm[ii], arr);
157+
arrzmm[ii] = vtype::i64gather(arr, arg + 8 * ii);
158158
arrzmm[ii] = sort_zmm_64bit<vtype, argtype>(arrzmm[ii], argzmm[ii]);
159159
}
160160

@@ -201,12 +201,12 @@ X86_SIMD_SORT_UNROLL_LOOP(4)
201201
// using reg_t = typename vtype::reg_t;
202202
// using opmask_t = typename vtype::opmask_t;
203203
// reg_t arrzmm[16];
204-
// argzmm_t argzmm[16];
204+
// argreg_t argzmm[16];
205205
//
206206
//X86_SIMD_SORT_UNROLL_LOOP(8)
207207
// for (int ii = 0; ii < 8; ++ii) {
208208
// argzmm[ii] = argtype::loadu(arg + 8*ii);
209-
// arrzmm[ii] = vtype::template i64gather<sizeof(type_t)>(argzmm[ii], arr);
209+
// arrzmm[ii] = vtype::i64gather(argzmm[ii], arr);
210210
// arrzmm[ii] = sort_zmm_64bit<vtype, argtype>(arrzmm[ii], argzmm[ii]);
211211
// }
212212
//
@@ -257,17 +257,14 @@ type_t get_pivot_64bit(type_t *arr,
257257
// median of 8
258258
int64_t size = (right - left) / 8;
259259
using reg_t = typename vtype::reg_t;
260-
// TODO: Use gather here too:
261-
__m512i rand_index = _mm512_set_epi64(arg[left + size],
262-
arg[left + 2 * size],
263-
arg[left + 3 * size],
264-
arg[left + 4 * size],
265-
arg[left + 5 * size],
266-
arg[left + 6 * size],
267-
arg[left + 7 * size],
268-
arg[left + 8 * size]);
269-
reg_t rand_vec
270-
= vtype::template i64gather<sizeof(type_t)>(rand_index, arr);
260+
reg_t rand_vec = vtype::set(arr[arg[left + size]],
261+
arr[arg[left + 2 * size]],
262+
arr[arg[left + 3 * size]],
263+
arr[arg[left + 4 * size]],
264+
arr[arg[left + 5 * size]],
265+
arr[arg[left + 6 * size]],
266+
arr[arg[left + 7 * size]],
267+
arr[arg[left + 8 * size]]);
271268
// pivot will never be a nan, since there are no nan's!
272269
reg_t sort = sort_zmm_64bit<vtype>(rand_vec);
273270
return ((type_t *)&sort)[4];

0 commit comments

Comments
 (0)