Skip to content

Commit 374fe09

Browse files
authored
ggml : use std::sort in ggml_argsort CPU implementation (ggml-org#17211)
* ggml : use std::sort in ggml_argsort CPU implementation * cont : add missing header
1 parent 8e878f0 commit 374fe09

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
#include "unary-ops.h"
88
#include "vec.h"
99

10-
#include <float.h>
10+
#include <cfloat>
1111
#include <algorithm>
12+
#include <functional>
1213

1314
// ggml_compute_forward_dup
1415

@@ -7682,24 +7683,24 @@ static void ggml_compute_forward_argsort_f32(
76827683
ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
76837684

76847685
for (int64_t i = ith; i < nr; i += nth) {
7685-
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
76867686
const float * src_data = (float *)((char *) src0->data + i*nb01);
76877687

7688+
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
7689+
76887690
for (int64_t j = 0; j < ne0; j++) {
76897691
dst_data[j] = j;
76907692
}
76917693

7692-
// C doesn't have a functional sort, so we do a bubble sort instead
7693-
for (int64_t j = 0; j < ne0; j++) {
7694-
for (int64_t k = j + 1; k < ne0; k++) {
7695-
if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
7696-
(order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
7697-
int32_t tmp = dst_data[j];
7698-
dst_data[j] = dst_data[k];
7699-
dst_data[k] = tmp;
7700-
}
7701-
}
7694+
std::function<bool(int32_t, int32_t)> cmp;
7695+
7696+
// note: this might be causing memory allocations? ideally should be avoided if it's the case
7697+
switch (order) {
7698+
case GGML_SORT_ORDER_ASC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] < src_data[b]; }; break;
7699+
case GGML_SORT_ORDER_DESC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] > src_data[b]; }; break;
7700+
default: GGML_ABORT("invalid sort order");
77027701
}
7702+
7703+
std::sort(dst_data, dst_data + ne0, cmp);
77037704
}
77047705
}
77057706

0 commit comments

Comments
 (0)