Skip to content

Commit ab55d76

Browse files
fixed type conversion
rare case, it was failing with tool argument type
1 parent 623f3dd commit ab55d76

File tree

1 file changed

+46
-34
lines changed

1 file changed

+46
-34
lines changed

common/chat.cpp

Lines changed: 46 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,14 +1384,14 @@ static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) {
13841384
auto handle_tool_call_end = [&] (common_chat_msg_parser & builder, auto end_pos) {
13851385
builder.move_to(end_pos);
13861386
builder.consume_literal("</tool_call>");
1387-
1387+
13881388
size_t obs_pos = builder.input().find("<|observation|>", builder.pos());
13891389
if (obs_pos != std::string::npos) {
13901390
if (obs_pos > builder.pos()) {
13911391
std::string content = builder.input().substr(builder.pos(), obs_pos - builder.pos());
13921392
builder.add_content(content);
13931393
}
1394-
1394+
13951395
builder.move_to(obs_pos);
13961396
builder.consume_literal("<|observation|>");
13971397
} else {
@@ -1401,91 +1401,105 @@ static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) {
14011401
};
14021402

14031403
builder.consume_spaces();
1404-
14051404
builder.try_parse_reasoning("<think>", "</think>");
1406-
1405+
14071406
size_t curr_pos = builder.pos();
14081407
while (builder.input().find("<tool_call>", builder.pos()) != std::string::npos) {
14091408
size_t tool_call_start = builder.input().find("<tool_call>", builder.pos());
14101409
if (tool_call_start > builder.pos()) {
14111410
std::string content = builder.input().substr(builder.pos(), tool_call_start - builder.pos());
14121411
builder.add_content(content);
14131412
}
1414-
1413+
14151414
size_t tool_call_end = builder.input().find("</tool_call>", tool_call_start);
14161415
if (tool_call_end == std::string::npos) return;
14171416

14181417
builder.move_to(tool_call_start);
14191418
builder.consume_literal("<tool_call>");
14201419
builder.consume_spaces();
14211420

1422-
size_t arg_key_start = builder.input().find("<arg_key>", tool_call_start);
1421+
size_t arg_key_start = builder.input().find("<arg_key>", builder.pos());
14231422
if (arg_key_start == std::string::npos || arg_key_start > tool_call_end) {
14241423
std::string function_content = builder.input().substr(builder.pos(), tool_call_end - builder.pos());
14251424
std::string function_name = string_strip(function_content);
1426-
1425+
14271426
if (!builder.add_tool_call(function_name, "", "{}")) {
14281427
LOG_INF("%s: failed to add tool call\n", __func__);
14291428
}
1430-
14311429
handle_tool_call_end(builder, tool_call_end);
1432-
14331430
} else {
14341431
std::string function_content = builder.input().substr(builder.pos(), arg_key_start - builder.pos());
14351432
std::string function_name = string_strip(function_content);
1436-
1433+
14371434
json args_json = json::object();
14381435
builder.move_to(arg_key_start);
1439-
1440-
while (builder.pos() < tool_call_end && builder.input().substr(builder.pos()).find("<arg_key>") == 0) {
1436+
1437+
while (builder.pos() < tool_call_end && builder.input().substr(builder.pos()).rfind("<arg_key>", 0) == 0) {
14411438
if (!builder.try_consume_literal("<arg_key>")) break;
1442-
1439+
14431440
auto key_close = builder.try_find_literal("</arg_key>");
14441441
if (!key_close || key_close->groups[0].end > tool_call_end) {
1445-
throw common_chat_msg_partial_exception("incomplete tool call");
1446-
return;
1442+
throw common_chat_msg_partial_exception("incomplete tool call (arg_key)");
14471443
}
1448-
14491444
std::string key = string_strip(key_close->prelude);
1450-
1445+
14511446
builder.consume_spaces();
1452-
14531447
if (!builder.try_consume_literal("<arg_value>")) {
1454-
throw common_chat_msg_partial_exception("incomplete tool call");
1455-
return;
1448+
throw common_chat_msg_partial_exception("incomplete tool call (arg_value)");
14561449
}
1457-
1450+
14581451
auto value_close = builder.try_find_literal("</arg_value>");
14591452
if (!value_close || value_close->groups[0].end > tool_call_end) {
1460-
throw common_chat_msg_partial_exception("incomplete tool call");
1461-
return;
1453+
throw common_chat_msg_partial_exception("incomplete tool call (arg_value content)");
14621454
}
1463-
14641455
std::string value = string_strip(value_close->prelude);
14651456

1466-
// Schema-aware type conversion
14671457
std::string expected_type = get_expected_type(function_name, key);
14681458
json parsed_value;
14691459

1470-
if (expected_type == "array" || expected_type == "object") {
1460+
if (expected_type == "integer" || expected_type == "number") {
1461+
try {
1462+
if (value.find('.') != std::string::npos) {
1463+
parsed_value = std::stod(value);
1464+
} else {
1465+
parsed_value = std::stoll(value);
1466+
}
1467+
} catch (const std::exception&) {
1468+
LOG_WRN("%s: Failed to parse '%s' as a number for key '%s', falling back to string.\n", __func__, value.c_str(), key.c_str());
1469+
parsed_value = value;
1470+
}
1471+
} else if (expected_type == "boolean") {
1472+
std::string lower_val = value;
1473+
std::transform(lower_val.begin(), lower_val.end(), lower_val.begin(),
1474+
[](unsigned char c){ return std::tolower(c); });
1475+
if (lower_val == "true" || lower_val == "1") {
1476+
parsed_value = true;
1477+
} else if (lower_val == "false" || lower_val == "0") {
1478+
parsed_value = false;
1479+
} else {
1480+
LOG_WRN("%s: Ambiguous boolean value '%s' for key '%s', falling back to string.\n", __func__, value.c_str(), key.c_str());
1481+
parsed_value = value;
1482+
}
1483+
} else if (expected_type == "array" || expected_type == "object") {
14711484
try {
14721485
parsed_value = json::parse(value);
1473-
} catch (...) {
1486+
} catch (const json::parse_error&) {
1487+
LOG_WRN("%s: Failed to parse '%s' as JSON for key '%s', falling back to raw string.\n", __func__, value.c_str(), key.c_str());
14741488
parsed_value = value;
14751489
}
14761490
} else {
1477-
// For all other types, store as string and let the unpacking logic handle it
1491+
// Default case is "string".
14781492
parsed_value = value;
14791493
}
1480-
1494+
14811495
args_json[key] = parsed_value;
14821496
builder.consume_spaces();
14831497
}
1484-
1498+
1499+
// This is a special case to handle when the model outputs a single JSON object as a string
14851500
if (args_json.size() == 1) {
14861501
const auto key = args_json.begin().key();
14871502
auto& value = args_json.begin().value();
1488-
14891503
if (value.is_string()) {
14901504
try {
14911505
json unpacked_json = json::parse(value.get<std::string>());
@@ -1503,12 +1517,10 @@ static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) {
15031517
} else {
15041518
LOG_INF("%s: successfully added tool call with arguments\n", __func__);
15051519
}
1506-
15071520
handle_tool_call_end(builder, tool_call_end);
15081521
}
1509-
1522+
15101523
if (curr_pos == builder.pos()) {
1511-
// No progress made, avoid infinite loop
15121524
LOG_INF("%s: no progress in parsing, stopping to avoid infinite loop\n", __func__);
15131525
break;
15141526
}

0 commit comments

Comments
 (0)