@@ -67,7 +67,7 @@ X86_SIMD_SORT_INLINE void argsort_8_64bit(type_t *arr, int64_t *arg, int32_t N)
67
67
{
68
68
using reg_t = typename vtype::reg_t ;
69
69
typename vtype::opmask_t load_mask = (0x01 << N) - 0x01 ;
70
- argzmm_t argzmm = argtype::maskz_loadu (load_mask, arg);
70
+ argreg_t argzmm = argtype::maskz_loadu (load_mask, arg);
71
71
reg_t arrzmm = vtype::template mask_i64gather<sizeof (type_t )>(
72
72
vtype::zmm_max (), load_mask, argzmm, arr);
73
73
arrzmm = sort_zmm_64bit<vtype, argtype>(arrzmm, argzmm);
@@ -83,9 +83,9 @@ X86_SIMD_SORT_INLINE void argsort_16_64bit(type_t *arr, int64_t *arg, int32_t N)
83
83
}
84
84
using reg_t = typename vtype::reg_t ;
85
85
typename vtype::opmask_t load_mask = (0x01 << (N - 8 )) - 0x01 ;
86
- argzmm_t argzmm1 = argtype::loadu (arg);
87
- argzmm_t argzmm2 = argtype::maskz_loadu (load_mask, arg + 8 );
88
- reg_t arrzmm1 = vtype::template i64gather< sizeof ( type_t )>(argzmm1, arr );
86
+ argreg_t argzmm1 = argtype::loadu (arg);
87
+ argreg_t argzmm2 = argtype::maskz_loadu (load_mask, arg + 8 );
88
+ reg_t arrzmm1 = vtype::i64gather (arr, arg );
89
89
reg_t arrzmm2 = vtype::template mask_i64gather<sizeof (type_t )>(
90
90
vtype::zmm_max (), load_mask, argzmm2, arr);
91
91
arrzmm1 = sort_zmm_64bit<vtype, argtype>(arrzmm1, argzmm1);
@@ -106,12 +106,12 @@ X86_SIMD_SORT_INLINE void argsort_32_64bit(type_t *arr, int64_t *arg, int32_t N)
106
106
using reg_t = typename vtype::reg_t ;
107
107
using opmask_t = typename vtype::opmask_t ;
108
108
reg_t arrzmm[4 ];
109
- argzmm_t argzmm[4 ];
109
+ argreg_t argzmm[4 ];
110
110
111
111
X86_SIMD_SORT_UNROLL_LOOP (2 )
112
112
for (int ii = 0 ; ii < 2 ; ++ii) {
113
113
argzmm[ii] = argtype::loadu (arg + 8 * ii);
114
- arrzmm[ii] = vtype::template i64gather< sizeof ( type_t )>(argzmm[ii], arr );
114
+ arrzmm[ii] = vtype::i64gather (arr, arg + 8 * ii );
115
115
arrzmm[ii] = sort_zmm_64bit<vtype, argtype>(arrzmm[ii], argzmm[ii]);
116
116
}
117
117
@@ -149,12 +149,12 @@ X86_SIMD_SORT_INLINE void argsort_64_64bit(type_t *arr, int64_t *arg, int32_t N)
149
149
using reg_t = typename vtype::reg_t ;
150
150
using opmask_t = typename vtype::opmask_t ;
151
151
reg_t arrzmm[8 ];
152
- argzmm_t argzmm[8 ];
152
+ argreg_t argzmm[8 ];
153
153
154
154
X86_SIMD_SORT_UNROLL_LOOP (4 )
155
155
for (int ii = 0 ; ii < 4 ; ++ii) {
156
156
argzmm[ii] = argtype::loadu (arg + 8 * ii);
157
- arrzmm[ii] = vtype::template i64gather< sizeof ( type_t )>(argzmm[ii], arr );
157
+ arrzmm[ii] = vtype::i64gather (arr, arg + 8 * ii );
158
158
arrzmm[ii] = sort_zmm_64bit<vtype, argtype>(arrzmm[ii], argzmm[ii]);
159
159
}
160
160
@@ -201,12 +201,12 @@ X86_SIMD_SORT_UNROLL_LOOP(4)
201
201
// using reg_t = typename vtype::reg_t;
202
202
// using opmask_t = typename vtype::opmask_t;
203
203
// reg_t arrzmm[16];
204
- // argzmm_t argzmm[16];
204
+ // argreg_t argzmm[16];
205
205
//
206
206
// X86_SIMD_SORT_UNROLL_LOOP(8)
207
207
// for (int ii = 0; ii < 8; ++ii) {
208
208
// argzmm[ii] = argtype::loadu(arg + 8*ii);
209
- // arrzmm[ii] = vtype::template i64gather<sizeof(type_t)> (argzmm[ii], arr);
209
+ // arrzmm[ii] = vtype::i64gather(argzmm[ii], arr);
210
210
// arrzmm[ii] = sort_zmm_64bit<vtype, argtype>(arrzmm[ii], argzmm[ii]);
211
211
// }
212
212
//
@@ -257,17 +257,14 @@ type_t get_pivot_64bit(type_t *arr,
257
257
// median of 8
258
258
int64_t size = (right - left) / 8 ;
259
259
using reg_t = typename vtype::reg_t ;
260
- // TODO: Use gather here too:
261
- __m512i rand_index = _mm512_set_epi64 (arg[left + size],
262
- arg[left + 2 * size],
263
- arg[left + 3 * size],
264
- arg[left + 4 * size],
265
- arg[left + 5 * size],
266
- arg[left + 6 * size],
267
- arg[left + 7 * size],
268
- arg[left + 8 * size]);
269
- reg_t rand_vec
270
- = vtype::template i64gather<sizeof (type_t )>(rand_index, arr);
260
+ reg_t rand_vec = vtype::set (arr[arg[left + size]],
261
+ arr[arg[left + 2 * size]],
262
+ arr[arg[left + 3 * size]],
263
+ arr[arg[left + 4 * size]],
264
+ arr[arg[left + 5 * size]],
265
+ arr[arg[left + 6 * size]],
266
+ arr[arg[left + 7 * size]],
267
+ arr[arg[left + 8 * size]]);
271
268
// pivot will never be a nan, since there are no nan's!
272
269
reg_t sort = sort_zmm_64bit<vtype>(rand_vec);
273
270
return ((type_t *)&sort)[4 ];
0 commit comments