Skip to content

Commit 12512e1

Browse files
authored
rename to n_cls_out
1 parent 2160aef commit 12512e1

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/llama-hparams.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ struct llama_hparams {
132132
bool use_kq_norm = true;
133133

134134
// for Classifiers
135-
uint32_t n_cls_out_labels = 1;
135+
uint32_t n_cls_out = 1;
136136

137137
// llama4
138138
uint32_t n_moe_layer_step = 0;

src/llama-model.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
683683
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
684684
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
685685
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
686-
ml.get_arr_n(LLM_KV_CLASSIFIER_OUTPUT_LABELS, hparams.n_cls_out_labels, false);
686+
ml.get_arr_n(LLM_KV_CLASSIFIER_OUTPUT_LABELS, hparams.n_cls_out, false);
687687

688688
switch (hparams.n_layer) {
689689
case 3:
@@ -2122,8 +2122,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
21222122
cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED);
21232123
cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED);
21242124

2125-
cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out_labels}, TENSOR_NOT_REQUIRED);
2126-
cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out_labels}, TENSOR_NOT_REQUIRED);
2125+
cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
2126+
cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
21272127
}
21282128

21292129
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);

0 commit comments

Comments
 (0)