Skip to content

Commit 02a5e6c

Browse files
committed
model : add harmony parser for gpt-oss
1 parent 3ea913f commit 02a5e6c

File tree

7 files changed

+849
-9
lines changed

7 files changed

+849
-9
lines changed

common/chat-parser.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ class common_chat_msg_partial_exception : public std::runtime_error {
1515
common_chat_msg_partial_exception(const std::string & message) : std::runtime_error(message) {}
1616
};
1717

18+
class common_chat_msg_parse_exception : public std::runtime_error {
19+
public:
20+
common_chat_msg_parse_exception(const std::string & message) : std::runtime_error(message) {}
21+
};
22+
1823
class common_chat_msg_parser {
1924
std::string input_;
2025
bool is_partial_;

common/chat.cpp

Lines changed: 150 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msg
296296
}
297297
if (!msg.reasoning_content.empty()) {
298298
jmsg["reasoning_content"] = msg.reasoning_content;
299+
jmsg["thinking"] = msg.reasoning_content; // gpt-oss
299300
}
300301
if (!msg.tool_name.empty()) {
301302
jmsg["name"] = msg.tool_name;
@@ -1338,17 +1339,160 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
13381339
data.prompt = prompt;
13391340
data.format = COMMON_CHAT_FORMAT_GPT_OSS;
13401341

1341-
// TODO: support tool calls in GPT-OSS?
1342+
if (inputs.tools.is_array() && !inputs.tools.empty()) {
1343+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1344+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1345+
std::vector<std::string> tool_rules;
1346+
foreach_function(inputs.tools, [&](const json & tool) {
1347+
const auto & function = tool.at("function");
1348+
std::string name = function.at("name");
1349+
auto parameters = function.at("parameters");
1350+
builder.resolve_refs(parameters);
1351+
1352+
tool_rules.push_back(builder.add_rule(name + "-call",
1353+
"\"" + name + "\"" + " space \"<|constrain|>\"? \"json\" space \"<|message|>\" " + builder.add_schema(name + "-args", parameters)
1354+
));
1355+
});
1356+
1357+
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | "));
1358+
builder.add_rule("root", "\"<|channel|>commentary\" space \"to=functions.\" " + tool_call);
1359+
1360+
data.grammar_triggers.push_back({
1361+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
1362+
"<\\|channel\\|>commentary\\s+to=functions\\."
1363+
});
1364+
1365+
data.preserved_tokens = {
1366+
"<|channel|>",
1367+
"<|constrain|>",
1368+
"<|message|>",
1369+
"<|start|>",
1370+
"<|end|>",
1371+
};
1372+
});
1373+
}
13421374

13431375
return data;
13441376
}
13451377
static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
1346-
// TODO @ngxson : this won't work with --special enabled, we should fix that
1347-
builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|start|>assistant<|channel|>final<|message|>");
1348-
if (!builder.syntax().parse_tool_calls) {
1349-
builder.add_content(builder.consume_rest());
1350-
return;
1378+
static const common_regex message_regex("<\\|message\\|>");
1379+
static const common_regex channel_regex("<\\|channel\\|>(final|analysis|commentary)");
1380+
static const common_regex start_regex("<\\|start\\|>assistant");
1381+
static const common_regex end_regex("<\\|end\\|>");
1382+
static const common_regex to_regex(" to=");
1383+
static const common_regex user_tool_call_regex(
1384+
"functions\\.([a-zA-Z_][a-zA-Z0-9_]*)\\s*(?:(?:<\\|constrain\\|>)?([a-zA-Z]+))?\\s*<\\|message\\|>"
1385+
);
1386+
static const common_regex builtin_tool_call_regex("(?:browser|python)[\\s\\S]*<\\|message\\|>");
1387+
1388+
// Save the start of the message so we can roll back when we encounter a tool call and parse_tool_calls == false.
1389+
size_t message_start_pos = 0;
1390+
1391+
// Similarly, save the channel start so we can roll back to defer reasoning parsing to builder.
1392+
size_t channel_start_pos = 0;
1393+
1394+
auto consume_until_next = [&](size_t from = std::string::npos) {
1395+
if (auto res = builder.try_find_regex(start_regex, from, false)) {
1396+
auto begin = res->groups[0].begin;
1397+
builder.move_to(begin);
1398+
return res->prelude;
1399+
}
1400+
return builder.consume_rest();
1401+
};
1402+
1403+
auto try_consume_message = [&]() {
1404+
if (builder.try_consume_regex(message_regex)) {
1405+
if (!builder.try_find_regex(end_regex)) {
1406+
builder.add_content(builder.consume_rest());
1407+
}
1408+
return true;
1409+
}
1410+
return false;
1411+
};
1412+
1413+
auto tool_call = [&]() {
1414+
if (!builder.syntax().parse_tool_calls) {
1415+
// Move back to the start and consume up to the next message
1416+
builder.move_to(message_start_pos);
1417+
builder.add_content(consume_until_next(message_start_pos + 1));
1418+
return;
1419+
}
1420+
1421+
if (auto res = builder.try_consume_regex(user_tool_call_regex)) {
1422+
auto name = builder.str(res->groups[1]);
1423+
if (auto args = builder.try_consume_json_with_dumped_args({{}})) {
1424+
if (!builder.add_tool_call(name, "", args->value) || args->is_partial) {
1425+
throw common_chat_msg_partial_exception("incomplete tool call");
1426+
}
1427+
}
1428+
} else if (builder.try_consume_regex(builtin_tool_call_regex)) {
1429+
builder.consume_rest();
1430+
LOG_ERR("builtin tool calls not implemented\n");
1431+
} else {
1432+
throw common_chat_msg_parse_exception("expected function call, got: " + consume_until_next());
1433+
}
1434+
};
1435+
1436+
auto commentary = [&]() {
1437+
if (builder.try_consume_regex(to_regex)) {
1438+
tool_call();
1439+
} else if (!try_consume_message()) {
1440+
throw common_chat_msg_parse_exception("expected: \" to=\" or <|message|>, got: " + consume_until_next());
1441+
}
1442+
};
1443+
1444+
auto analysis = [&]() {
1445+
if (builder.try_consume_regex(to_regex)) {
1446+
tool_call(); // built-in tools can be called in the analysis channel
1447+
} else if (builder.try_consume_regex(message_regex)) {
1448+
// Defer reasoning parsing to builder
1449+
builder.move_to(channel_start_pos);
1450+
1451+
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE) {
1452+
builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|end|>");
1453+
} else {
1454+
builder.add_content(consume_until_next());
1455+
}
1456+
} else {
1457+
throw common_chat_msg_parse_exception("expected: <|message|>, got: " + consume_until_next());
1458+
}
1459+
};
1460+
1461+
auto channel = [&]() {
1462+
if (auto res = builder.try_consume_regex(channel_regex)) {
1463+
channel_start_pos = res->groups[0].begin;
1464+
auto type = builder.str(res->groups[1]);
1465+
if (type == "analysis") {
1466+
analysis();
1467+
} else if (type == "commentary") {
1468+
commentary();
1469+
} else if (type == "final") {
1470+
if (!try_consume_message()) {
1471+
throw common_chat_msg_parse_exception("expected: <|message|>, got: " + consume_until_next());
1472+
}
1473+
}
1474+
} else {
1475+
throw common_chat_msg_parse_exception("expected: <|channel|>, got: " + consume_until_next());
1476+
}
1477+
};
1478+
1479+
try {
1480+
channel();
1481+
} catch (const common_chat_msg_parse_exception & e) {
1482+
LOG_DBG("Parse error: %s\n", e.what());
1483+
}
1484+
1485+
// Read in complete messages until done or partial exception raised
1486+
while (auto res = builder.try_consume_regex(start_regex)) {
1487+
message_start_pos = res->groups[0].begin;
1488+
try {
1489+
channel();
1490+
} catch (const common_chat_msg_parse_exception & e) {
1491+
LOG_DBG("Parse error: %s\n", e.what());
1492+
}
13511493
}
1494+
1495+
builder.consume_rest();
13521496
}
13531497

13541498
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {

0 commit comments

Comments
 (0)