Skip to content

Commit d65e556

Browse files
committed
simplify logic by combining regex
1 parent e6a4578 commit d65e556

File tree

1 file changed

+33
-46
lines changed

1 file changed

+33
-46
lines changed

common/chat.cpp

Lines changed: 33 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,13 +1351,11 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
13511351
return data;
13521352
}
13531353
static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
1354-
static const common_regex assistant_regex("assistant");
13551354
static const common_regex message_regex("<\\|message\\|>");
1356-
static const common_regex channel_regex("<\\|channel\\|>");
1357-
static const common_regex start_regex("<\\|start\\|>");
1355+
static const common_regex channel_regex("<\\|channel\\|>(final|analysis|commentary)");
1356+
static const common_regex start_regex("<\\|start\\|>assistant");
13581357
static const common_regex end_regex("<\\|end\\|>");
13591358
static const common_regex to_regex(" to=");
1360-
static const common_regex channel_type_regexp("(final|analysis|commentary)");
13611359
static const common_regex user_tool_call_regex(
13621360
"functions\\.([a-zA-Z_][a-zA-Z0-9_]*)\\s*(?:(?:<\\|constrain\\|>)?([a-zA-Z]+))?\\s*<\\|message\\|>"
13631361
);
@@ -1378,7 +1376,24 @@ static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
13781376
return builder.consume_rest();
13791377
};
13801378

1379+
auto try_consume_message = [&]() {
1380+
if (builder.try_consume_regex(message_regex)) {
1381+
if (!builder.try_find_regex(end_regex)) {
1382+
builder.add_content(builder.consume_rest());
1383+
}
1384+
return true;
1385+
}
1386+
return false;
1387+
};
1388+
13811389
auto tool_call = [&]() {
1390+
if (!builder.syntax().parse_tool_calls) {
1391+
// Move back to the start and consume up to the next message
1392+
builder.move_to(message_start_pos);
1393+
builder.add_content(consume_until_next(message_start_pos + 1));
1394+
return;
1395+
}
1396+
13821397
if (auto res = builder.try_consume_regex(user_tool_call_regex)) {
13831398
auto name = builder.str(res->groups[1]);
13841399
if (auto args = builder.try_consume_json_with_dumped_args({{}})) {
@@ -1396,30 +1411,12 @@ static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
13961411

13971412
auto commentary = [&]() {
13981413
if (builder.try_consume_regex(to_regex)) {
1399-
if (builder.syntax().parse_tool_calls) {
1400-
tool_call();
1401-
} else {
1402-
// Move back to the start and consume up to the next message
1403-
builder.move_to(message_start_pos);
1404-
builder.add_content(consume_until_next(message_start_pos + 1));
1405-
}
1406-
} else if (builder.try_consume_regex(message_regex)) {
1407-
if (!builder.try_find_regex(end_regex)) {
1408-
builder.add_content(builder.consume_rest());
1409-
}
1410-
} else {
1414+
tool_call();
1415+
} else if (!try_consume_message()) {
14111416
throw common_chat_msg_parse_exception("expected: \" to=\" or <|message|>, got: " + consume_until_next());
14121417
}
14131418
};
14141419

1415-
auto final = [&]() {
1416-
if (builder.try_consume_regex(message_regex)) {
1417-
builder.add_content(builder.consume_rest());
1418-
} else {
1419-
throw common_chat_msg_parse_exception("expected: <|message|>, got: " + consume_until_next());
1420-
}
1421-
};
1422-
14231420
auto analysis = [&]() {
14241421
if (builder.try_consume_regex(to_regex)) {
14251422
tool_call(); // built-in tools can be called in the analysis channel
@@ -1438,44 +1435,34 @@ static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
14381435
};
14391436

14401437
auto channel = [&]() {
1441-
if (auto channel = builder.try_consume_regex(channel_regex)) {
1442-
channel_start_pos = channel->groups[0].begin;
1443-
if (auto res = builder.try_consume_regex(channel_type_regexp)) {
1444-
auto type = builder.str(res->groups[0]);
1445-
if (type == "analysis") {
1446-
analysis();
1447-
} else if (type == "final") {
1448-
final();
1449-
} else if (type == "commentary") {
1450-
commentary();
1438+
if (auto res = builder.try_consume_regex(channel_regex)) {
1439+
channel_start_pos = res->groups[0].begin;
1440+
auto type = builder.str(res->groups[1]);
1441+
if (type == "analysis") {
1442+
analysis();
1443+
} else if (type == "commentary") {
1444+
commentary();
1445+
} else if (type == "final") {
1446+
if (!try_consume_message()) {
1447+
throw common_chat_msg_parse_exception("expected: <|message|>, got: " + consume_until_next());
14511448
}
1452-
} else {
1453-
throw common_chat_msg_parse_exception("expected one of: analysis, final, commentary, got: " + consume_until_next());
14541449
}
14551450
} else {
14561451
throw common_chat_msg_parse_exception("expected: <|channel|>, got: " + consume_until_next());
14571452
}
14581453
};
14591454

1460-
auto start = [&]() {
1461-
if (builder.try_consume_regex(assistant_regex)) {
1462-
channel();
1463-
} else {
1464-
throw common_chat_msg_parse_exception("expected: <|assistant|>, got: " + consume_until_next());
1465-
}
1466-
};
1467-
14681455
try {
14691456
channel();
14701457
} catch (const common_chat_msg_parse_exception & e) {
14711458
LOG_DBG("Parse error: %s\n", e.what());
14721459
}
14731460

14741461
// Read in complete messages until done or partial exception raised
1475-
while (auto res = builder.try_find_literal("<|start|>")) {
1462+
while (auto res = builder.try_consume_regex(start_regex)) {
14761463
message_start_pos = res->groups[0].begin;
14771464
try {
1478-
start();
1465+
channel();
14791466
} catch (const common_chat_msg_parse_exception & e) {
14801467
LOG_DBG("Parse error: %s\n", e.what());
14811468
}

0 commit comments

Comments
 (0)