@@ -7665,6 +7665,18 @@ void ggml_compute_forward_timestep_embedding(
76657665
76667666// ggml_compute_forward_argsort
76677667
7668+ template <enum ggml_sort_order order>
7669+ struct argsort_cmp {
7670+ const float * data;
7671+ bool operator ()(int32_t a, int32_t b) const {
7672+ if constexpr (order == GGML_SORT_ORDER_ASC) {
7673+ return data[a] < data[b];
7674+ } else {
7675+ return data[a] > data[b];
7676+ }
7677+ }
7678+ };
7679+
76687680static void ggml_compute_forward_argsort_f32 (
76697681 const ggml_compute_params * params,
76707682 ggml_tensor * dst) {
@@ -7691,16 +7703,18 @@ static void ggml_compute_forward_argsort_f32(
76917703 dst_data[j] = j;
76927704 }
76937705
7694- std::function<bool (int32_t , int32_t )> cmp;
7695-
7696- // note: this might be causing memory allocations? ideally should be avoided if it's the case
76977706 switch (order) {
7698- case GGML_SORT_ORDER_ASC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] < src_data[b]; }; break ;
7699- case GGML_SORT_ORDER_DESC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] > src_data[b]; }; break ;
7700- default : GGML_ABORT (" invalid sort order" );
7701- }
7707+ case GGML_SORT_ORDER_ASC:
7708+ std::sort (dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_ASC>{src_data});
7709+ break ;
77027710
7703- std::sort (dst_data, dst_data + ne0, cmp);
7711+ case GGML_SORT_ORDER_DESC:
7712+ std::sort (dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_DESC>{src_data});
7713+ break ;
7714+
7715+ default :
7716+ GGML_ABORT (" invalid sort order" );
7717+ }
77047718 }
77057719}
77067720
0 commit comments