Skip to content

Commit a1b1018

Browse files
authored
add multiple classifier outputs and labels support
1 parent eb39499 commit a1b1018

File tree

2 files changed

+46
-7
lines changed

2 files changed

+46
-7
lines changed

src/llama-context.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -808,16 +808,17 @@ int llama_context::encode(llama_batch & inp_batch) {
808808
} break;
809809
case LLAMA_POOLING_TYPE_RANK:
810810
{
811-
// extract the rerank score - a single float per sequence
811+
// extract the rerank score - n_cls_out floats per sequence
812812
auto & embd_seq_out = embd_seq;
813+
const uint32_t n_cls_out = hparams.n_cls_out;
813814

814815
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
815816
const llama_seq_id seq_id = ubatch.seq_id[s][0];
816817
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
817818
continue;
818819
}
819-
embd_seq_out[seq_id].resize(1);
820-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
820+
embd_seq_out[seq_id].resize(n_cls_out);
821+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_id)*sizeof(float), n_cls_out*sizeof(float));
821822
}
822823
} break;
823824
case LLAMA_POOLING_TYPE_UNSPECIFIED:

src/llama-model.cpp

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -417,18 +417,41 @@ void llama_model::load_arch(llama_model_loader & ml) {
417417
}
418418
}
419419

420+
struct LLM_KV_MATCH_WITHOUT_ARCH {
421+
const LLM_KV kv_arch = LLM_KV(LLM_ARCH_UNKNOWN);
422+
const std::string kv_arch_prefix = llm_arch_name(LLM_ARCH_UNKNOWN);
423+
424+
bool operator()(const llm_kv & kv, const std::string & kv_name) const
425+
{
426+
std::string kv_match = kv_arch(kv);
427+
auto kv_arch_pos = kv_match.find(kv_arch_prefix);
428+
429+
return kv_name.find(kv_match.substr(kv_arch_pos == std::string::npos ? 0 : kv_arch_pos + kv_arch_prefix.size())) != std::string::npos;
430+
}
431+
};
432+
420433
void llama_model::load_hparams(llama_model_loader & ml) {
421434
const gguf_context * ctx = ml.meta.get();
422435

423436
// get metadata as string
424437
for (int i = 0; i < gguf_get_n_kv(ctx); i++) {
438+
const char * name = gguf_get_key(ctx, i);
425439
gguf_type type = gguf_get_kv_type(ctx, i);
440+
426441
if (type == GGUF_TYPE_ARRAY) {
427-
continue;
442+
if (LLM_KV_MATCH_WITHOUT_ARCH()(LLM_KV_CLASSIFIER_OUTPUT_LABELS, name)) {
443+
const size_t n_items = gguf_get_arr_n(ctx, i);
444+
445+
for (size_t j = 0; j < n_items; j++) {
446+
const std::string name_i = format("%s.%zu", name, j);
447+
const std::string value = gguf_get_arr_str(ctx, i, j);
448+
gguf_kv.emplace(name_i, value);
449+
}
450+
}
451+
} else {
452+
const std::string value = gguf_kv_to_str(ctx, i);
453+
gguf_kv.emplace(name, value);
428454
}
429-
const char * name = gguf_get_key(ctx, i);
430-
const std::string value = gguf_kv_to_str(ctx, i);
431-
gguf_kv.emplace(name, value);
432455
}
433456

434457
// get general kv
@@ -13593,6 +13616,21 @@ int32_t llama_model_n_head_kv(const llama_model * model) {
1359313616
return model->hparams.n_head_kv();
1359413617
}
1359513618

13619+
uint32_t llama_model_n_cls_out(const struct llama_model * model) {
13620+
return model->hparams.n_cls_out;
13621+
}
13622+
13623+
const char * llama_model_get_classifier_label_by_index(const struct llama_model * model, uint32_t i) {
13624+
const std::string key = format("%s.%u", LLM_KV(model->arch)(LLM_KV_CLASSIFIER_OUTPUT_LABELS).c_str(), i);
13625+
const auto & it = model->gguf_kv.find(key);
13626+
13627+
if (it != model->gguf_kv.end()) {
13628+
return it->second.c_str();
13629+
}
13630+
13631+
return nullptr;
13632+
}
13633+
1359613634
// deprecated
1359713635
int32_t llama_n_ctx_train(const llama_model * model) {
1359813636
return llama_model_n_ctx_train(model);

0 commit comments

Comments
 (0)