Skip to content

Commit 583cb83

Browse files
authored
ggml : add ggml_top_k (ggml-org#17365)
* ggml : add ggml_top_k * cont : add ggml_argsort_top_k * metal : add top_k support * ggml : cleanup * tests : add virtual err() function for test_case * ggml : add comments
1 parent 05872ac commit 583cb83

File tree

15 files changed

+511
-80
lines changed

15 files changed

+511
-80
lines changed

ggml/include/ggml.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ extern "C" {
530530
GGML_OP_ARANGE,
531531
GGML_OP_TIMESTEP_EMBEDDING,
532532
GGML_OP_ARGSORT,
533+
GGML_OP_TOP_K,
533534
GGML_OP_LEAKY_RELU,
534535
GGML_OP_TRI,
535536
GGML_OP_FILL,
@@ -2258,18 +2259,25 @@ extern "C" {
22582259
struct ggml_tensor * a,
22592260
enum ggml_sort_order order);
22602261

2261-
GGML_API struct ggml_tensor * ggml_arange(
2262+
// similar to ggml_top_k but implemented as `argsort` + `view`
2263+
GGML_API struct ggml_tensor * ggml_argsort_top_k(
22622264
struct ggml_context * ctx,
2263-
float start,
2264-
float stop,
2265-
float step);
2265+
struct ggml_tensor * a,
2266+
int k);
22662267

22672268
// top k elements per row
2269+
// note: the resulting top k indices are in no particular order
22682270
GGML_API struct ggml_tensor * ggml_top_k(
22692271
struct ggml_context * ctx,
22702272
struct ggml_tensor * a,
22712273
int k);
22722274

2275+
GGML_API struct ggml_tensor * ggml_arange(
2276+
struct ggml_context * ctx,
2277+
float start,
2278+
float stop,
2279+
float step);
2280+
22732281
#define GGML_KQ_MASK_PAD 64
22742282

22752283
// q: [n_embd_k, n_batch, n_head, ne3 ]

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1927,6 +1927,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
19271927
{
19281928
ggml_compute_forward_argsort(params, tensor);
19291929
} break;
1930+
case GGML_OP_TOP_K:
1931+
{
1932+
ggml_compute_forward_top_k(params, tensor);
1933+
} break;
19301934
case GGML_OP_LEAKY_RELU:
19311935
{
19321936
ggml_compute_forward_leaky_relu(params, tensor);
@@ -2311,6 +2315,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
23112315
case GGML_OP_ARANGE:
23122316
case GGML_OP_TIMESTEP_EMBEDDING:
23132317
case GGML_OP_ARGSORT:
2318+
case GGML_OP_TOP_K:
23142319
case GGML_OP_FLASH_ATTN_EXT:
23152320
case GGML_OP_FLASH_ATTN_BACK:
23162321
case GGML_OP_SSM_CONV:
@@ -2834,6 +2839,10 @@ struct ggml_cplan ggml_graph_plan(
28342839
cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
28352840
cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
28362841
} break;
2842+
case GGML_OP_TOP_K:
2843+
{
2844+
cur += sizeof(int32_t)*node->src[0]->ne[0]*n_tasks;
2845+
} break;
28372846
case GGML_OP_FLASH_ATTN_EXT:
28382847
{
28392848
const int64_t ne10 = node->src[1]->ne[0]; // DK

ggml/src/ggml-cpu/ops.cpp

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7794,7 +7794,7 @@ void ggml_compute_forward_timestep_embedding(
77947794
// ggml_compute_forward_argsort
77957795

77967796
template<enum ggml_sort_order order>
7797-
struct argsort_cmp {
7797+
struct cmp_argsort {
77987798
const float * data;
77997799
bool operator()(int32_t a, int32_t b) const {
78007800
if constexpr (order == GGML_SORT_ORDER_ASC) {
@@ -7833,11 +7833,11 @@ static void ggml_compute_forward_argsort_f32(
78337833

78347834
switch (order) {
78357835
case GGML_SORT_ORDER_ASC:
7836-
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_ASC>{src_data});
7836+
std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_ASC>{src_data});
78377837
break;
78387838

78397839
case GGML_SORT_ORDER_DESC:
7840-
std::sort(dst_data, dst_data + ne0, argsort_cmp<GGML_SORT_ORDER_DESC>{src_data});
7840+
std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_DESC>{src_data});
78417841
break;
78427842

78437843
default:
@@ -7864,6 +7864,72 @@ void ggml_compute_forward_argsort(
78647864
}
78657865
}
78667866

7867+
// ggml_compute_forward_top_k
7868+
7869+
struct cmp_top_k {
7870+
const float * data;
7871+
bool operator()(int32_t a, int32_t b) const {
7872+
return data[a] > data[b];
7873+
}
7874+
};
7875+
7876+
static void ggml_compute_forward_top_k_f32(
7877+
const ggml_compute_params * params,
7878+
ggml_tensor * dst) {
7879+
7880+
const ggml_tensor * src0 = dst->src[0];
7881+
7882+
GGML_TENSOR_UNARY_OP_LOCALS
7883+
7884+
GGML_ASSERT(nb0 == sizeof(float));
7885+
7886+
const int ith = params->ith;
7887+
const int nth = params->nth;
7888+
7889+
const int64_t nr = ggml_nrows(src0);
7890+
7891+
const int top_k = ne0;
7892+
7893+
int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
7894+
7895+
for (int64_t i = ith; i < nr; i += nth) {
7896+
const float * src_data = (float *)((char *) src0->data + i*nb01);
7897+
7898+
for (int64_t j = 0; j < ne00; j++) {
7899+
tmp[j] = j;
7900+
}
7901+
7902+
std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
7903+
7904+
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
7905+
7906+
std::copy(tmp, tmp + top_k, dst_data);
7907+
7908+
// emphasize that the order is not important
7909+
if (top_k > 1) {
7910+
std::swap(dst_data[0], dst_data[1]);
7911+
}
7912+
}
7913+
}
7914+
7915+
void ggml_compute_forward_top_k(
7916+
const ggml_compute_params * params,
7917+
ggml_tensor * dst) {
7918+
7919+
const ggml_tensor * src0 = dst->src[0];
7920+
7921+
switch (src0->type) {
7922+
case GGML_TYPE_F32:
7923+
{
7924+
ggml_compute_forward_top_k_f32(params, dst);
7925+
} break;
7926+
default:
7927+
{
7928+
GGML_ABORT("fatal error");
7929+
}
7930+
}
7931+
}
7932+
78677933
// ggml_compute_forward_flash_attn_ext
78687934

78697935
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(

ggml/src/ggml-cpu/ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ void ggml_compute_forward_roll(const struct ggml_compute_params * params, struct
8181
void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst);
8282
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
8383
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
84+
void ggml_compute_forward_top_k(const struct ggml_compute_params * params, struct ggml_tensor * dst);
8485
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
8586
void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
8687
void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst);

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,64 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_l
10091009
return res;
10101010
}
10111011

1012+
// note: reuse the argsort kernel for top_k
1013+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) {
1014+
assert(op->op == GGML_OP_TOP_K);
1015+
1016+
char base[256];
1017+
char name[256];
1018+
1019+
// note: the top_k kernel is always descending order
1020+
ggml_sort_order order = GGML_SORT_ORDER_DESC;
1021+
1022+
const char * order_str = "undefined";
1023+
switch (order) {
1024+
case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
1025+
case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
1026+
default: GGML_ABORT("fatal error");
1027+
};
1028+
1029+
snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
1030+
snprintf(name, 256, "%s", base);
1031+
1032+
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1033+
if (res) {
1034+
return res;
1035+
}
1036+
1037+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1038+
1039+
return res;
1040+
}
1041+
1042+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
1043+
assert(op->op == GGML_OP_TOP_K);
1044+
1045+
char base[256];
1046+
char name[256];
1047+
1048+
ggml_sort_order order = GGML_SORT_ORDER_DESC;
1049+
1050+
const char * order_str = "undefined";
1051+
switch (order) {
1052+
case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
1053+
case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
1054+
default: GGML_ABORT("fatal error");
1055+
};
1056+
1057+
snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
1058+
snprintf(name, 256, "%s", base);
1059+
1060+
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1061+
if (res) {
1062+
return res;
1063+
}
1064+
1065+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1066+
1067+
return res;
1068+
}
1069+
10121070
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
10131071
ggml_metal_library_t lib,
10141072
const struct ggml_tensor * op,

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_me
128128
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
129129
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
130130
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
131+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op);
132+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
131133
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
132134
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
133135
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
905905
case GGML_OP_LEAKY_RELU:
906906
return op->src[0]->type == GGML_TYPE_F32;
907907
case GGML_OP_ARGSORT:
908+
case GGML_OP_TOP_K:
908909
case GGML_OP_ARANGE:
909910
return true;
910911
case GGML_OP_FLASH_ATTN_EXT:

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -832,14 +832,19 @@ typedef struct {
832832
} ggml_metal_kargs_leaky_relu;
833833

834834
typedef struct {
835-
int64_t ne00;
836-
int64_t ne01;
837-
int64_t ne02;
838-
int64_t ne03;
835+
int32_t ne00;
836+
int32_t ne01;
837+
int32_t ne02;
838+
int32_t ne03;
839839
uint64_t nb00;
840840
uint64_t nb01;
841841
uint64_t nb02;
842842
uint64_t nb03;
843+
int32_t ne0;
844+
int32_t ne1;
845+
int32_t ne2;
846+
int32_t ne3;
847+
int32_t top_k;
843848
} ggml_metal_kargs_argsort;
844849

845850
typedef struct {
@@ -851,6 +856,11 @@ typedef struct {
851856
uint64_t nb01;
852857
uint64_t nb02;
853858
uint64_t nb03;
859+
int32_t ne0;
860+
int32_t ne1;
861+
int32_t ne2;
862+
int32_t ne3;
863+
int32_t top_k;
854864
int32_t len;
855865
} ggml_metal_kargs_argsort_merge;
856866

0 commit comments

Comments
 (0)