@@ -7794,7 +7794,7 @@ void ggml_compute_forward_timestep_embedding(
77947794// ggml_compute_forward_argsort
77957795
77967796template <enum ggml_sort_order order>
7797- struct argsort_cmp {
7797+ struct cmp_argsort {
77987798 const float * data;
77997799 bool operator ()(int32_t a, int32_t b) const {
78007800 if constexpr (order == GGML_SORT_ORDER_ASC) {
@@ -7833,11 +7833,11 @@ static void ggml_compute_forward_argsort_f32(
78337833
78347834 switch (order) {
78357835 case GGML_SORT_ORDER_ASC:
7836- std::sort (dst_data, dst_data + ne0, argsort_cmp <GGML_SORT_ORDER_ASC>{src_data});
7836+ std::sort (dst_data, dst_data + ne0, cmp_argsort <GGML_SORT_ORDER_ASC>{src_data});
78377837 break ;
78387838
78397839 case GGML_SORT_ORDER_DESC:
7840- std::sort (dst_data, dst_data + ne0, argsort_cmp <GGML_SORT_ORDER_DESC>{src_data});
7840+ std::sort (dst_data, dst_data + ne0, cmp_argsort <GGML_SORT_ORDER_DESC>{src_data});
78417841 break ;
78427842
78437843 default :
@@ -7864,6 +7864,72 @@ void ggml_compute_forward_argsort(
78647864 }
78657865}
78667866
7867+ // ggml_compute_forward_top_k
7868+
7869+ struct cmp_top_k {
7870+ const float * data;
7871+ bool operator ()(int32_t a, int32_t b) const {
7872+ return data[a] > data[b];
7873+ }
7874+ };
7875+
7876+ static void ggml_compute_forward_top_k_f32 (
7877+ const ggml_compute_params * params,
7878+ ggml_tensor * dst) {
7879+
7880+ const ggml_tensor * src0 = dst->src [0 ];
7881+
7882+ GGML_TENSOR_UNARY_OP_LOCALS
7883+
7884+ GGML_ASSERT (nb0 == sizeof (float ));
7885+
7886+ const int ith = params->ith ;
7887+ const int nth = params->nth ;
7888+
7889+ const int64_t nr = ggml_nrows (src0);
7890+
7891+ const int top_k = ne0;
7892+
7893+ int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
7894+
7895+ for (int64_t i = ith; i < nr; i += nth) {
7896+ const float * src_data = (float *)((char *) src0->data + i*nb01);
7897+
7898+ for (int64_t j = 0 ; j < ne00; j++) {
7899+ tmp[j] = j;
7900+ }
7901+
7902+ std::partial_sort (tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
7903+
7904+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
7905+
7906+ std::copy (tmp, tmp + top_k, dst_data);
7907+
7908+ // emphasize that the order is not important
7909+ if (top_k > 1 ) {
7910+ std::swap (dst_data[0 ], dst_data[1 ]);
7911+ }
7912+ }
7913+ }
7914+
7915+ void ggml_compute_forward_top_k (
7916+ const ggml_compute_params * params,
7917+ ggml_tensor * dst) {
7918+
7919+ const ggml_tensor * src0 = dst->src [0 ];
7920+
7921+ switch (src0->type ) {
7922+ case GGML_TYPE_F32:
7923+ {
7924+ ggml_compute_forward_top_k_f32 (params, dst);
7925+ } break ;
7926+ default :
7927+ {
7928+ GGML_ABORT (" fatal error" );
7929+ }
7930+ }
7931+ }
7932+
78677933// ggml_compute_forward_flash_attn_ext
78687934
78697935static void ggml_compute_forward_flash_attn_ext_f16_one_chunk (
0 commit comments