Skip to content

Commit 83c46ba

Browse files
committed
parser updates
1 parent 633455e commit 83c46ba

File tree

4 files changed

+78
-11
lines changed

4 files changed

+78
-11
lines changed

common/chat.cpp

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1318,7 +1318,37 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
13181318
data.prompt = prompt;
13191319
data.format = COMMON_CHAT_FORMAT_GPT_OSS;
13201320

1321-
// TODO: support tool calls in GPT-OSS?
1321+
if (inputs.tools.is_array() && !inputs.tools.empty()) {
1322+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1323+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1324+
std::vector<std::string> tool_rules;
1325+
foreach_function(inputs.tools, [&](const json & tool) {
1326+
const auto & function = tool.at("function");
1327+
std::string name = function.at("name");
1328+
auto parameters = function.at("parameters");
1329+
builder.resolve_refs(parameters);
1330+
1331+
tool_rules.push_back(builder.add_rule(name + "-call",
1332+
"\"" + name + "\"" + " space? \"<|constrain|>json<|message|>\" " + builder.add_schema(name + "-args", parameters)
1333+
));
1334+
});
1335+
1336+
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | "));
1337+
builder.add_rule("root", "\"<|channel|>commentary\" space \"to=functions.\" " + tool_call);
1338+
1339+
data.grammar_triggers.push_back({
1340+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
1341+
"<\\|channel\\|>commentary\\s+to=functions\\."
1342+
});
1343+
1344+
data.preserved_tokens = {
1345+
"<|channel|>",
1346+
"<|constrain|>",
1347+
"<|message|>",
1348+
"<|start|>",
1349+
};
1350+
});
1351+
}
13221352

13231353
return data;
13241354
}

common/parsers/harmony.cpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,25 @@
11
#include "harmony.h"
22
#include "regex-partial.h"
3+
#include "log.h"
34

45
harmony_msg_parser::harmony_msg_parser(common_chat_msg_parser & builder)
56
: builder(builder) {}
67

78
void harmony_msg_parser::parse() {
8-
// TODO @ngxson : this won't work with --special enabled, we should fix that
9-
//builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|start|>assistant<|channel|>final<|message|>");
10-
//builder.add_content(builder.consume_rest());
11-
channel();
9+
try {
10+
channel();
11+
} catch (const harmony_parse_error & e) {
12+
LOG_ERR("Parse error: %s", e.what());
13+
}
14+
15+
while (builder.try_find_literal("<|start|>")) {
16+
try {
17+
start();
18+
} catch(const harmony_parse_error & e) {
19+
LOG_ERR("Parse error: %s, skipping to next valid token", e.what());
20+
}
21+
}
22+
1223
builder.add_content(builder.consume_rest());
1324
}
1425

@@ -33,9 +44,6 @@ void harmony_msg_parser::analysis() {
3344
static const common_regex end("<\\|end\\|>");
3445
if (auto res = builder.try_find_regex(end, std::string::npos, false)) {
3546
builder.add_reasoning_content(res->prelude);
36-
if (builder.try_consume_literal("<|start|>")) {
37-
start();
38-
}
3947
} else {
4048
builder.add_reasoning_content(builder.consume_rest());
4149
}
@@ -59,15 +67,21 @@ void harmony_msg_parser::commentary() {
5967
if (builder.try_consume_regex(to)) {
6068
user_function();
6169
}
70+
71+
if (builder.try_consume_literal("<|message|>")) {
72+
static const common_regex end("<\\|end\\|>");
73+
if (!builder.try_find_regex(end)) {
74+
builder.add_content(builder.consume_rest());
75+
}
76+
}
6277
}
6378

6479
void harmony_msg_parser::user_function() {
6580
static const common_regex tool_call_regex(
66-
"functions\\.([a-zA-Z_][a-zA-Z0-9_]*)\\s*(<\\|constrain\\|>)?([a-z]+)<\\|message\\|>"
81+
"functions\\.([a-zA-Z_][a-zA-Z0-9_]*)\\s?(?:<\\|constrain\\|>([a-zA-Z]+))?<\\|message\\|>"
6782
);
6883
if (auto res = builder.try_consume_regex(tool_call_regex)) {
6984
auto name = builder.str(res->groups[1]);
70-
//auto constrain_type = builder.str(res->groups[3]);
7185
auto args = builder.consume_rest();
7286

7387
builder.add_tool_call(name, "", args);

common/parsers/harmony.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
#include <string>
77
#include <vector>
88

9+
class harmony_parse_error: public std::runtime_error {
10+
public:
11+
harmony_parse_error(const std::string & message) : std::runtime_error(message) {}
12+
};
13+
914
class harmony_msg_parser {
1015
common_chat_msg_parser & builder;
1116

tests/test-chat.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1725,7 +1725,25 @@ static void test_template_output_parsers() {
17251725
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
17261726
}));
17271727

1728-
1728+
assert_msg_equals(simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking"),
1729+
common_chat_parse(
1730+
"<|channel|>analysis<|message|>I'm\nthinking<|end|>"
1731+
"<|start|>assistant<|channel|>commentary<|message|>Hello, world!\nWhat's up?",
1732+
/* is_partial= */ true,
1733+
{
1734+
/* .format = */ COMMON_CHAT_FORMAT_GPT_OSS,
1735+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
1736+
}));
1737+
assert_msg_equals(simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking", "special_function", "{\"arg1\": 1}"),
1738+
common_chat_parse(
1739+
"<|channel|>analysis<|message|>I'm\nthinking<|end|>"
1740+
"<|start|>assistant<|channel|>commentary<|message|>Hello, world!\nWhat's up?<|end|>"
1741+
"<|start|>assistant<|channel|>commentary to=functions.special_function <|constrain|>json<|message|>{\"arg1\": 1}",
1742+
/* is_partial= */ true,
1743+
{
1744+
/* .format = */ COMMON_CHAT_FORMAT_GPT_OSS,
1745+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
1746+
}));
17291747
}
17301748
}
17311749

0 commit comments

Comments
 (0)