Skip to content

Commit c3d646a

Browse files
committed
Add num_bos num_eos to GenerationConfig
1 parent d02a9b3 commit c3d646a

File tree

3 files changed

+18
-2
lines changed

3 files changed

+18
-2
lines changed

examples/models/llama/main.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,18 @@ DEFINE_int32(
4242
-1,
4343
"Number of CPU threads for inference. Defaults to -1, which implies we'll use a heuristic to derive the # of performant cores for a specific device.");
4444

45+
DEFINE_int32(
46+
num_bos,
47+
0,
48+
"Number of BOS tokens to prepend to the prompt. Defaults to 0. If > 0, the prompt will be prepended with BOS tokens. This is useful for models that expect one or more BOS token at the start."
49+
)
50+
51+
DEFINE_int32(
52+
num_eos,
53+
0,
54+
"Number of EOS tokens to append to the prompt. Defaults to 0. If > 0, the prompt will be appended with EOS tokens. This is useful for models that expect one or more EOS token at the end."
55+
)
56+
4557
DEFINE_bool(warmup, false, "Whether to run a warmup run.");
4658

4759
int32_t main(int32_t argc, char** argv) {

extension/llm/runner/irunner.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ struct GenerationConfig {
4949
// Temperature for sampling (higher = more random)
5050
float temperature = 0.8f;
5151

52+
// Number of eos and bos to add to the prompt
53+
int32_t num_bos = 0;
54+
int32_t num_eos = 0;
55+
5256
/**
5357
* Resolve the maximum number of new tokens to generate based on constraints.
5458
*

extension/llm/runner/text_llm_runner.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ Error TextLLMRunner::generate_from_pos(
117117

118118
::tokenizers::Result<std::vector<uint64_t>> encode_res = tokenizer_->encode(
119119
prompt,
120-
/* bos */ 0,
121-
/* eos */ 0);
120+
/*bos=*/config.num_bos,
121+
/*eos=*/config.num_eos);
122122

123123
ET_CHECK_TK_OK_OR_RETURN_ERROR(
124124
encode_res.error(), "Failed to encode prompt %s", prompt.c_str());

0 commit comments

Comments
 (0)