Skip to content
20 changes: 18 additions & 2 deletions examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "llama.h"

#include <ctime>
#include <cstring>
#include <algorithm>

#if defined(_MSC_VER)
Expand Down Expand Up @@ -236,9 +237,24 @@ int main(int argc, char ** argv) {
LOG("\n");
}
} else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
const uint32_t n_cls_out = llama_model_n_cls_out(model);
std::vector<std::string> cls_out_labels;

for (uint32_t i = 0; i < n_cls_out; i++) {
const char * label = llama_model_get_classifier_label_by_index(model, i);
const std::string label_i = label == nullptr || strlen(label) == 0 ? std::to_string(i) : label;
cls_out_labels.emplace_back(label_i);
}

for (int j = 0; j < n_embd_count; j++) {
// NOTE: if you change this log - update the tests in ci/run.sh
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
for (uint32_t i = 0; i < n_cls_out; i++) {
// NOTE: if you change this log - update the tests in ci/run.sh
if (n_cls_out == 1) {
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
} else {
LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd + i], cls_out_labels[i].c_str());
}
}
}
} else {
// print the first part of the embeddings or for a single prompt, the full embedding
Expand Down
8 changes: 7 additions & 1 deletion include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,12 @@ extern "C" {
// Get the model's RoPE frequency scaling factor
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);

// Returns the number of classifier outputs (only valid for classifier models)
LLAMA_API uint32_t llama_model_n_cls_out(const struct llama_model * model);

// Returns label of classifier output by index (<n_cls_out). Returns nullptr if no label provided
LLAMA_API const char * llama_model_get_classifier_label_by_index(const struct llama_model * model, uint32_t i);

LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab);

LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab);
Expand Down Expand Up @@ -912,7 +918,7 @@ extern "C" {

// Get the embeddings for a sequence id
// Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
// when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence
// when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[i] with the rank(s) of the sequence
// otherwise: float[n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);

Expand Down
7 changes: 4 additions & 3 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -808,16 +808,17 @@ int llama_context::encode(llama_batch & inp_batch) {
} break;
case LLAMA_POOLING_TYPE_RANK:
{
// extract the rerank score - a single float per sequence
// extract the rerank score - n_cls_out floats per sequence
auto & embd_seq_out = embd_seq;
const uint32_t n_cls_out = hparams.n_cls_out;

for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
embd_seq_out[seq_id].resize(1);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
embd_seq_out[seq_id].resize(n_cls_out);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_id)*sizeof(float), n_cls_out*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
Expand Down
46 changes: 42 additions & 4 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,18 +417,41 @@ void llama_model::load_arch(llama_model_loader & ml) {
}
}

struct LLM_KV_MATCH_WITHOUT_ARCH {
const LLM_KV kv_arch = LLM_KV(LLM_ARCH_UNKNOWN);
const std::string kv_arch_prefix = llm_arch_name(LLM_ARCH_UNKNOWN);

bool operator()(const llm_kv & kv, const std::string & kv_name) const
{
std::string kv_match = kv_arch(kv);
auto kv_arch_pos = kv_match.find(kv_arch_prefix);

return kv_name.find(kv_match.substr(kv_arch_pos == std::string::npos ? 0 : kv_arch_pos + kv_arch_prefix.size())) != std::string::npos;
}
};

void llama_model::load_hparams(llama_model_loader & ml) {
const gguf_context * ctx = ml.meta.get();

// get metadata as string
for (int i = 0; i < gguf_get_n_kv(ctx); i++) {
const char * name = gguf_get_key(ctx, i);
gguf_type type = gguf_get_kv_type(ctx, i);

if (type == GGUF_TYPE_ARRAY) {
continue;
if (LLM_KV_MATCH_WITHOUT_ARCH()(LLM_KV_CLASSIFIER_OUTPUT_LABELS, name)) {
const size_t n_items = gguf_get_arr_n(ctx, i);

for (size_t j = 0; j < n_items; j++) {
const std::string name_i = format("%s.%zu", name, j);
const std::string value = gguf_get_arr_str(ctx, i, j);
gguf_kv.emplace(name_i, value);
}
}
} else {
const std::string value = gguf_kv_to_str(ctx, i);
gguf_kv.emplace(name, value);
}
const char * name = gguf_get_key(ctx, i);
const std::string value = gguf_kv_to_str(ctx, i);
gguf_kv.emplace(name, value);
}

// get general kv
Expand Down Expand Up @@ -13593,6 +13616,21 @@ int32_t llama_model_n_head_kv(const llama_model * model) {
return model->hparams.n_head_kv();
}

uint32_t llama_model_n_cls_out(const struct llama_model * model) {
return model->hparams.n_cls_out;
}

const char * llama_model_get_classifier_label_by_index(const struct llama_model * model, uint32_t i) {
const std::string key = format("%s.%u", LLM_KV(model->arch)(LLM_KV_CLASSIFIER_OUTPUT_LABELS).c_str(), i);
const auto & it = model->gguf_kv.find(key);

if (it != model->gguf_kv.end()) {
return it->second.c_str();
}

return nullptr;
}

// deprecated
int32_t llama_n_ctx_train(const llama_model * model) {
return llama_model_n_ctx_train(model);
Expand Down
Loading