Skip to content

Commit 3a52f4c

Browse files
authored
update n_cls_out for any arch with labels
1 parent 76cf024 commit 3a52f4c

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

src/llama-model.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
552552
uint32_t n_vocab = 0;
553553
ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false);
554554

555+
// for classifier models
556+
if (!classifier_labels.empty()) {
557+
hparams.n_cls_out = classifier_labels.size();
558+
}
559+
555560
// arch-specific KVs
556561
switch (arch) {
557562
case LLM_ARCH_LLAMA:
@@ -695,7 +700,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
695700
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
696701
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
697702
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
698-
ml.get_arr_n(LLM_KV_CLASSIFIER_OUTPUT_LABELS, hparams.n_cls_out, false);
699703

700704
switch (hparams.n_layer) {
701705
case 3:

0 commit comments

Comments
 (0)