Skip to content

Commit 38ece05

Browse files
authored
move labels to llama_model
1 parent fa61273 commit 38ece05

File tree

4 files changed

+10
-24
lines changed

4 files changed

+10
-24
lines changed

examples/embedding/embedding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ int main(int argc, char ** argv) {
241241
std::vector<std::string> cls_out_labels;
242242

243243
for (uint32_t i = 0; i < n_cls_out; i++) {
244-
const char * label = llama_model_get_classifier_label_by_index(model, i);
244+
const char * label = llama_model_cls_label(model, i);
245245
const std::string label_i = label == nullptr || strlen(label) == 0 ? std::to_string(i) : label;
246246
cls_out_labels.emplace_back(label_i);
247247
}

include/llama.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ extern "C" {
510510
LLAMA_API uint32_t llama_model_n_cls_out(const struct llama_model * model);
511511

512512
// Returns label of classifier output by index (<n_cls_out). Returns nullptr if no label provided
513-
LLAMA_API const char * llama_model_get_classifier_label_by_index(const struct llama_model * model, uint32_t i);
513+
LLAMA_API const char * llama_model_cls_label(const struct llama_model * model, uint32_t i);
514514

515515
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab);
516516

src/llama-model.cpp

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -417,19 +417,6 @@ void llama_model::load_arch(llama_model_loader & ml) {
417417
}
418418
}
419419

420-
struct LLM_KV_MATCH_WITHOUT_ARCH {
421-
const LLM_KV kv_arch = LLM_KV(LLM_ARCH_UNKNOWN);
422-
const std::string kv_arch_prefix = llm_arch_name(LLM_ARCH_UNKNOWN);
423-
424-
bool operator()(const llm_kv & kv, const std::string & kv_name) const
425-
{
426-
std::string kv_match = kv_arch(kv);
427-
auto kv_arch_pos = kv_match.find(kv_arch_prefix);
428-
429-
return kv_name.find(kv_match.substr(kv_arch_pos == std::string::npos ? 0 : kv_arch_pos + kv_arch_prefix.size())) != std::string::npos;
430-
}
431-
};
432-
433420
void llama_model::load_hparams(llama_model_loader & ml) {
434421
const gguf_context * ctx = ml.meta.get();
435422

@@ -439,13 +426,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
439426
gguf_type type = gguf_get_kv_type(ctx, i);
440427

441428
if (type == GGUF_TYPE_ARRAY) {
442-
if (LLM_KV_MATCH_WITHOUT_ARCH()(LLM_KV_CLASSIFIER_OUTPUT_LABELS, name)) {
429+
if (LLM_KV(arch)(LLM_KV_CLASSIFIER_OUTPUT_LABELS) == name) {
443430
const size_t n_items = gguf_get_arr_n(ctx, i);
444431

445432
for (size_t j = 0; j < n_items; j++) {
446-
const std::string name_i = format("%s.%zu", name, j);
447433
const std::string value = gguf_get_arr_str(ctx, i, j);
448-
gguf_kv.emplace(name_i, value);
434+
classifier_labels.emplace_back(value);
449435
}
450436
}
451437
} else {
@@ -13620,12 +13606,9 @@ uint32_t llama_model_n_cls_out(const struct llama_model * model) {
1362013606
return model->hparams.n_cls_out;
1362113607
}
1362213608

13623-
const char * llama_model_get_classifier_label_by_index(const struct llama_model * model, uint32_t i) {
13624-
const std::string key = format("%s.%u", LLM_KV(model->arch)(LLM_KV_CLASSIFIER_OUTPUT_LABELS).c_str(), i);
13625-
const auto & it = model->gguf_kv.find(key);
13626-
13627-
if (it != model->gguf_kv.end()) {
13628-
return it->second.c_str();
13609+
const char * llama_model_cls_label(const struct llama_model * model, uint32_t i) {
13610+
if (i < model->classifier_labels.size()) {
13611+
return model->classifier_labels[i].c_str();
1362913612
}
1363013613

1363113614
return nullptr;

src/llama-model.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,9 @@ struct llama_model {
363363
// for quantize-stats only
364364
std::vector<std::pair<std::string, struct ggml_tensor *>> tensors_by_name;
365365

366+
// for classifier models
367+
std::vector<std::string> classifier_labels;
368+
366369
int64_t t_load_us = 0;
367370
int64_t t_start_us = 0;
368371

0 commit comments

Comments
 (0)