Skip to content

Commit 7856949

Browse files
author
ochafik
committed
reinstate tool call id logic, keep track of previously generated ids
1 parent f0ea330 commit 7856949

File tree

3 files changed

+19
-4
lines changed

3 files changed

+19
-4
lines changed

common/chat.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const comm
9292
auto & diff = diffs.emplace_back();
9393
diff.tool_call_index = idx;
9494
diff.tool_call_delta.name = newf.name;
95-
diff.tool_call_delta.id = newf.id;
95+
if (pref.id != newf.id) {
96+
diff.tool_call_delta.id = newf.id;
97+
}
9698
diff.tool_call_delta.arguments = args_diff;
9799
}
98100
}

common/chat.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,18 @@ struct common_chat_msg {
4242
bool empty() const {
4343
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
4444
}
45+
void ensure_tool_call_ids_set(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
46+
for (auto i = 0u; i < tool_calls.size(); i++) {
47+
if (ids_cache.size() <= i) {
48+
auto id = tool_calls[i].id;
49+
if (id.empty()) {
50+
id = gen_tool_call_id();
51+
}
52+
ids_cache.push_back(id);
53+
}
54+
tool_calls[i].id = ids_cache[i];
55+
}
56+
}
4557
bool operator==(const common_chat_msg & other) const {
4658
return role == other.role
4759
&& content == other.content

examples/server/server.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,6 +1282,7 @@ struct server_slot {
12821282
llama_token sampled;
12831283

12841284
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
1285+
std::vector<std::string> generated_tool_call_ids;
12851286

12861287
// stats
12871288
size_t n_sent_text = 0; // number of sent text character
@@ -1313,6 +1314,7 @@ struct server_slot {
13131314
generated_token_probs.clear();
13141315
generated_msg = {};
13151316
json_schema = json();
1317+
generated_tool_call_ids.clear();
13161318
}
13171319

13181320
bool is_non_causal() const {
@@ -2356,14 +2358,12 @@ struct server_context {
23562358
/* is_partial= */ true,
23572359
slot.params.oaicompat_chat_syntax);
23582360
if (!new_msg.empty()) {
2361+
new_msg.ensure_tool_call_ids_set(slot.generated_tool_call_ids, gen_tool_call_id);
23592362
slot.generated_msg = new_msg;
23602363
}
23612364
res->oaicompat_previous_msg = previous_msg;
23622365
res->oaicompat_new_msg = new_msg.empty() ? previous_msg : new_msg;
23632366

2364-
// res->previous_content = slot.generated_text.substr(0, slot.generated_text.size() - tkn.text_to_send.size());
2365-
// res->oaicompat_chat_format = slot.params.oaicompat_chat_format;
2366-
23672367
// populate res.probs_output
23682368
if (slot.params.sampling.n_probs > 0) {
23692369
res->prob_output = tkn; // copy the token probs
@@ -2409,6 +2409,7 @@ struct server_context {
24092409
res->content,
24102410
/* is_partial= */ slot.stop == STOP_TYPE_LIMIT,
24112411
slot.params.oaicompat_chat_syntax);
2412+
res->oaicompat_msg.ensure_tool_call_ids_set(slot.generated_tool_call_ids, gen_tool_call_id);
24122413
res->oaicompat_chat_syntax = slot.params.oaicompat_chat_syntax;
24132414

24142415
// populate res.probs_output

0 commit comments

Comments
 (0)