Skip to content

Commit ce2f593

Browse files
author
ochafik
committed
add common_chat_msg_diff
1 parent a95fe78 commit ce2f593

File tree

3 files changed

+167
-0
lines changed

3 files changed

+167
-0
lines changed

common/chat.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,44 @@
55
#include "minja/minja.hpp"
66

77
#include <optional>
8+
std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg) {
9+
std::vector<common_chat_msg_diff> diffs;
10+
// if (previous_msg.reasoning_content != current.reasoning_content) {
11+
// auto & diff = diffs.emplace_back();
12+
// diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, current.reasoning_content);
13+
// }
14+
if (previous_msg.content != new_msg.content) {
15+
auto & diff = diffs.emplace_back();
16+
diff.content_delta = string_diff(previous_msg.content, new_msg.content);
17+
}
18+
19+
if (new_msg.tool_calls.size() < previous_msg.tool_calls.size()) {
20+
throw std::runtime_error("Invalid diff: now finding less tool calls!");
21+
}
22+
23+
if (!previous_msg.tool_calls.empty()) {
24+
auto idx = previous_msg.tool_calls.size() - 1;
25+
const auto & pref = previous_msg.tool_calls[idx];
26+
const auto & newf = new_msg.tool_calls[idx];
27+
if (pref.name != newf.name) {
28+
throw std::runtime_error("Invalid diff: tool call mismatch!");
29+
}
30+
auto args_diff = string_diff(pref.arguments, newf.arguments);
31+
if (!args_diff.empty() || pref.id != newf.id) {
32+
auto & diff = diffs.emplace_back();
33+
diff.tool_call_index = idx;
34+
diff.tool_call_delta.name = newf.name;
35+
diff.tool_call_delta.id = newf.id;
36+
diff.tool_call_delta.arguments = args_diff;
37+
}
38+
}
39+
for (size_t idx = previous_msg.tool_calls.size(); idx < new_msg.tool_calls.size(); ++idx) {
40+
auto & diff = diffs.emplace_back();
41+
diff.tool_call_index = idx;
42+
diff.tool_call_delta = new_msg.tool_calls[idx];
43+
}
44+
return diffs;
45+
}
846

947
typedef minja::chat_template common_chat_template;
1048

common/chat.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ struct common_chat_tool_call {
1212
std::string name;
1313
std::string arguments;
1414
std::string id;
15+
16+
bool operator==(const common_chat_tool_call & other) const {
17+
return name == other.name && arguments == other.arguments && id == other.id;
18+
}
1519
};
1620

1721
struct common_chat_msg_content_part {
@@ -27,6 +31,10 @@ struct common_chat_msg {
2731
std::string reasoning_content;
2832
std::string tool_name;
2933
std::string tool_call_id;
34+
35+
bool empty() const {
36+
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
37+
}
3038
};
3139

3240
struct common_chat_tool {
@@ -133,3 +141,18 @@ template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common
133141
// T can be std::string containing JSON or nlohmann::ordered_json
134142
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
135143
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
144+
145+
struct common_chat_msg_diff {
146+
// std::string reasoning_content_delta;
147+
std::string content_delta;
148+
size_t tool_call_index = std::string::npos;
149+
common_chat_tool_call tool_call_delta;
150+
151+
static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg);
152+
153+
bool operator==(const common_chat_msg_diff & other) const {
154+
return content_delta == other.content_delta
155+
&& tool_call_index == other.tool_call_index
156+
&& tool_call_delta == other.tool_call_delta;
157+
}
158+
};

tests/test-chat.cpp

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,27 @@
1616

1717
using json = nlohmann::ordered_json;
1818

19+
static std::ostream & operator<<(std::ostream & os, const common_chat_msg_diff & diff) {
20+
// os << "reasoning_content_delta: " << diff.reasoning_content_delta << '\n';
21+
os << "{ content_delta: " << diff.content_delta << "; ";
22+
if (diff.tool_call_index != std::string::npos) {
23+
os << "tool_call_index: " << diff.tool_call_index << "; ";
24+
os << "tool_call_delta.name: " << diff.tool_call_delta.name << "; ";
25+
os << "tool_call_delta.id: " << diff.tool_call_delta.id << "; ";
26+
os << "tool_call_delta.arguments: " << diff.tool_call_delta.arguments << "; ";
27+
}
28+
os << "}";
29+
return os;
30+
}
31+
// operator<< for vector<common_chat_msg_diff>:
32+
static std::ostream & operator<<(std::ostream & os, const std::vector<common_chat_msg_diff> & diffs) {
33+
os << "[\n";
34+
for (const auto & diff : diffs) {
35+
os << " " << diff << ",\n";
36+
}
37+
os << "]";
38+
return os;
39+
}
1940

2041
template <class T> static void assert_equals(const T & expected, const T & actual) {
2142
if (expected != actual) {
@@ -927,6 +948,90 @@ static void test_template_output_parsers() {
927948
}
928949
}
929950

951+
static void test_msg_diffs_compute() {
952+
{
953+
common_chat_msg msg1;
954+
955+
common_chat_msg msg2;
956+
msg2.content = "Hello, world!";
957+
958+
common_chat_msg_diff diff;
959+
diff.content_delta = "Hello, world!";
960+
961+
assert_equals(
962+
{diff},
963+
common_chat_msg_diff::compute_diffs(msg1, msg2));
964+
}
965+
{
966+
common_chat_msg msg1;
967+
msg1.content = "Hello,";
968+
969+
common_chat_msg msg2;
970+
msg2.content = "Hello, world!";
971+
972+
common_chat_msg_diff diff;
973+
diff.content_delta = " world!";
974+
975+
assert_equals(
976+
{diff},
977+
common_chat_msg_diff::compute_diffs(msg1, msg2));
978+
}
979+
{
980+
common_chat_msg msg0;
981+
982+
common_chat_msg msg1;
983+
msg1.tool_calls = { { "special_function", "{\"ar", /* .id = */ "123" } };
984+
985+
common_chat_msg msg2;
986+
msg2.tool_calls = { { "special_function", "{\"arg1\": 1}", /* .id = */ "123" } };
987+
988+
common_chat_msg_diff diff01;
989+
diff01.tool_call_index = 0;
990+
diff01.tool_call_delta.name = "special_function";
991+
diff01.tool_call_delta.id = "123";
992+
diff01.tool_call_delta.arguments = "{\"ar";
993+
994+
assert_equals(
995+
{diff01},
996+
common_chat_msg_diff::compute_diffs(msg0, msg1));
997+
998+
common_chat_msg_diff diff12;
999+
diff12.tool_call_index = 0;
1000+
diff12.tool_call_delta.name = "special_function";
1001+
diff12.tool_call_delta.id = "123";
1002+
diff12.tool_call_delta.arguments = "g1\": 1}";
1003+
1004+
assert_equals(
1005+
{diff12},
1006+
common_chat_msg_diff::compute_diffs(msg1, msg2));
1007+
}
1008+
{
1009+
common_chat_msg msg0;
1010+
1011+
common_chat_msg msg2;
1012+
msg2.tool_calls = {
1013+
{ "f1", "{\"arg1\": 1}", /* .id = */ "123" },
1014+
{ "f2", "{\"arg2\": 2}", /* .id = */ "222" },
1015+
};
1016+
1017+
common_chat_msg_diff diff1;
1018+
diff1.tool_call_index = 0;
1019+
diff1.tool_call_delta.name = "f1";
1020+
diff1.tool_call_delta.id = "123";
1021+
diff1.tool_call_delta.arguments = "{\"arg1\": 1}";
1022+
1023+
common_chat_msg_diff diff2;
1024+
diff2.tool_call_index = 1;
1025+
diff2.tool_call_delta.name = "f2";
1026+
diff2.tool_call_delta.id = "222";
1027+
diff2.tool_call_delta.arguments = "{\"arg2\": 2}";
1028+
1029+
assert_equals(
1030+
{diff1, diff2},
1031+
common_chat_msg_diff::compute_diffs(msg0, msg2));
1032+
}
1033+
}
1034+
9301035
int main(int argc, char ** argv) {
9311036
// try {
9321037
#ifndef _WIN32
@@ -960,6 +1065,7 @@ int main(int argc, char ** argv) {
9601065
} else
9611066
#endif
9621067
{
1068+
test_msg_diffs_compute();
9631069
test_msgs_oaicompat_json_conversion();
9641070
test_tools_oaicompat_json_conversion();
9651071
test_template_output_parsers();

0 commit comments

Comments
 (0)