Skip to content

Commit 152e903

Browse files
committed
llama : add classigication head (wip) [no ci]
1 parent 00f40ae commit 152e903

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

common/arg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
391391
[](gpt_params & params) {
392392
params.verbose_prompt = true;
393393
}
394-
).set_examples({LLAMA_EXAMPLE_MAIN}));
394+
));
395395
add_opt(llama_arg(
396396
{"--no-display-prompt"},
397397
format("don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false"),

src/llama.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11291,8 +11291,20 @@ struct llm_build_context {
1129111291
inpL = cur;
1129211292
}
1129311293

11294-
// final output
1129511294
cur = inpL;
11295+
11296+
// classification head
11297+
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
11298+
// TODO: become pooling layer?
11299+
if (model.cls) {
11300+
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.cls, cur), model.cls_b);
11301+
11302+
cur = ggml_tanh(ctx0, cur);
11303+
11304+
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);
11305+
// TODO: cur is now a scalar - what to do?
11306+
}
11307+
1129611308
cb(cur, "result_embd", -1);
1129711309

1129811310
ggml_build_forward_expand(gf, cur);

0 commit comments

Comments
 (0)