Skip to content

Commit 2c13a0e

Browse files
committed
model : add harmony parser for gpt-oss
1 parent e54d41b commit 2c13a0e

File tree

7 files changed

+849
-9
lines changed

7 files changed

+849
-9
lines changed

common/chat-parser.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ class common_chat_msg_partial_exception : public std::runtime_error {
1515
common_chat_msg_partial_exception(const std::string & message) : std::runtime_error(message) {}
1616
};
1717

18+
class common_chat_msg_parse_exception : public std::runtime_error {
19+
public:
20+
common_chat_msg_parse_exception(const std::string & message) : std::runtime_error(message) {}
21+
};
22+
1823
class common_chat_msg_parser {
1924
std::string input_;
2025
bool is_partial_;

common/chat.cpp

Lines changed: 150 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msg
296296
}
297297
if (!msg.reasoning_content.empty()) {
298298
jmsg["reasoning_content"] = msg.reasoning_content;
299+
jmsg["thinking"] = msg.reasoning_content; // gpt-oss
299300
}
300301
if (!msg.tool_name.empty()) {
301302
jmsg["name"] = msg.tool_name;
@@ -1314,17 +1315,160 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
13141315
data.prompt = prompt;
13151316
data.format = COMMON_CHAT_FORMAT_GPT_OSS;
13161317

1317-
// TODO: support tool calls in GPT-OSS?
1318+
if (inputs.tools.is_array() && !inputs.tools.empty()) {
1319+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1320+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1321+
std::vector<std::string> tool_rules;
1322+
foreach_function(inputs.tools, [&](const json & tool) {
1323+
const auto & function = tool.at("function");
1324+
std::string name = function.at("name");
1325+
auto parameters = function.at("parameters");
1326+
builder.resolve_refs(parameters);
1327+
1328+
tool_rules.push_back(builder.add_rule(name + "-call",
1329+
"\"" + name + "\"" + " space \"<|constrain|>\"? \"json\" space \"<|message|>\" " + builder.add_schema(name + "-args", parameters)
1330+
));
1331+
});
1332+
1333+
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | "));
1334+
builder.add_rule("root", "\"<|channel|>commentary\" space \"to=functions.\" " + tool_call);
1335+
1336+
data.grammar_triggers.push_back({
1337+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
1338+
"<\\|channel\\|>commentary\\s+to=functions\\."
1339+
});
1340+
1341+
data.preserved_tokens = {
1342+
"<|channel|>",
1343+
"<|constrain|>",
1344+
"<|message|>",
1345+
"<|start|>",
1346+
"<|end|>",
1347+
};
1348+
});
1349+
}
13181350

13191351
return data;
13201352
}
13211353
static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
1322-
// TODO @ngxson : this won't work with --special enabled, we should fix that
1323-
builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|start|>assistant<|channel|>final<|message|>");
1324-
if (!builder.syntax().parse_tool_calls) {
1325-
builder.add_content(builder.consume_rest());
1326-
return;
1354+
static const common_regex message_regex("<\\|message\\|>");
1355+
static const common_regex channel_regex("<\\|channel\\|>(final|analysis|commentary)");
1356+
static const common_regex start_regex("<\\|start\\|>assistant");
1357+
static const common_regex end_regex("<\\|end\\|>");
1358+
static const common_regex to_regex(" to=");
1359+
static const common_regex user_tool_call_regex(
1360+
"functions\\.([a-zA-Z_][a-zA-Z0-9_]*)\\s*(?:(?:<\\|constrain\\|>)?([a-zA-Z]+))?\\s*<\\|message\\|>"
1361+
);
1362+
static const common_regex builtin_tool_call_regex("(?:browser|python)[\\s\\S]*<\\|message\\|>");
1363+
1364+
// Save the start of the message so we can roll back when we encounter a tool call and parse_tool_calls == false.
1365+
size_t message_start_pos = 0;
1366+
1367+
// Similarly, save the channel start so we can roll back to defer reasoning parsing to builder.
1368+
size_t channel_start_pos = 0;
1369+
1370+
auto consume_until_next = [&](size_t from = std::string::npos) {
1371+
if (auto res = builder.try_find_regex(start_regex, from, false)) {
1372+
auto begin = res->groups[0].begin;
1373+
builder.move_to(begin);
1374+
return res->prelude;
1375+
}
1376+
return builder.consume_rest();
1377+
};
1378+
1379+
auto try_consume_message = [&]() {
1380+
if (builder.try_consume_regex(message_regex)) {
1381+
if (!builder.try_find_regex(end_regex)) {
1382+
builder.add_content(builder.consume_rest());
1383+
}
1384+
return true;
1385+
}
1386+
return false;
1387+
};
1388+
1389+
auto tool_call = [&]() {
1390+
if (!builder.syntax().parse_tool_calls) {
1391+
// Move back to the start and consume up to the next message
1392+
builder.move_to(message_start_pos);
1393+
builder.add_content(consume_until_next(message_start_pos + 1));
1394+
return;
1395+
}
1396+
1397+
if (auto res = builder.try_consume_regex(user_tool_call_regex)) {
1398+
auto name = builder.str(res->groups[1]);
1399+
if (auto args = builder.try_consume_json_with_dumped_args({{}})) {
1400+
if (!builder.add_tool_call(name, "", args->value) || args->is_partial) {
1401+
throw common_chat_msg_partial_exception("incomplete tool call");
1402+
}
1403+
}
1404+
} else if (builder.try_consume_regex(builtin_tool_call_regex)) {
1405+
builder.consume_rest();
1406+
LOG_ERR("builtin tool calls not implemented\n");
1407+
} else {
1408+
throw common_chat_msg_parse_exception("expected function call, got: " + consume_until_next());
1409+
}
1410+
};
1411+
1412+
auto commentary = [&]() {
1413+
if (builder.try_consume_regex(to_regex)) {
1414+
tool_call();
1415+
} else if (!try_consume_message()) {
1416+
throw common_chat_msg_parse_exception("expected: \" to=\" or <|message|>, got: " + consume_until_next());
1417+
}
1418+
};
1419+
1420+
auto analysis = [&]() {
1421+
if (builder.try_consume_regex(to_regex)) {
1422+
tool_call(); // built-in tools can be called in the analysis channel
1423+
} else if (builder.try_consume_regex(message_regex)) {
1424+
// Defer reasoning parsing to builder
1425+
builder.move_to(channel_start_pos);
1426+
1427+
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE) {
1428+
builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|end|>");
1429+
} else {
1430+
builder.add_content(consume_until_next());
1431+
}
1432+
} else {
1433+
throw common_chat_msg_parse_exception("expected: <|message|>, got: " + consume_until_next());
1434+
}
1435+
};
1436+
1437+
auto channel = [&]() {
1438+
if (auto res = builder.try_consume_regex(channel_regex)) {
1439+
channel_start_pos = res->groups[0].begin;
1440+
auto type = builder.str(res->groups[1]);
1441+
if (type == "analysis") {
1442+
analysis();
1443+
} else if (type == "commentary") {
1444+
commentary();
1445+
} else if (type == "final") {
1446+
if (!try_consume_message()) {
1447+
throw common_chat_msg_parse_exception("expected: <|message|>, got: " + consume_until_next());
1448+
}
1449+
}
1450+
} else {
1451+
throw common_chat_msg_parse_exception("expected: <|channel|>, got: " + consume_until_next());
1452+
}
1453+
};
1454+
1455+
try {
1456+
channel();
1457+
} catch (const common_chat_msg_parse_exception & e) {
1458+
LOG_DBG("Parse error: %s\n", e.what());
1459+
}
1460+
1461+
// Read in complete messages until done or partial exception raised
1462+
while (auto res = builder.try_consume_regex(start_regex)) {
1463+
message_start_pos = res->groups[0].begin;
1464+
try {
1465+
channel();
1466+
} catch (const common_chat_msg_parse_exception & e) {
1467+
LOG_DBG("Parse error: %s\n", e.what());
1468+
}
13271469
}
1470+
1471+
builder.consume_rest();
13281472
}
13291473

13301474
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {

0 commit comments

Comments
 (0)