Skip to content

Commit ed7c622

Browse files
author
ochafik
committed
Rename: common/chat.*, common_chat_{inputs -> params}
1 parent 6e676c8 commit ed7c622

File tree

9 files changed

+117
-123
lines changed

9 files changed

+117
-123
lines changed

Makefile

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ TEST_TARGETS = \
5252
tests/test-arg-parser \
5353
tests/test-autorelease \
5454
tests/test-backend-ops \
55-
tests/test-chat-handler \
55+
tests/test-chat \
5656
tests/test-chat-template \
5757
tests/test-double-float \
5858
tests/test-grammar-integration \
@@ -984,7 +984,7 @@ OBJ_COMMON = \
984984
$(DIR_COMMON)/ngram-cache.o \
985985
$(DIR_COMMON)/sampling.o \
986986
$(DIR_COMMON)/speculative.o \
987-
$(DIR_COMMON)/chat-handler.o \
987+
$(DIR_COMMON)/chat.o \
988988
$(DIR_COMMON)/build-info.o \
989989
$(DIR_COMMON)/json-schema-to-grammar.o
990990

@@ -1363,8 +1363,8 @@ llama-server: \
13631363
examples/server/httplib.h \
13641364
examples/server/index.html.hpp \
13651365
examples/server/loading.html.hpp \
1366-
common/chat-handler.cpp \
1367-
common/chat-handler.hpp \
1366+
common/chat.cpp \
1367+
common/chat.hpp \
13681368
common/chat-template.hpp \
13691369
common/json.hpp \
13701370
common/minja.hpp \
@@ -1475,7 +1475,7 @@ tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp \
14751475
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
14761476
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
14771477

1478-
tests/test-chat-handler: tests/test-chat-handler.cpp \
1478+
tests/test-chat: tests/test-chat.cpp \
14791479
$(OBJ_ALL)
14801480
$(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<)
14811481
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

common/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ add_library(${TARGET} STATIC
5656
arg.cpp
5757
arg.h
5858
base64.hpp
59-
chat-handler.cpp
60-
chat-handler.hpp
59+
chat.cpp
60+
chat.hpp
6161
chat-template.hpp
6262
common.cpp
6363
common.h

common/chat-handler.cpp renamed to common/chat.cpp

Lines changed: 84 additions & 84 deletions
Large diffs are not rendered by default.
Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,5 @@
1-
/*
2-
Copyright 2024 Google LLC
1+
// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
32

4-
Use of this source code is governed by an MIT-style
5-
license that can be found in the LICENSE file or at
6-
https://opensource.org/licenses/MIT.
7-
*/
8-
// SPDX-License-Identifier: MIT
93
#pragma once
104

115
#include "common.h"
@@ -16,7 +10,7 @@
1610

1711
using json = nlohmann::ordered_json;
1812

19-
struct common_chat_params {
13+
struct common_chat_inputs {
2014
json messages;
2115
json tools;
2216
json tool_choice;
@@ -29,7 +23,7 @@ struct common_chat_params {
2923

3024
typedef std::function<common_chat_msg(const std::string & input)> common_chat_parser;
3125

32-
struct common_chat_data {
26+
struct common_chat_params {
3327
json prompt;
3428
std::string grammar;
3529
std::vector<common_grammar_trigger> grammar_triggers;
@@ -39,4 +33,4 @@ struct common_chat_data {
3933
bool grammar_lazy = false;
4034
};
4135

42-
struct common_chat_data common_chat_init(const common_chat_template & tmpl, const struct common_chat_params & params);
36+
struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & params);

common/common.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include "json.hpp"
1313
#include "json-schema-to-grammar.h"
1414
#include "llama.h"
15-
#include "chat-handler.hpp"
15+
#include "chat.hpp"
1616
#include "chat-template.hpp"
1717

1818
#include <algorithm>
@@ -1776,12 +1776,12 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
17761776
if (use_jinja) {
17771777
try {
17781778
auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
1779-
common_chat_params params;
1779+
common_chat_inputs params;
17801780
params.messages = json::array({{
17811781
{"role", "user"},
17821782
{"content", "test"},
17831783
}});
1784-
common_chat_init(chat_template, params);
1784+
common_chat_params_init(chat_template, params);
17851785
return true;
17861786
} catch (const std::exception & e) {
17871787
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
@@ -1803,10 +1803,10 @@ std::string common_chat_apply_template(
18031803
for (const auto & msg : msgs) {
18041804
messages.push_back({{"role", msg.role}, {"content", msg.content}});
18051805
}
1806-
common_chat_params params;
1806+
common_chat_inputs params;
18071807
params.messages = messages;
18081808
params.add_generation_prompt = add_ass;
1809-
auto data = common_chat_init(tmpl, params);
1809+
auto data = common_chat_params_init(tmpl, params);
18101810
return data.prompt;
18111811
}
18121812

examples/server/server.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1824,16 +1824,16 @@ struct server_context {
18241824

18251825
if (use_jinja) {
18261826
auto templates = common_chat_templates_from_model(model, "");
1827-
common_chat_params params;
1827+
common_chat_inputs params;
18281828
params.messages = json::array({{
18291829
{"role", "user"},
18301830
{"content", "test"},
18311831
}});
18321832
GGML_ASSERT(templates.template_default);
18331833
try {
1834-
common_chat_init(*templates.template_default, params);
1834+
common_chat_params_init(*templates.template_default, params);
18351835
if (templates.template_tool_use) {
1836-
common_chat_init(*templates.template_tool_use, params);
1836+
common_chat_params_init(*templates.template_tool_use, params);
18371837
}
18381838
return true;
18391839
} catch (const std::exception & e) {
@@ -3787,10 +3787,10 @@ int main(int argc, char ** argv) {
37873787
std::vector<server_task> tasks;
37883788

37893789
try {
3790-
common_chat_data chat_data;
3790+
common_chat_params chat_data;
37913791
bool add_special = false;
37923792
if (tmpl && ctx_server.params_base.use_jinja) {
3793-
chat_data = common_chat_init(*tmpl, {
3793+
chat_data = common_chat_params_init(*tmpl, {
37943794
/* .messages = */ json_value(data, "messages", json::array()),
37953795
/* .tools = */ json_value(data, "tools", json()),
37963796
/* .tool_choice = */ json_value(data, "tool_choice", std::string("auto")),

examples/server/utils.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#define JSON_ASSERT GGML_ASSERT
1818
#include "json.hpp"
1919
#include "minja.hpp"
20-
#include "chat-handler.hpp"
20+
#include "chat.hpp"
2121
#include "chat-template.hpp"
2222

2323
#include <random>

tests/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ if (NOT WIN32)
9393
llama_target_and_test(test-grammar-parser.cpp)
9494
llama_target_and_test(test-grammar-integration.cpp)
9595
llama_target_and_test(test-llama-grammar.cpp)
96-
llama_target_and_test(test-chat-handler.cpp)
96+
llama_target_and_test(test-chat.cpp)
9797
# TODO: disabled on loongarch64 because the ggml-ci node lacks Python 3.8
9898
if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
9999
llama_target_and_test(test-json-schema-to-grammar.cpp WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..)

tests/test-chat-handler.cpp renamed to tests/test-chat.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include "chat-handler.hpp"
1+
#include "chat.hpp"
22
#include "chat-template.hpp"
33
#include "llama-grammar.h"
44
#include "unicode.h"
@@ -169,15 +169,15 @@ struct delta_data {
169169
};
170170

171171
static delta_data init_delta(const common_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
172-
common_chat_params params;
172+
common_chat_inputs params;
173173
params.parallel_tool_calls = true;
174174
params.messages = json::array();
175175
params.messages.push_back(user_message);
176176
params.tools = tools;
177-
auto prefix_data = common_chat_init(tmpl, params);
177+
auto prefix_data = common_chat_params_init(tmpl, params);
178178
params.messages.push_back(delta_message);
179179
params.add_generation_prompt = false;
180-
auto full_data = common_chat_init(tmpl, params);
180+
auto full_data = common_chat_params_init(tmpl, params);
181181

182182
std::string prefix = prefix_data.prompt;
183183
std::string full = full_data.prompt;
@@ -220,7 +220,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
220220
};
221221

222222
for (const auto & tool_choice : json({"auto", "required"})) {
223-
common_chat_params params;
223+
common_chat_inputs params;
224224
params.tool_choice = tool_choice;
225225
params.parallel_tool_calls = true;
226226
params.messages = json {user_message, test_message};
@@ -301,15 +301,15 @@ static void test_template_output_parsers() {
301301
};
302302

303303

304-
common_chat_params no_tools_params;
304+
common_chat_inputs no_tools_params;
305305
no_tools_params.messages = {{{"role", "user"}, {"content", "Hey"}}};
306306

307-
common_chat_params tools_params = no_tools_params;
307+
common_chat_inputs tools_params = no_tools_params;
308308
tools_params.tools = json::array();
309309
tools_params.tools.push_back(special_function_tool);
310310

311-
auto describe = [](const common_chat_template & tmpl, const common_chat_params & params) {
312-
auto data = common_chat_init(tmpl, params);
311+
auto describe = [](const common_chat_template & tmpl, const common_chat_inputs & params) {
312+
auto data = common_chat_params_init(tmpl, params);
313313
return data.format;
314314
};
315315

@@ -322,7 +322,7 @@ static void test_template_output_parsers() {
322322
assert_equals(std::string("generic tool calls"), describe(common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"), "<s>", "</s>"), tools_params));
323323

324324
// Generic tool calls doesn't generate / parse content-only messages symmetrically.
325-
assert_msg_equals(msg_from_json(text_message), common_chat_init(tmpl, tools_params).parser(
325+
assert_msg_equals(msg_from_json(text_message), common_chat_params_init(tmpl, tools_params).parser(
326326
"{\n"
327327
" \"response\": \"Hello, world!\"\n"
328328
"}"));

0 commit comments

Comments
 (0)