|
16 | 16 |
|
17 | 17 | using json = nlohmann::ordered_json; |
18 | 18 |
|
| 19 | +static std::ostream & operator<<(std::ostream & os, const common_chat_msg_diff & diff) { |
| 20 | + // os << "reasoning_content_delta: " << diff.reasoning_content_delta << '\n'; |
| 21 | + os << "{ content_delta: " << diff.content_delta << "; "; |
| 22 | + if (diff.tool_call_index != std::string::npos) { |
| 23 | + os << "tool_call_index: " << diff.tool_call_index << "; "; |
| 24 | + os << "tool_call_delta.name: " << diff.tool_call_delta.name << "; "; |
| 25 | + os << "tool_call_delta.id: " << diff.tool_call_delta.id << "; "; |
| 26 | + os << "tool_call_delta.arguments: " << diff.tool_call_delta.arguments << "; "; |
| 27 | + } |
| 28 | + os << "}"; |
| 29 | + return os; |
| 30 | +} |
| 31 | +// operator<< for vector<common_chat_msg_diff>: |
| 32 | +static std::ostream & operator<<(std::ostream & os, const std::vector<common_chat_msg_diff> & diffs) { |
| 33 | + os << "[\n"; |
| 34 | + for (const auto & diff : diffs) { |
| 35 | + os << " " << diff << ",\n"; |
| 36 | + } |
| 37 | + os << "]"; |
| 38 | + return os; |
| 39 | +} |
19 | 40 |
|
20 | 41 | template <class T> static void assert_equals(const T & expected, const T & actual) { |
21 | 42 | if (expected != actual) { |
@@ -927,6 +948,90 @@ static void test_template_output_parsers() { |
927 | 948 | } |
928 | 949 | } |
929 | 950 |
|
| 951 | +static void test_msg_diffs_compute() { |
| 952 | + { |
| 953 | + common_chat_msg msg1; |
| 954 | + |
| 955 | + common_chat_msg msg2; |
| 956 | + msg2.content = "Hello, world!"; |
| 957 | + |
| 958 | + common_chat_msg_diff diff; |
| 959 | + diff.content_delta = "Hello, world!"; |
| 960 | + |
| 961 | + assert_equals( |
| 962 | + {diff}, |
| 963 | + common_chat_msg_diff::compute_diffs(msg1, msg2)); |
| 964 | + } |
| 965 | + { |
| 966 | + common_chat_msg msg1; |
| 967 | + msg1.content = "Hello,"; |
| 968 | + |
| 969 | + common_chat_msg msg2; |
| 970 | + msg2.content = "Hello, world!"; |
| 971 | + |
| 972 | + common_chat_msg_diff diff; |
| 973 | + diff.content_delta = " world!"; |
| 974 | + |
| 975 | + assert_equals( |
| 976 | + {diff}, |
| 977 | + common_chat_msg_diff::compute_diffs(msg1, msg2)); |
| 978 | + } |
| 979 | + { |
| 980 | + common_chat_msg msg0; |
| 981 | + |
| 982 | + common_chat_msg msg1; |
| 983 | + msg1.tool_calls = { { "special_function", "{\"ar", /* .id = */ "123" } }; |
| 984 | + |
| 985 | + common_chat_msg msg2; |
| 986 | + msg2.tool_calls = { { "special_function", "{\"arg1\": 1}", /* .id = */ "123" } }; |
| 987 | + |
| 988 | + common_chat_msg_diff diff01; |
| 989 | + diff01.tool_call_index = 0; |
| 990 | + diff01.tool_call_delta.name = "special_function"; |
| 991 | + diff01.tool_call_delta.id = "123"; |
| 992 | + diff01.tool_call_delta.arguments = "{\"ar"; |
| 993 | + |
| 994 | + assert_equals( |
| 995 | + {diff01}, |
| 996 | + common_chat_msg_diff::compute_diffs(msg0, msg1)); |
| 997 | + |
| 998 | + common_chat_msg_diff diff12; |
| 999 | + diff12.tool_call_index = 0; |
| 1000 | + diff12.tool_call_delta.name = "special_function"; |
| 1001 | + diff12.tool_call_delta.id = "123"; |
| 1002 | + diff12.tool_call_delta.arguments = "g1\": 1}"; |
| 1003 | + |
| 1004 | + assert_equals( |
| 1005 | + {diff12}, |
| 1006 | + common_chat_msg_diff::compute_diffs(msg1, msg2)); |
| 1007 | + } |
| 1008 | + { |
| 1009 | + common_chat_msg msg0; |
| 1010 | + |
| 1011 | + common_chat_msg msg2; |
| 1012 | + msg2.tool_calls = { |
| 1013 | + { "f1", "{\"arg1\": 1}", /* .id = */ "123" }, |
| 1014 | + { "f2", "{\"arg2\": 2}", /* .id = */ "222" }, |
| 1015 | + }; |
| 1016 | + |
| 1017 | + common_chat_msg_diff diff1; |
| 1018 | + diff1.tool_call_index = 0; |
| 1019 | + diff1.tool_call_delta.name = "f1"; |
| 1020 | + diff1.tool_call_delta.id = "123"; |
| 1021 | + diff1.tool_call_delta.arguments = "{\"arg1\": 1}"; |
| 1022 | + |
| 1023 | + common_chat_msg_diff diff2; |
| 1024 | + diff2.tool_call_index = 1; |
| 1025 | + diff2.tool_call_delta.name = "f2"; |
| 1026 | + diff2.tool_call_delta.id = "222"; |
| 1027 | + diff2.tool_call_delta.arguments = "{\"arg2\": 2}"; |
| 1028 | + |
| 1029 | + assert_equals( |
| 1030 | + {diff1, diff2}, |
| 1031 | + common_chat_msg_diff::compute_diffs(msg0, msg2)); |
| 1032 | + } |
| 1033 | +} |
| 1034 | + |
930 | 1035 | int main(int argc, char ** argv) { |
931 | 1036 | // try { |
932 | 1037 | #ifndef _WIN32 |
@@ -960,6 +1065,7 @@ int main(int argc, char ** argv) { |
960 | 1065 | } else |
961 | 1066 | #endif |
962 | 1067 | { |
| 1068 | + test_msg_diffs_compute(); |
963 | 1069 | test_msgs_oaicompat_json_conversion(); |
964 | 1070 | test_tools_oaicompat_json_conversion(); |
965 | 1071 | test_template_output_parsers(); |
|
0 commit comments