Skip to content

Commit 9e92fd1

Browse files
committed
feat: add custom chat template, ref #65
1 parent 101e541 commit 9e92fd1

File tree

3 files changed

+77
-97
lines changed

3 files changed

+77
-97
lines changed

ac-local-plugin/code/LocalLlama.cpp

Lines changed: 45 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <ac/llama/ControlVector.hpp>
1111
#include <ac/llama/LogitComparer.hpp>
1212
#include <ac/llama/ResourceCache.hpp>
13+
#include <ac/llama/ChatFormat.hpp>
1314

1415
#include <ac/local/Service.hpp>
1516
#include <ac/local/ServiceFactory.hpp>
@@ -38,9 +39,6 @@
3839
#include "aclp-llama-version.h"
3940
#include "aclp-llama-interface.hpp"
4041

41-
// TODO: remove this include
42-
#include <iostream>
43-
4442
namespace ac::local {
4543

4644
namespace {
@@ -53,13 +51,12 @@ class ChatSession {
5351
const llama::Vocab& m_vocab;
5452
llama::Instance& m_instance;
5553
IoEndpoint& m_io;
56-
std::string m_userPrefix;
57-
std::string m_assistantPrefix;
5854

59-
std::vector<llama::Token> m_promptTokens;
55+
std::string m_userPrefix;
56+
std::unique_ptr<llama::ChatFormat> m_chatFormat;
57+
std::vector<llama::ChatMsg> m_chatMessages;
58+
size_t m_submittedMessages = 0;
6059

61-
bool m_addUserPrefix = true;
62-
bool m_addAssistantPrefix = true;
6360
public:
6461
using Schema = sc::StateChatInstance;
6562

@@ -69,50 +66,57 @@ class ChatSession {
6966
, m_instance(instance)
7067
, m_io(io)
7168
{
72-
m_promptTokens = instance.model().vocab().tokenize(params.setup.value(), true, true);
73-
m_session.setInitialPrompt(m_promptTokens);
69+
auto& chatTemplate = params.chatTemplate.value();
70+
auto modelChatParams = llama::ChatFormat::getChatParams(instance.model());
71+
if (chatTemplate.empty()) {
72+
if (modelChatParams.chatTemplate.empty()) {
73+
throw_ex{} << "The model does not have a default chat template, please provide one.";
74+
}
75+
76+
m_chatFormat = std::make_unique<llama::ChatFormat>(modelChatParams.chatTemplate);
77+
} else {
78+
modelChatParams.chatTemplate = chatTemplate;
79+
m_chatFormat = std::make_unique<llama::ChatFormat>(std::move(modelChatParams));
80+
}
81+
82+
auto promptTokens = instance.model().vocab().tokenize(params.setup.value(), true, true);
83+
m_session.setInitialPrompt(promptTokens);
7484

7585
m_userPrefix = "\n";
7686
m_userPrefix += params.roleUser;
7787
m_userPrefix += ":";
78-
m_assistantPrefix = "\n";
79-
m_assistantPrefix += params.roleAssistant;
80-
m_assistantPrefix += ":";
8188
}
8289

8390
~ChatSession() {
8491
m_instance.stopSession();
8592
}
8693

87-
xec::coro<void> sendMessages(Schema::OpSendMessages::Params& params) {
94+
xec::coro<void> addMessages(Schema::OpAddChatMessages::Params& params) {
8895
auto& messages = params.messages.value();
89-
for (size_t i = 0; i < messages.size(); ++i) {
90-
std::cout << messages[i].role.value() << ": " << messages[i].content.value() << "\n";
96+
std::vector<llama::Token> tokens;
97+
98+
for (const auto& message : messages) {
99+
m_chatMessages.push_back(llama::ChatMsg{
100+
.role = message.role.value(),
101+
.text = message.content.value()
102+
});
91103
}
92104

93-
co_await m_io.push(Frame_from(schema::SimpleOpReturn<Schema::OpSendMessages>{}, {}));
105+
co_await m_io.push(Frame_from(schema::SimpleOpReturn<Schema::OpAddChatMessages>{}, {}));
94106
}
95107

96-
xec::coro<void> pushPrompt(Schema::OpAddChatPrompt::Params& params) {
97-
auto& prompt = params.prompt.value();
98-
99-
// prefix with space as the generated content doesn't include it
100-
prompt = ' ' + prompt;
101-
102-
if (m_addUserPrefix) {
103-
// we haven't had an interaction yet, so we need to add the user prefix
104-
// subsequent interaction will have it generated
105-
prompt = m_userPrefix + prompt;
108+
void submitPendingImages() {
109+
auto messagesToSubmit = m_chatMessages.size() - m_submittedMessages;
110+
std::string formatted;
111+
if (messagesToSubmit == 1) {
112+
formatted = m_chatFormat->formatMsg(
113+
m_chatMessages.back(), {m_chatMessages.begin(), m_chatMessages.end() - 1}, true);
114+
} else {
115+
formatted = m_chatFormat->formatChat(
116+
{m_chatMessages.begin() + m_submittedMessages, m_chatMessages.end()}, true);
106117
}
107118

108-
// prepare for the next generation
109-
prompt += m_assistantPrefix;
110-
111-
m_promptTokens = m_vocab.tokenize(prompt, false, false);
112-
m_session.pushPrompt(m_promptTokens);
113-
m_addAssistantPrefix = false;
114-
115-
co_await m_io.push(Frame_from(schema::SimpleOpReturn<Schema::OpAddChatPrompt>{}, {}));
119+
m_session.pushPrompt(m_vocab.tokenize(formatted, true, true));
116120
}
117121

118122
xec::coro<void> getResponse(Schema::ChatResponseParams params, bool isStreaming) {
@@ -122,18 +126,14 @@ class ChatSession {
122126
maxTokens = 1000;
123127
}
124128

125-
if (m_addAssistantPrefix) {
126-
// generated responses are requested first, but we haven't yet fed the assistant prefix to the model
127-
auto prompt = m_assistantPrefix;
128-
assert(m_promptTokens.empty()); // nothing should be pending here
129-
m_promptTokens = m_vocab.tokenize(prompt, false, false);
130-
m_session.pushPrompt(m_promptTokens);
129+
if (m_submittedMessages != m_chatMessages.size()) {
130+
submitPendingImages();
131+
m_submittedMessages = m_chatMessages.size();
131132
}
132133

133134
ac::llama::AntipromptManager antiprompt;
134135
antiprompt.addAntiprompt(m_userPrefix);
135136

136-
m_addUserPrefix = true;
137137
Schema::OpGetChatResponse::Return ret;
138138
auto& result = ret.response.materialize();
139139

@@ -149,14 +149,10 @@ class ChatSession {
149149

150150
auto matchedAntiPrompt = antiprompt.feedGeneratedText(tokenStr);
151151
if (!matchedAntiPrompt.empty()) {
152-
// user prefix was added by generation, so don't add it again
153-
m_addUserPrefix = false;
154-
155152
// and also hide it from the return value
156153
// note that we assume that m_userPrefix is always the final piece of text in the response
157154
// TODO: update to better match the cutoff when issue #131 is done
158155
result.erase(result.size() - matchedAntiPrompt.size());
159-
m_addUserPrefix = false;
160156
break;
161157
}
162158

@@ -447,14 +443,12 @@ struct LocalLlama {
447443
Frame err;
448444

449445
try {
450-
if (auto iparams = Frame_optTo(schema::OpParams<Schema::OpAddChatPrompt>{}, *f)) {
451-
co_await chatSession.pushPrompt(*iparams);
452-
} else if (auto iparams = Frame_optTo(schema::OpParams<Schema::OpGetChatResponse>{}, *f)) {
446+
if (auto iparams = Frame_optTo(schema::OpParams<Schema::OpGetChatResponse>{}, *f)) {
453447
co_await chatSession.getResponse(*iparams, false);
454448
} else if (auto iparams = Frame_optTo(schema::OpParams<Schema::OpStreamChatResponse>{}, *f)) {
455449
co_await chatSession.getResponse(*iparams, true);
456-
} else if (auto iparams = Frame_optTo(schema::OpParams<Schema::OpSendMessages>{}, *f)) {
457-
co_await chatSession.sendMessages(*iparams);
450+
} else if (auto iparams = Frame_optTo(schema::OpParams<Schema::OpAddChatMessages>{}, *f)) {
451+
co_await chatSession.addMessages(*iparams);
458452
} else {
459453
err = unknownOpError(*f);
460454
}
@@ -478,7 +472,6 @@ struct LocalLlama {
478472
lparams.vocabOnly = lmParams.vocabOnly.valueOr(false);
479473
lparams.prefixInputsWithBos = lmParams.prefixInputsWithBos.valueOr(false);
480474

481-
482475
auto model = m_resourceCache.getModel({.gguf = gguf, .params = lparams});
483476

484477
std::vector<llama::ResourceCache::LoraLock> loras;

ac-local-plugin/example/ep-chat.cpp

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ int main() try {
3131
std::cout << "Initial state: " << sid << '\n';
3232

3333
for (auto x : llama.stream<schema::StateLlama::OpLoadModel>({
34-
// .ggufPath = AC_TEST_DATA_LLAMA_DIR "/gpt2-117m-q6_k.gguf"
35-
.ggufPath = AC_TEST_DATA_LLAMA_DIR "/../../../tmp/Meta-Llama-3.1-8B-Instruct-Q6_K.gguf"
34+
.ggufPath = AC_TEST_DATA_LLAMA_DIR "/gpt2-117m-q6_k.gguf"
3635
})) {
3736
std::cout << "Model loaded: " << x.tag.value() << " " << x.progress.value() << '\n';
3837
}
@@ -43,22 +42,27 @@ int main() try {
4342
sid = llama.call<schema::StateModelLoaded::OpStartInstance>({
4443
.instanceType = "chat",
4544
.setup = "A chat between a human user and a helpful AI assistant.",
46-
// .roleUser = roleUser,
47-
// .roleAssistant = roleAssistant
45+
.roleUser = roleUser,
4846
});
4947
std::cout << "Instance started: " << sid << '\n';
5048

51-
std::vector<schema::Message> initMessages = {
52-
{roleUser, "I need assistance for API design"},
53-
{roleAssistant, "What aspect of API design are you looking for help with? Do you have a specific problem or question in mind?"},
54-
{roleUser, "It's a C++ implementation of a class"},
55-
};
56-
57-
llama.call<schema::StateChatInstance::OpSendMessages>({
58-
.messages = initMessages
59-
});
49+
constexpr bool addPreviousMessages = true;
50+
if (addPreviousMessages) {
51+
std::vector<schema::Message> msgs = {
52+
{roleUser, "Hey, I need help planning a surprise weekend getaway."},
53+
{roleAssistant, "Sure! Are you thinking of something outdoorsy, a relaxing spa weekend, or maybe a city adventure?"},
54+
{roleUser, "A quiet nature retreat would be perfect."},
55+
{roleAssistant, "Great choice. I can suggest a few scenic cabin locations and even help you build a checklist for the trip."}
56+
};
57+
58+
llama.call<schema::StateChatInstance::OpAddChatMessages>({
59+
.messages = msgs
60+
});
6061

61-
std::vector<schema::Message> messages;
62+
for (auto& m : msgs) {
63+
std::cout << m.role.value() << ": " << m.content.value() << '\n';
64+
}
65+
}
6266

6367
while (true) {
6468
std::cout << roleUser <<": ";
@@ -67,27 +71,23 @@ int main() try {
6771
std::getline(std::cin, user);
6872
}
6973
if (user == "/quit") break;
70-
user = ' ' + user;
71-
messages.push_back({roleUser, user});
7274

73-
llama.call<schema::StateChatInstance::OpAddChatPrompt>({
74-
.prompt = user
75+
llama.call<schema::StateChatInstance::OpAddChatMessages>({
76+
.messages = std::vector<schema::Message>{
77+
{ roleUser, user}
78+
}
7579
});
7680

77-
std::string text;
7881
std::cout << roleAssistant << ": ";
79-
constexpr bool streamChat = false;
82+
constexpr bool streamChat = true;
8083
if (streamChat) {
8184
for(auto t: llama.stream<schema::StateChatInstance::OpStreamChatResponse>({})) {
82-
text += t;
8385
std::cout << t << std::flush;
8486
}
8587
} else {
8688
auto res = llama.call<schema::StateChatInstance::OpGetChatResponse>({});
87-
text += res.response.value();
8889
std::cout << res.response.value() << std::flush;
8990
}
90-
messages.push_back({roleUser, text});
9191
std::cout << "\n";
9292
}
9393

ac-local-plugin/schema/ac/schema/LlamaCpp.hpp

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -83,19 +83,22 @@ struct StateModelLoaded {
8383

8484
Field<std::string> setup = Default();
8585
Field<std::string> chatTemplate = Default();
86+
Field<std::string> bosOverride = Default();
87+
Field<std::string> eosOverride = Default();
8688
Field<std::string> roleUser = Default("User");
87-
Field<std::string> roleAssistant = Default("Assistant");
8889

8990
template <typename Visitor>
9091
void visitFields(Visitor& v) {
9192
v(instanceType, "instance_type", "Type of the instance to start");
9293
v(ctxSize, "ctx_size", "Size of the context");
9394
v(batchSize, "batch_size", "Size of the single batch");
9495
v(ubatchSize, "ubatch_size", "Size of the context");
95-
v(ctrlVectorPaths, "ctrl-vectors", "Paths to the control vectors.");
96-
v(setup, "setup", "Initial setup for the chat session");
96+
v(ctrlVectorPaths, "ctrl_vectors", "Paths to the control vectors.");
97+
v(setup, "setup", "Initial setup prompt for the chat session");
98+
v(chatTemplate, "chat_template", "Chat template to use. If empty will use the model default");
99+
v(bosOverride, "bos_override", "BOS token to use with the custom template. If empty will use the model default");
100+
v(eosOverride, "eos_override", "EOS token to use with the custom template. If empty will use the model default");
97101
v(roleUser, "role_user", "Role name for the user");
98-
v(roleAssistant, "role_assistant", "Role name for the assistant");
99102
}
100103
};
101104

@@ -220,8 +223,8 @@ struct StateChatInstance {
220223
static constexpr auto id = "chat-instance";
221224
static constexpr auto desc = "Chat state";
222225

223-
struct OpSendMessages {
224-
static inline constexpr std::string_view id = "send-messages";
226+
struct OpAddChatMessages {
227+
static inline constexpr std::string_view id = "add-messages";
225228
static inline constexpr std::string_view desc = "Send messages to the chat session";
226229

227230
struct Params {
@@ -236,22 +239,6 @@ struct StateChatInstance {
236239
using Return = nullptr_t;
237240
};
238241

239-
struct OpAddChatPrompt {
240-
static inline constexpr std::string_view id = "add-chat-prompt";
241-
static inline constexpr std::string_view desc = "Add a prompt to the chat session as a user";
242-
243-
struct Params {
244-
Field<std::string> prompt = Default();
245-
246-
template <typename Visitor>
247-
void visitFields(Visitor& v) {
248-
v(prompt, "prompt", "Prompt to add to the chat session");
249-
}
250-
};
251-
252-
using Return = nullptr_t;
253-
};
254-
255242
struct ChatResponseParams {
256243
Field<uint32_t> maxTokens = Default(0);
257244

0 commit comments

Comments
 (0)