Skip to content

Commit d52579a

Browse files
author
Olivier Chafik
committed
prefer json::at to operator[] in chat.cpp
1 parent 37a4bb2 commit d52579a

File tree

1 file changed

+50
-50
lines changed

1 file changed

+50
-50
lines changed

common/chat.cpp

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,11 @@ static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& in
142142
result.role = "assistant";
143143
const auto process_tool_calls = [&](const json & tool_calls) {
144144
for (const auto & tool_call : tool_calls) {
145-
const auto & arguments = tool_call["arguments"];
145+
const auto & arguments = tool_call.at("arguments");
146146
result.tool_calls.push_back({
147-
tool_call["name"],
147+
tool_call.at("name"),
148148
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
149-
tool_call.contains("id") ? tool_call["id"] : "",
149+
tool_call.contains("id") ? tool_call.at("id") : "",
150150
});
151151
}
152152
};
@@ -163,7 +163,7 @@ static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& in
163163

164164
static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
165165
for (const auto & tool : tools) {
166-
if (!tool.contains("type") || tool["type"] != "function" || !tool.contains("function")) {
166+
if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
167167
LOG_INF("Skipping tool without function: %s", tool.dump(2).c_str());
168168
continue;
169169
}
@@ -198,27 +198,27 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
198198

199199
auto tool_call_schemas = json::array();
200200
foreach_function(inputs.tools, [&](const json & tool) {
201-
const auto & function = tool["function"];
201+
const auto & function = tool.at("function");
202202
auto tool_schema = json {
203203
{"type", "object"},
204204
{"properties", {
205205
{"name", {
206206
{"type", "string"},
207-
{"const", function["name"]},
207+
{"const", function.at("name")},
208208
}},
209-
{"arguments", function["parameters"]},
209+
{"arguments", function.at("parameters")},
210210
}},
211211
{"required", json::array({"name", "arguments"})},
212212
};
213213
if (function.contains("description")) {
214-
tool_schema["description"] = function["description"];
214+
tool_schema["description"] = function.at("description");
215215
}
216216
if (inputs.parallel_tool_calls) {
217-
tool_schema["properties"]["id"] = {
217+
tool_schema.at("properties")["id"] = {
218218
{"type", "string"},
219219
{"minLength", 4},
220220
};
221-
tool_schema["required"].push_back("id");
221+
tool_schema.at("required").push_back("id");
222222
}
223223
tool_call_schemas.emplace_back(tool_schema);
224224
});
@@ -283,21 +283,21 @@ static common_chat_msg common_chat_parse_generic(const std::string & input) {
283283
common_chat_msg result;
284284
result.role = "assistant";
285285
if (data.contains("tool_calls")) {
286-
for (const auto & tool_call : data["tool_calls"]) {
286+
for (const auto & tool_call : data.at("tool_calls")) {
287287
result.tool_calls.push_back({
288-
tool_call["name"],
289-
tool_call["arguments"].dump(),
290-
tool_call.contains("id") ? tool_call["id"] : "",
288+
tool_call.at("name"),
289+
tool_call.at("arguments").dump(),
290+
tool_call.contains("id") ? tool_call.at("id") : "",
291291
});
292292
}
293293
} else if (data.contains("tool_call")) {
294294
result.tool_calls.push_back({
295-
data["tool_call"]["name"],
296-
data["tool_call"]["arguments"].dump(),
295+
data.at("tool_call").at("name"),
296+
data.at("tool_call").at("arguments").dump(),
297297
/* id= */ "",
298298
});
299299
} else if (data.contains("response")) {
300-
const auto & response = data["response"];
300+
const auto & response = data.at("response");
301301
result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
302302
}
303303
return result;
@@ -309,17 +309,17 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
309309
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
310310
auto schemas = json::array();
311311
foreach_function(inputs.tools, [&](const json & tool) {
312-
const auto & function = tool["function"];
312+
const auto & function = tool.at("function");
313313
schemas.push_back({
314314
{"type", "object"},
315315
{"properties", {
316316
// Important note: the model is probably trained to take a JSON stringified arguments value.
317317
// It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object.
318318
{"name", {
319319
{"type", "string"},
320-
{"const", function["name"]},
320+
{"const", function.at("name")},
321321
}},
322-
{"arguments", function["parameters"]},
322+
{"arguments", function.at("parameters")},
323323
{"id", {
324324
{"type", "string"},
325325
// Nemo's template expects a 9-character alphanumeric ID.
@@ -354,7 +354,7 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
354354
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
355355
auto schemas = json::array();
356356
foreach_function(inputs.tools, [&](const json & tool) {
357-
const auto & function = tool["function"];
357+
const auto & function = tool.at("function");
358358
schemas.push_back({
359359
{"type", "object"},
360360
{"properties", {
@@ -365,9 +365,9 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
365365
}},
366366
{"tool_name", {
367367
{"type", "string"},
368-
{"const", function["name"]},
368+
{"const", function.at("name")},
369369
}},
370-
{"parameters", function["parameters"]},
370+
{"parameters", function.at("parameters")},
371371
}},
372372
{"required", json::array({"tool_call_id", "tool_name", "parameters"})},
373373
});
@@ -392,11 +392,11 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_
392392
};
393393
auto adjusted_messages = json::array();
394394
for (const auto & msg : inputs.messages) {
395-
auto has_reasoning_content = msg.contains("reasoning_content") && msg["reasoning_content"].is_string();
396-
auto has_tool_calls = msg.contains("tool_calls") && msg["tool_calls"].is_array();
395+
auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string();
396+
auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array();
397397
if (has_reasoning_content && has_tool_calls) {
398398
auto adjusted_message = msg;
399-
adjusted_message["tool_plan"] = msg["reasoning_content"];
399+
adjusted_message["tool_plan"] = msg.at("reasoning_content");
400400
adjusted_message.erase("reasoning_content");
401401
adjusted_messages.push_back(adjusted_message);
402402
} else {
@@ -433,9 +433,9 @@ static common_chat_msg common_chat_parse_command_r7b(const std::string & input,
433433
auto actions = json::parse(actions_str);
434434
for (const auto & action : actions) {
435435
result.tool_calls.push_back({
436-
/* .name = */ action["tool_name"],
437-
/* .arguments = */ action["parameters"].dump(),
438-
/* .id = */ action["tool_call_id"],
436+
/* .name = */ action.at("tool_name"),
437+
/* .arguments = */ action.at("parameters").dump(),
438+
/* .id = */ action.at("tool_call_id"),
439439
});
440440
}
441441
} else if (std::regex_match(rest, match, response_regex)) {
@@ -448,7 +448,7 @@ static common_chat_msg common_chat_parse_command_r7b(const std::string & input,
448448
}
449449

450450
static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
451-
if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
451+
if (!parameters.is_object() || !parameters.contains("type") || parameters.at("type") != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
452452
throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties");
453453
}
454454
const auto & parameters_properties = parameters.at("properties");
@@ -502,9 +502,9 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
502502
};
503503

504504
foreach_function(inputs.tools, [&](const json & tool) {
505-
const auto & function = tool["function"];
506-
std::string name = function["name"];
507-
auto parameters = function["parameters"];
505+
const auto & function = tool.at("function");
506+
std::string name = function.at("name");
507+
auto parameters = function.at("parameters");
508508
builder.resolve_refs(parameters);
509509

510510
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
@@ -585,9 +585,9 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
585585
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
586586
std::vector<std::string> tool_rules;
587587
foreach_function(inputs.tools, [&](const json & tool) {
588-
const auto & function = tool["function"];
589-
std::string name = function["name"];
590-
auto parameters = function["parameters"];
588+
const auto & function = tool.at("function");
589+
std::string name = function.at("name");
590+
auto parameters = function.at("parameters");
591591
auto args_rule = builder.add_schema(name + "-args", parameters);
592592
tool_rules.push_back(builder.add_rule(name + "-call",
593593
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n"
@@ -678,15 +678,15 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
678678
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
679679
auto schemas = json::array();
680680
foreach_function(inputs.tools, [&](const json & tool) {
681-
const auto & function = tool["function"];
681+
const auto & function = tool.at("function");
682682
schemas.push_back({
683683
{"type", "object"},
684684
{"properties", {
685685
{"name", {
686686
{"type", "string"},
687-
{"const", function["name"]},
687+
{"const", function.at("name")},
688688
}},
689-
{"arguments", function["parameters"]},
689+
{"arguments", function.at("parameters")},
690690
}},
691691
{"required", json::array({"name", "arguments", "id"})},
692692
});
@@ -724,9 +724,9 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
724724
std::vector<std::string> first_tool_rules;
725725
std::vector<std::string> subsequent_tool_rules;
726726
foreach_function(inputs.tools, [&](const json & tool) {
727-
const auto & function = tool["function"];
728-
std::string name = function["name"];
729-
auto parameters = function["parameters"];
727+
const auto & function = tool.at("function");
728+
std::string name = function.at("name");
729+
auto parameters = function.at("parameters");
730730
auto args_rule = builder.add_schema(name + "-args", parameters);
731731
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
732732
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
@@ -806,9 +806,9 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
806806
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
807807
std::vector<std::string> tool_rules;
808808
foreach_function(inputs.tools, [&](const json & tool) {
809-
const auto & function = tool["function"];
810-
const auto & parameters = function["parameters"];
811-
std::string name = function["name"];
809+
const auto & function = tool.at("function");
810+
const auto & parameters = function.at("parameters");
811+
std::string name = function.at("name");
812812
if (name == "python" || name == "ipython") {
813813
if (!parameters.contains("type")) {
814814
throw std::runtime_error("Missing type in python tool");
@@ -879,9 +879,9 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
879879
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
880880
std::vector<std::string> tool_rules;
881881
foreach_function(inputs.tools, [&](const json & tool) {
882-
const auto & function = tool["function"];
883-
std::string name = function["name"];
884-
auto parameters = function["parameters"];
882+
const auto & function = tool.at("function");
883+
std::string name = function.at("name");
884+
auto parameters = function.at("parameters");
885885
builder.resolve_refs(parameters);
886886
tool_rules.push_back(builder.add_schema(name + "-call", {
887887
{"type", "object"},
@@ -929,9 +929,9 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
929929
if (!parse_json(it, end, call)) {
930930
throw std::runtime_error("Failed to parse json tool call");
931931
}
932-
const auto & arguments = call["arguments"];
932+
const auto & arguments = call.at("arguments");
933933
result.tool_calls.push_back({
934-
call["name"],
934+
call.at("name"),
935935
arguments.dump(),
936936
// arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
937937
/* id= */ "",

0 commit comments

Comments
 (0)