diff --git a/common/chat.cpp b/common/chat.cpp index 316bd24170c9e..2e17f3e625760 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1319,9 +1319,55 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp return data; } static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) { - // TODO @ngxson : this won't work with --special enabled, we should fix that - builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|start|>assistant<|channel|>final<|message|>"); + static const common_regex next_block( + "<\\|start\\|>assistant<\\|channel\\|>(?:commentary\\s+to=functions\\.([^\\s<\\|]+)(?:\\s+json)?|final)<\\|message\\|>"); + if (!builder.syntax().parse_tool_calls) { + if (auto res = builder.try_find_regex(next_block)) { + while (true) { + std::string fname = builder.str(res->groups[1]); + const std::string header = builder.str(res->groups[0]); + if (fname.size() > 4 && header.find(" json<|message|>") == std::string::npos) { + if (fname.rfind("json") == fname.size() - 4) { + fname.resize(fname.size() - 4); + } + } + if (!fname.empty()) { + if (!builder.try_consume_json_with_dumped_args({{}})) { + break; + } + res = builder.try_find_regex(next_block); + if (!res) break; + continue; + } + builder.add_content(builder.consume_rest()); + return; + } + } + builder.add_content(builder.consume_rest()); + return; + } + + while (true) { + auto res = builder.try_find_regex(next_block); + if (!res) { + builder.add_content(builder.consume_rest()); + return; + } + std::string fname = builder.str(res->groups[1]); + const std::string header = builder.str(res->groups[0]); + if (fname.size() > 4 && header.find(" json<|message|>") == std::string::npos) { + if (fname.rfind("json") == fname.size() - 4) { + fname.resize(fname.size() - 4); + } + } + if (!fname.empty()) { + auto arguments = builder.consume_json_with_dumped_args({{}}); + if (!builder.add_tool_call(fname, "", arguments.value) || arguments.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + continue; + } builder.add_content(builder.consume_rest()); return; } diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 99b4b4d5bac7b..f18cab744385b 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -1441,6 +1441,64 @@ static void test_template_output_parsers() { } } +static void test_gpt_oss_parser() { + printf("[%s]\n", __func__); + + // Standard commentary header with explicit ' json' + { + const std::string input = + "<|start|>assistant<|channel|>commentary to=functions.shell json<|message|>" \ + "{\\\"command\\\":[\\\"bash\\\",\\\"-lc\\\",\\\"echo hi\\\"]}" \ + "<|start|>assistant<|channel|>final<|message|>Done"; + common_chat_msg expected = simple_assist_msg("Done", "", "shell", + "{\"command\":[\"bash\",\"-lc\",\"echo hi\"]}"); + assert_msg_equals( + expected, + common_chat_parse( + input, + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_GPT_OSS})); + } + + // Commentary header where 'json' is glued to function name (e.g., 'shelljson') + { + const std::string input = + "<|start|>assistant<|channel|>commentary to=functions.shelljson<|message|>" \ + "{\\\"command\\\":[\\\"bash\\\",\\\"-lc\\\",\\\"echo hi\\\"]}" \ + "<|start|>assistant<|channel|>final<|message|>Done"; + common_chat_msg expected = simple_assist_msg("Done", "", "shell", + "{\"command\":[\"bash\",\"-lc\",\"echo hi\"]}"); + assert_msg_equals( + expected, + common_chat_parse( + input, + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_GPT_OSS})); + } + + // Multiple commentary tool calls then final + { + const std::string input = + "<|start|>assistant<|channel|>commentary to=functions.shell json<|message|>" \ + "{\\\"command\\\":[\\\"bash\\\",\\\"-lc\\\",\\\"echo hi\\\"]}" \ + "<|start|>assistant<|channel|>commentary to=functions.update_plan json<|message|>" \ + "{\\\"plan\\\":[{\\\"step\\\":\\\"x\\\",\\\"status\\\":\\\"in_progress\\\"}]}" \ + "<|start|>assistant<|channel|>final<|message|>ok"; + common_chat_msg expected; + expected.role = "assistant"; + expected.content = "ok"; + expected.tool_calls.push_back({"shell", "{\"command\":[\"bash\",\"-lc\",\"echo hi\"]}", ""}); + expected.tool_calls.push_back({"update_plan", "{\"plan\":[{\"step\":\"x\",\"status\":\"in_progress\"}]}", ""}); + + assert_msg_equals( + expected, + common_chat_parse( + input, + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_GPT_OSS})); + } +} + static void test_msg_diffs_compute() { printf("[%s]\n", __func__); { @@ -1564,6 +1622,7 @@ int main(int argc, char ** argv) { test_msgs_oaicompat_json_conversion(); test_tools_oaicompat_json_conversion(); test_template_output_parsers(); + test_gpt_oss_parser(); std::cout << "\n[chat] All tests passed!" << '\n'; } return 0;