@@ -417,18 +417,41 @@ void llama_model::load_arch(llama_model_loader & ml) {
417417 }
418418}
419419
420+ struct LLM_KV_MATCH_WITHOUT_ARCH {
421+ const LLM_KV kv_arch = LLM_KV(LLM_ARCH_UNKNOWN);
422+ const std::string kv_arch_prefix = llm_arch_name(LLM_ARCH_UNKNOWN);
423+
424+ bool operator()(const llm_kv & kv, const std::string & kv_name) const
425+ {
426+ std::string kv_match = kv_arch(kv);
427+ auto kv_arch_pos = kv_match.find(kv_arch_prefix);
428+
429+ return kv_name.find(kv_match.substr(kv_arch_pos == std::string::npos ? 0 : kv_arch_pos + kv_arch_prefix.size())) != std::string::npos;
430+ }
431+ };
432+
420433void llama_model::load_hparams(llama_model_loader & ml) {
421434 const gguf_context * ctx = ml.meta.get();
422435
423436 // get metadata as string
424437 for (int i = 0; i < gguf_get_n_kv(ctx); i++) {
438+ const char * name = gguf_get_key(ctx, i);
425439 gguf_type type = gguf_get_kv_type(ctx, i);
440+
426441 if (type == GGUF_TYPE_ARRAY) {
427- continue;
442+ if (LLM_KV_MATCH_WITHOUT_ARCH()(LLM_KV_CLASSIFIER_OUTPUT_LABELS, name)) {
443+ const size_t n_items = gguf_get_arr_n(ctx, i);
444+
445+ for (size_t j = 0; j < n_items; j++) {
446+ const std::string name_i = format("%s.%zu", name, j);
447+ const std::string value = gguf_get_arr_str(ctx, i, j);
448+ gguf_kv.emplace(name_i, value);
449+ }
450+ }
451+ } else {
452+ const std::string value = gguf_kv_to_str(ctx, i);
453+ gguf_kv.emplace(name, value);
428454 }
429- const char * name = gguf_get_key(ctx, i);
430- const std::string value = gguf_kv_to_str(ctx, i);
431- gguf_kv.emplace(name, value);
432455 }
433456
434457 // get general kv
@@ -13593,6 +13616,21 @@ int32_t llama_model_n_head_kv(const llama_model * model) {
1359313616 return model->hparams.n_head_kv();
1359413617}
1359513618
13619+ uint32_t llama_model_n_cls_out(const struct llama_model * model) {
13620+ return model->hparams.n_cls_out;
13621+ }
13622+
13623+ const char * llama_model_get_classifier_label_by_index(const struct llama_model * model, uint32_t i) {
13624+ const std::string key = format("%s.%u", LLM_KV(model->arch)(LLM_KV_CLASSIFIER_OUTPUT_LABELS).c_str(), i);
13625+ const auto & it = model->gguf_kv.find(key);
13626+
13627+ if (it != model->gguf_kv.end()) {
13628+ return it->second.c_str();
13629+ }
13630+
13631+ return nullptr;
13632+ }
13633+
1359613634// deprecated
1359713635int32_t llama_n_ctx_train(const llama_model * model) {
1359813636 return llama_model_n_ctx_train(model);
0 commit comments