Skip to content

Commit f4bca13

Browse files
committed
Correct language on pivot and partition functions
The comment and variable names appear misleading as the function actually returns the position of the element immediately following the last which is less than the pivot.
1 parent 9029f61 commit f4bca13

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

src/avx512-common-qsort.h

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ static inline zmm_t cmp_merge(zmm_t in1, zmm_t in2, opmask_t mask)
131131
return vtype::mask_mov(min, mask, max); // 0 -> min, 1 -> max
132132
}
133133
/*
134-
* Parition one ZMM register based on the pivot and returns the index of the
135-
* last element that is less than equal to the pivot.
134+
* Parition one ZMM register based on the pivot and returns the
135+
* number of elements that are greater than or equal to the pivot.
136136
*/
137137
template <typename vtype, typename type_t, typename zmm_t>
138138
static inline int32_t partition_vec(type_t *arr,
@@ -143,20 +143,20 @@ static inline int32_t partition_vec(type_t *arr,
143143
zmm_t *smallest_vec,
144144
zmm_t *biggest_vec)
145145
{
146-
/* which elements are larger than the pivot */
147-
typename vtype::opmask_t gt_mask = vtype::ge(curr_vec, pivot_vec);
148-
int32_t amount_gt_pivot = _mm_popcnt_u32((int32_t)gt_mask);
146+
/* which elements are larger than or equal to the pivot */
147+
typename vtype::opmask_t ge_mask = vtype::ge(curr_vec, pivot_vec);
148+
int32_t amount_ge_pivot = _mm_popcnt_u32((int32_t)ge_mask);
149149
vtype::mask_compressstoreu(
150-
arr + left, vtype::knot_opmask(gt_mask), curr_vec);
150+
arr + left, vtype::knot_opmask(ge_mask), curr_vec);
151151
vtype::mask_compressstoreu(
152-
arr + right - amount_gt_pivot, gt_mask, curr_vec);
152+
arr + right - amount_ge_pivot, ge_mask, curr_vec);
153153
*smallest_vec = vtype::min(curr_vec, *smallest_vec);
154154
*biggest_vec = vtype::max(curr_vec, *biggest_vec);
155-
return amount_gt_pivot;
155+
return amount_ge_pivot;
156156
}
157157
/*
158158
* Parition an array based on the pivot and returns the index of the
159-
* last element that is less than equal to the pivot.
159+
* first element that is greater than or equal to the pivot.
160160
*/
161161
template <typename vtype, typename type_t>
162162
static inline int64_t partition_avx512(type_t *arr,
@@ -188,7 +188,7 @@ static inline int64_t partition_avx512(type_t *arr,
188188

189189
if (right - left == vtype::numlanes) {
190190
zmm_t vec = vtype::loadu(arr + left);
191-
int32_t amount_gt_pivot = partition_vec<vtype>(arr,
191+
int32_t amount_ge_pivot = partition_vec<vtype>(arr,
192192
left,
193193
left + vtype::numlanes,
194194
vec,
@@ -197,7 +197,7 @@ static inline int64_t partition_avx512(type_t *arr,
197197
&max_vec);
198198
*smallest = vtype::reducemin(min_vec);
199199
*biggest = vtype::reducemax(max_vec);
200-
return left + (vtype::numlanes - amount_gt_pivot);
200+
return left + (vtype::numlanes - amount_ge_pivot);
201201
}
202202

203203
// first and last vtype::numlanes values are partitioned at the end
@@ -225,7 +225,7 @@ static inline int64_t partition_avx512(type_t *arr,
225225
left += vtype::numlanes;
226226
}
227227
// partition the current vector and save it on both sides of the array
228-
int32_t amount_gt_pivot
228+
int32_t amount_ge_pivot
229229
= partition_vec<vtype>(arr,
230230
l_store,
231231
r_store + vtype::numlanes,
@@ -234,27 +234,27 @@ static inline int64_t partition_avx512(type_t *arr,
234234
&min_vec,
235235
&max_vec);
236236
;
237-
r_store -= amount_gt_pivot;
238-
l_store += (vtype::numlanes - amount_gt_pivot);
237+
r_store -= amount_ge_pivot;
238+
l_store += (vtype::numlanes - amount_ge_pivot);
239239
}
240240

241241
/* partition and save vec_left and vec_right */
242-
int32_t amount_gt_pivot = partition_vec<vtype>(arr,
242+
int32_t amount_ge_pivot = partition_vec<vtype>(arr,
243243
l_store,
244244
r_store + vtype::numlanes,
245245
vec_left,
246246
pivot_vec,
247247
&min_vec,
248248
&max_vec);
249-
l_store += (vtype::numlanes - amount_gt_pivot);
250-
amount_gt_pivot = partition_vec<vtype>(arr,
249+
l_store += (vtype::numlanes - amount_ge_pivot);
250+
amount_ge_pivot = partition_vec<vtype>(arr,
251251
l_store,
252252
l_store + vtype::numlanes,
253253
vec_right,
254254
pivot_vec,
255255
&min_vec,
256256
&max_vec);
257-
l_store += (vtype::numlanes - amount_gt_pivot);
257+
l_store += (vtype::numlanes - amount_ge_pivot);
258258
*smallest = vtype::reducemin(min_vec);
259259
*biggest = vtype::reducemax(max_vec);
260260
return l_store;

0 commit comments

Comments
 (0)