Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 140 additions & 7 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msg
}
if (!msg.reasoning_content.empty()) {
jmsg["reasoning_content"] = msg.reasoning_content;
jmsg["thinking"] = msg.reasoning_content; // gpt-oss
}
if (!msg.tool_name.empty()) {
jmsg["name"] = msg.tool_name;
Expand Down Expand Up @@ -1338,17 +1339,149 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
data.prompt = prompt;
data.format = COMMON_CHAT_FORMAT_GPT_OSS;

// TODO: support tool calls in GPT-OSS?
// These special tokens are required to parse properly, so we include them
// even if parse_tool_calls is false.
data.preserved_tokens = {
"<|channel|>",
"<|constrain|>",
"<|message|>",
"<|start|>",
"<|end|>",
};

if (inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
// tool calls can appear in commentary or analysis channels
auto channel = builder.add_rule("channel", "\"<|channel|>\" ( \"commentary\" | \"analysis\" )");

std::vector<std::string> tool_rules_recipient_in_role;
std::vector<std::string> tool_rules_recipient_in_channel;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);

tool_rules_recipient_in_role.push_back(
builder.add_rule(name + "-call",
"\"" + name + "\"" + channel + " \" <|constrain|>json\"? \"<|message|>\" " +
builder.add_schema(name + "-args", parameters)
)
);

tool_rules_recipient_in_channel.push_back(
builder.add_rule(name + "-call",
"\"" + name + "\"" + " \" <|constrain|>json\"? \"<|message|>\" " +
builder.add_schema(name + "-args", parameters)
)
);
});

auto recipient_in_role = builder.add_rule("recipient_in_role",
"\"<|start|>assistant\"? \" to=functions.\" ( " +
string_join(tool_rules_recipient_in_role, " | ") + " )"
);

auto recipient_in_channel = builder.add_rule("recipient_in_channel",
channel + " \" to=functions.\" ( " +
string_join(tool_rules_recipient_in_channel, " | ") + " )"
);

builder.add_rule("root", recipient_in_role + " | " + recipient_in_channel);

// Trigger on tool calls that appear in the commentary channel
data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
"<\\|channel\\|>(commentary|analysis) to"
});

// Trigger tool calls that appear in the role section, either at the
// start or in the middle.
data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
"^ to"
});

data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
"<\\|start\\|>assistant to"
});
});
}

return data;
}
static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
// TODO @ngxson : this won't work with --special enabled, we should fix that
builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|start|>assistant<|channel|>final<|message|>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const std::string constraint = "(?: (<\\|constrain\\|>)?([a-zA-Z0-9_-]+))";
static const std::string recipient("(?: to=functions\\.([^<\\s]+))");

static const common_regex start_regex("<\\|start\\|>assistant");
static const common_regex end_regex("<\\|end\\|>");
static const common_regex basic_analysis_regex("<\\|channel\\|>analysis<\\|message\\|>");
static const common_regex final_regex("<\\|channel\\|>final" + constraint + "?<\\|message\\|>");
static const common_regex commentary_preamble_regex("<\\|channel\\|>commentary<\\|message\\|>");
static const common_regex tool_call1_regex(recipient + "<\\|channel\\|>(analysis|commentary)" + constraint + "?<\\|message\\|>");
static const common_regex tool_call2_regex("<\\|channel\\|>(analysis|commentary)" + recipient + constraint + "?<\\|message\\|>");

auto consume_end = [&](bool include_end = false) {
if (auto res = builder.try_find_regex(end_regex, std::string::npos, false)) {
return res->prelude + (include_end ? builder.str(res->groups[0]) : "");
}
return builder.consume_rest();
};

auto handle_tool_call = [&](const std::string & name) {
if (auto args = builder.try_consume_json_with_dumped_args({{}})) {
if (builder.syntax().parse_tool_calls) {
if (!builder.add_tool_call(name, "", args->value) || args->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
} else {
std::string args_as_string;
if (args->value.is_object()) {
args_as_string = args->value.dump();
} else {
args_as_string = args->value;
}

// simulate tool call in content
builder.add_content("<tool_call>");
builder.add_content("{\"name\": " + json(name).dump() + ", \"arguments\": ");
builder.add_content(args_as_string);
if (!args->is_partial) {
builder.add_content("}");
builder.add_content("</tool_call>");
} else {
throw common_chat_msg_partial_exception("incomplete tool call");
}
}
}
};

do {
if (auto res = builder.try_consume_regex(basic_analysis_regex)) {
builder.move_to(res->groups[0].begin);
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) {
builder.add_content(consume_end(true));
} else {
builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|end|>");
}
} else if (builder.try_consume_regex(final_regex)) {
builder.add_content(consume_end());
break;
} else if (builder.try_consume_regex(commentary_preamble_regex)) {
builder.add_content(consume_end());
} else if (auto res = builder.try_consume_regex(tool_call1_regex)) {
std::string name = builder.str(res->groups[1]);
handle_tool_call(name);
} else if (auto res = builder.try_consume_regex(tool_call2_regex)) {
std::string name = builder.str(res->groups[2]);
handle_tool_call(name);
}
} while (builder.try_find_regex(start_regex, std::string::npos, false));

builder.consume_rest();
}

static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
Expand Down
Loading