|
7 | 7 | #include "unary-ops.h" |
8 | 8 | #include "vec.h" |
9 | 9 |
|
10 | | -#include <float.h> |
| 10 | +#include <cfloat> |
11 | 11 | #include <algorithm> |
| 12 | +#include <functional> |
12 | 13 |
|
13 | 14 | // ggml_compute_forward_dup |
14 | 15 |
|
@@ -7682,24 +7683,24 @@ static void ggml_compute_forward_argsort_f32( |
7682 | 7683 | ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0); |
7683 | 7684 |
|
7684 | 7685 | for (int64_t i = ith; i < nr; i += nth) { |
7685 | | - int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1); |
7686 | 7686 | const float * src_data = (float *)((char *) src0->data + i*nb01); |
7687 | 7687 |
|
| 7688 | + int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1); |
| 7689 | + |
7688 | 7690 | for (int64_t j = 0; j < ne0; j++) { |
7689 | 7691 | dst_data[j] = j; |
7690 | 7692 | } |
7691 | 7693 |
|
7692 | | - // C doesn't have a functional sort, so we do a bubble sort instead |
7693 | | - for (int64_t j = 0; j < ne0; j++) { |
7694 | | - for (int64_t k = j + 1; k < ne0; k++) { |
7695 | | - if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) || |
7696 | | - (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) { |
7697 | | - int32_t tmp = dst_data[j]; |
7698 | | - dst_data[j] = dst_data[k]; |
7699 | | - dst_data[k] = tmp; |
7700 | | - } |
7701 | | - } |
| 7694 | + std::function<bool(int32_t, int32_t)> cmp; |
| 7695 | + |
| 7696 | + // note: this might be causing memory allocations? ideally should be avoided if it's the case |
| 7697 | + switch (order) { |
| 7698 | + case GGML_SORT_ORDER_ASC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] < src_data[b]; }; break; |
| 7699 | + case GGML_SORT_ORDER_DESC: cmp = [src_data](int32_t a, int32_t b) { return src_data[a] > src_data[b]; }; break; |
| 7700 | + default: GGML_ABORT("invalid sort order"); |
7702 | 7701 | } |
| 7702 | + |
| 7703 | + std::sort(dst_data, dst_data + ne0, cmp); |
7703 | 7704 | } |
7704 | 7705 | } |
7705 | 7706 |
|
|
0 commit comments