Skip to content

Commit fa61273

Browse files
authored
show multiple rankings and associated labels
ggml-ci
1 parent 6ef43ba commit fa61273

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

examples/embedding/embedding.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "llama.h"
55

66
#include <ctime>
7+
#include <cstring>
78
#include <algorithm>
89

910
#if defined(_MSC_VER)
@@ -236,9 +237,24 @@ int main(int argc, char ** argv) {
236237
LOG("\n");
237238
}
238239
} else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
240+
const uint32_t n_cls_out = llama_model_n_cls_out(model);
241+
std::vector<std::string> cls_out_labels;
242+
243+
for (uint32_t i = 0; i < n_cls_out; i++) {
244+
const char * label = llama_model_get_classifier_label_by_index(model, i);
245+
const std::string label_i = label == nullptr || strlen(label) == 0 ? std::to_string(i) : label;
246+
cls_out_labels.emplace_back(label_i);
247+
}
248+
239249
for (int j = 0; j < n_embd_count; j++) {
240-
// NOTE: if you change this log - update the tests in ci/run.sh
241-
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
250+
for (uint32_t i = 0; i < n_cls_out; i++) {
251+
// NOTE: if you change this log - update the tests in ci/run.sh
252+
if (n_cls_out == 1) {
253+
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
254+
} else {
255+
LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd + i], cls_out_labels[i].c_str());
256+
}
257+
}
242258
}
243259
} else {
244260
// print the first part of the embeddings or for a single prompt, the full embedding

0 commit comments

Comments
 (0)