Skip to content

Commit 657b8a7

Browse files
authored
chat: handle gpt-oss return/end token inconsistency (ggml-org#15421)
This commit addresses an inconsistency during inference by adding a new member to the `templates_params` struct to indicate whether the chat is in inference mode. This allows the gpt-oss specific function `common_chat_params_init_gpt_oss` to check this flag and the `add_generation_prompt` flag to determine if it should replace the `<|return|>` token with the `<|end|>` token in the prompt. The motivation for this change is to ensure that the formatted prompt of past messages in `common_chat_format_single` matches the output of the formatted new message. The issue is that the gpt-oss template returns different end tags: `<|return|>` when `add_generation_prompt` is false, and `<|end|>` when `add_generation_prompt` is true. This causes the substring function to start at an incorrect position, resulting in tokenization starting with 'tart|>' instead of '<|start|>'. Resolves: ggml-org#15417
1 parent ec5ab1a commit 657b8a7

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

common/chat.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ struct templates_params {
147147
json extra_context;
148148
bool add_bos;
149149
bool add_eos;
150+
bool is_inference = true;
150151
};
151152

152153
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
@@ -1336,6 +1337,17 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
13361337
common_chat_params data;
13371338
auto prompt = apply(tmpl, inputs);
13381339

1340+
// Check if we need to replace the return token with end token during
1341+
// inference and without generation prompt. For more details see:
1342+
// https://github.com/ggml-org/llama.cpp/issues/15417
1343+
if (inputs.is_inference && !inputs.add_generation_prompt) {
1344+
static constexpr std::string_view return_token = "<|return|>";
1345+
static constexpr std::string_view end_token = "<|end|>";
1346+
if (size_t pos = prompt.rfind(return_token); pos != std::string::npos) {
1347+
prompt.replace(pos, return_token.length(), end_token);
1348+
}
1349+
}
1350+
13391351
data.prompt = prompt;
13401352
data.format = COMMON_CHAT_FORMAT_GPT_OSS;
13411353

0 commit comments

Comments
 (0)