Skip to content

Commit 879dec3

Browse files
authored
ggml-cpu : use template for argsort (#17222)
1 parent 97d5117 commit 879dec3

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7665,6 +7665,18 @@ void ggml_compute_forward_timestep_embedding(
76657665

76667666
// ggml_compute_forward_argsort
76677667

7668+
template<enum ggml_sort_order order>
7669+
struct argsort_cmp {
7670+
const float * data;
7671+
bool operator()(int32_t a, int32_t b) const {
7672+
if constexpr (order == GGML_SORT_ORDER_ASC) {
7673+
return data[a] < data[b];
7674+
} else {
7675+
return data[a] > data[b];
7676+
}
7677+
}
7678+
};
7679+
76687680
static void ggml_compute_forward_argsort_f32(
76697681
const ggml_compute_params * params,
76707682
ggml_tensor * dst) {
@@ -7691,16 +7703,18 @@ static void ggml_compute_forward_argsort_f32(
76917703
dst_data[j] = j;
76927704
}
76937705

7694-
std::function<bool(int32_t, int32_t)> cmp;
7695-
7696-
// note: this might be causing memory allocations? ideally should be avoided if it's the case
76977706
switch (order) {
7698-
case GGML_SORT_ORDER_ASC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] < src_data[b]; }; break;
7699-
case GGML_SORT_ORDER_DESC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] > src_data[b]; }; break;
7700-
default: GGML_ABORT("invalid sort order");
7701-
}
7707+
case GGML_SORT_ORDER_ASC:
7708+
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_ASC>{src_data});
7709+
break;
77027710

7703-
std::sort(dst_data, dst_data + ne0, cmp);
7711+
case GGML_SORT_ORDER_DESC:
7712+
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_DESC>{src_data});
7713+
break;
7714+
7715+
default:
7716+
GGML_ABORT("invalid sort order");
7717+
}
77047718
}
77057719
}
77067720

tests/test-backend-ops.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7631,6 +7631,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
76317631
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, it));
76327632
}
76337633

7634+
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1}));
7635+
76347636
return test_cases;
76357637
}
76367638

0 commit comments

Comments
 (0)