Skip to content

Commit 190c483

Browse files
authored
chat : reserve memory in compute_diffs and improve naming (#17729)
1 parent e7c2cf1 commit 190c483

File tree

2 files changed

+22
-14
lines changed

2 files changed

+22
-14
lines changed

common/chat.cpp

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -85,29 +85,36 @@ json common_chat_msg::to_json_oaicompat() const
8585
return message;
8686
}
8787

88-
std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg) {
88+
std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new) {
8989
std::vector<common_chat_msg_diff> diffs;
90-
if (previous_msg.reasoning_content != new_msg.reasoning_content) {
90+
if (msg_new.tool_calls.size() > msg_prv.tool_calls.size()) {
91+
diffs.reserve(msg_new.tool_calls.size() - msg_prv.tool_calls.size() + 3);
92+
} else {
93+
diffs.reserve(3);
94+
}
95+
96+
// TODO: these can become expensive for long messages - how to optimize?
97+
if (msg_prv.reasoning_content != msg_new.reasoning_content) {
9198
auto & diff = diffs.emplace_back();
92-
diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, new_msg.reasoning_content);
99+
diff.reasoning_content_delta = string_diff(msg_prv.reasoning_content, msg_new.reasoning_content);
93100
}
94-
if (previous_msg.content != new_msg.content) {
101+
if (msg_prv.content != msg_new.content) {
95102
auto & diff = diffs.emplace_back();
96-
diff.content_delta = string_diff(previous_msg.content, new_msg.content);
103+
diff.content_delta = string_diff(msg_prv.content, msg_new.content);
97104
}
98105

99-
if (new_msg.tool_calls.size() < previous_msg.tool_calls.size()) {
106+
if (msg_new.tool_calls.size() < msg_prv.tool_calls.size()) {
100107
throw std::runtime_error("Invalid diff: now finding less tool calls!");
101108
}
102109

103-
if (!previous_msg.tool_calls.empty()) {
104-
auto idx = previous_msg.tool_calls.size() - 1;
105-
const auto & pref = previous_msg.tool_calls[idx];
106-
const auto & newf = new_msg.tool_calls[idx];
110+
if (!msg_prv.tool_calls.empty()) {
111+
const auto idx = msg_prv.tool_calls.size() - 1;
112+
const auto & pref = msg_prv.tool_calls[idx];
113+
const auto & newf = msg_new.tool_calls[idx];
107114
if (pref.name != newf.name) {
108115
throw std::runtime_error("Invalid diff: tool call mismatch!");
109116
}
110-
auto args_diff = string_diff(pref.arguments, newf.arguments);
117+
const auto args_diff = string_diff(pref.arguments, newf.arguments);
111118
if (!args_diff.empty() || pref.id != newf.id) {
112119
auto & diff = diffs.emplace_back();
113120
diff.tool_call_index = idx;
@@ -118,11 +125,12 @@ std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const comm
118125
diff.tool_call_delta.arguments = args_diff;
119126
}
120127
}
121-
for (size_t idx = previous_msg.tool_calls.size(); idx < new_msg.tool_calls.size(); ++idx) {
128+
for (size_t idx = msg_prv.tool_calls.size(); idx < msg_new.tool_calls.size(); ++idx) {
122129
auto & diff = diffs.emplace_back();
123130
diff.tool_call_index = idx;
124-
diff.tool_call_delta = new_msg.tool_calls[idx];
131+
diff.tool_call_delta = msg_new.tool_calls[idx];
125132
}
133+
126134
return diffs;
127135
}
128136

common/chat.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ struct common_chat_msg_diff {
7777
size_t tool_call_index = std::string::npos;
7878
common_chat_tool_call tool_call_delta;
7979

80-
static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg);
80+
static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new);
8181

8282
bool operator==(const common_chat_msg_diff & other) const {
8383
return content_delta == other.content_delta

0 commit comments

Comments
 (0)