File tree Expand file tree Collapse file tree 1 file changed +18
-2
lines changed Expand file tree Collapse file tree 1 file changed +18
-2
lines changed Original file line number Diff line number Diff line change 44#include " llama.h"
55
66#include < ctime>
7+ #include < cstring>
78#include < algorithm>
89
910#if defined(_MSC_VER)
@@ -236,9 +237,24 @@ int main(int argc, char ** argv) {
236237 LOG (" \n " );
237238 }
238239 } else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
240+ const uint32_t n_cls_out = llama_model_n_cls_out (model);
241+ std::vector<std::string> cls_out_labels;
242+
243+ for (uint32_t i = 0 ; i < n_cls_out; i++) {
244+ const char * label = llama_model_get_classifier_label_by_index (model, i);
245+ const std::string label_i = label == nullptr || strlen (label) == 0 ? std::to_string (i) : label;
246+ cls_out_labels.emplace_back (label_i);
247+ }
248+
239249 for (int j = 0 ; j < n_embd_count; j++) {
240- // NOTE: if you change this log - update the tests in ci/run.sh
241- LOG (" rerank score %d: %8.3f\n " , j, emb[j * n_embd]);
250+ for (uint32_t i = 0 ; i < n_cls_out; i++) {
251+ // NOTE: if you change this log - update the tests in ci/run.sh
252+ if (n_cls_out == 1 ) {
253+ LOG (" rerank score %d: %8.3f\n " , j, emb[j * n_embd]);
254+ } else {
255+ LOG (" rerank score %d: %8.3f [%s]\n " , j, emb[j * n_embd + i], cls_out_labels[i].c_str ());
256+ }
257+ }
242258 }
243259 } else {
244260 // print the first part of the embeddings or for a single prompt, the full embedding
You can’t perform that action at this time.
0 commit comments