Skip to content

Commit 1fb6372

Browse files
committed
feat: catch correct antiprompt according to the template, ref #65
1 parent 771aee0 commit 1fb6372

File tree

5 files changed

+42
-10
lines changed

5 files changed

+42
-10
lines changed

ac-local-plugin/code/LocalLlama.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class ChatSession {
5353
IoEndpoint& m_io;
5454

5555
std::string m_roleUser;
56+
std::string m_userPrefix;
5657
std::string m_roleAsistant;
5758
std::unique_ptr<llama::ChatFormat> m_chatFormat;
5859
std::vector<llama::ChatMsg> m_chatMessages;
@@ -77,6 +78,7 @@ class ChatSession {
7778
m_chatFormat = std::make_unique<llama::ChatFormat>(modelChatParams.chatTemplate);
7879
} else {
7980
modelChatParams.chatTemplate = chatTemplate;
81+
modelChatParams.roleAssistant = params.roleAssistant.value();
8082
m_chatFormat = std::make_unique<llama::ChatFormat>(std::move(modelChatParams));
8183
}
8284

@@ -85,6 +87,22 @@ class ChatSession {
8587

8688
m_roleUser = params.roleUser;
8789
m_roleAsistant = params.roleAssistant;
90+
91+
auto trim = [](const std::string& str) {
92+
auto begin = std::find_if_not(str.begin(), str.end(), [](unsigned char ch) {
93+
return std::isspace(ch);
94+
});
95+
96+
auto end = std::find_if_not(str.rbegin(), str.rend(), [](unsigned char ch) {
97+
return std::isspace(ch);
98+
}).base();
99+
100+
return (begin < end) ? std::string(begin, end) : "";
101+
};
102+
103+
// user prefix should a substr before stop
104+
m_userPrefix = m_chatFormat->formatMsg({.role = m_roleUser, .text = "stop"}, {}, false);
105+
m_userPrefix = trim(m_userPrefix.substr(0, m_userPrefix.find("stop")));
88106
}
89107

90108
~ChatSession() {
@@ -132,8 +150,7 @@ class ChatSession {
132150
}
133151

134152
ac::llama::AntipromptManager antiprompt;
135-
auto userPrefix = "\n" + m_roleUser + ": ";
136-
antiprompt.addAntiprompt(userPrefix);
153+
antiprompt.addAntiprompt(m_userPrefix);
137154

138155
std::string fullResponse;
139156
Schema::OpGetChatResponse::Return ret;

ac-local-plugin/example/ep-chat.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,18 @@ int main() try {
3838

3939
const std::string roleUser = "user";
4040
const std::string roleAssistant = "assistant";
41+
const std::string chatTemplate =
42+
"{% for message in messages %}"
43+
"{{ '<|' + message['role'] + '|>\\n' + message['content'] + '<|end|>' + '\\n' }}"
44+
"{% endfor %}"
45+
"{% if add_generation_prompt %}"
46+
"{{ '<|' + assistant_role + '|>\\n' }}"
47+
"{% endif %}";
4148

4249
sid = llama.call<schema::StateModelLoaded::OpStartInstance>({
4350
.instanceType = "chat",
4451
.setup = "A chat between a human user and a helpful AI assistant.",
52+
.chatTemplate = chatTemplate,
4553
.roleUser = roleUser,
4654
.roleAssistant = roleAssistant,
4755
});

ac-local-plugin/schema/ac/schema/LlamaCpp.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,11 @@ struct StateModelLoaded {
9696
v(ubatchSize, "ubatch_size", "Size of the context");
9797
v(ctrlVectorPaths, "ctrl_vectors", "Paths to the control vectors.");
9898
v(setup, "setup", "Initial setup prompt for the chat session");
99-
v(chatTemplate, "chat_template", "Chat template to use. If empty will use the model default");
99+
v(chatTemplate, "chat_template", "Valid Jinja chat template to use. If empty will use the model default");
100100
v(bosOverride, "bos_override", "BOS token to use with the custom template. If empty will use the model default");
101101
v(eosOverride, "eos_override", "EOS token to use with the custom template. If empty will use the model default");
102102
v(roleUser, "role_user", "Role name for the user");
103+
v(roleAssistant, "role_assistant", "Role name for the assistant");
103104
}
104105
};
105106

code/ac/llama/ChatFormat.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ class JinjaImpl final : public ChatFormat::impl {
111111
JinjaImpl(ChatFormat::Params params)
112112
{
113113
m_templateStr = std::move(params.chatTemplate);
114+
m_assistantRole = std::move(params.roleAssistant);
114115

115116
try {
116117
m_minjaTemplate = std::make_unique<minja::chat_template>(m_templateStr, params.bosToken, params.eosToken);
@@ -121,9 +122,9 @@ class JinjaImpl final : public ChatFormat::impl {
121122

122123
~JinjaImpl() {}
123124

124-
virtual std::string formatChat(std::span<const ChatMsg> chat, bool /*addAssistantPrompt*/) const override {
125+
virtual std::string formatChat(std::span<const ChatMsg> chat, bool addAssistantPrompt) const override {
125126
auto [jChat, size] = ac2jsonChatMessages(chat);
126-
return size == 0 ? std::string{} : applyJinja(jChat);
127+
return size == 0 ? std::string{} : applyJinja(jChat, addAssistantPrompt);
127128
}
128129

129130
virtual std::string formatMsg(const ChatMsg& msg, std::span<const ChatMsg> history, bool addAssistantPrompt) const override {
@@ -132,10 +133,10 @@ class JinjaImpl final : public ChatFormat::impl {
132133
}
133134

134135
auto [jchat, size] = ac2jsonChatMessages(history);
135-
auto fmtHistory = applyJinja(jchat);
136+
auto fmtHistory = applyJinja(jchat, addAssistantPrompt);
136137

137138
jchat.push_back({{"role", msg.role}, {"content", msg.text}});
138-
auto fmtNew = applyJinja(jchat);
139+
auto fmtNew = applyJinja(jchat, addAssistantPrompt);
139140

140141
return fmtNew.substr(fmtHistory.size());
141142
}
@@ -156,19 +157,22 @@ class JinjaImpl final : public ChatFormat::impl {
156157
return {messages, size};
157158
}
158159

159-
std::string applyJinja(acnl::json jChat) const {
160+
std::string applyJinja(acnl::json jChat, bool addAssistantPrompt) const {
160161
auto startsWith = [](const std::string& str, const std::string& prefix) {
161162
return str.rfind(prefix, 0) == 0;
162163
};
163164

164165
minja::chat_template_inputs tmpl_inputs;
165166
tmpl_inputs.messages = jChat;
167+
tmpl_inputs.add_generation_prompt = addAssistantPrompt;
168+
tmpl_inputs.extra_context = {
169+
{"assistant_role", m_assistantRole}
170+
};
166171

167-
minja::chat_template_options tmpl_opts;
168172
// To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
169173
// instead of using `chat_template_options.use_bos_token = false`, since these tokens
170174
// may be needed inside the template / between messages too.
171-
auto result = m_minjaTemplate->apply(tmpl_inputs, tmpl_opts);
175+
auto result = m_minjaTemplate->apply(tmpl_inputs);
172176
if (startsWith(result, m_minjaTemplate->bos_token())) {
173177
result = result.substr(m_minjaTemplate->bos_token().size());
174178
}
@@ -180,6 +184,7 @@ class JinjaImpl final : public ChatFormat::impl {
180184

181185
std::unique_ptr<minja::chat_template> m_minjaTemplate;
182186
std::string m_templateStr;
187+
std::string m_assistantRole;
183188
};
184189

185190

code/ac/llama/ChatFormat.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class AC_LLAMA_EXPORT ChatFormat {
2222
std::string chatTemplate;
2323
std::string bosToken;
2424
std::string eosToken;
25+
std::string roleAssistant = "";
2526
};
2627

2728
explicit ChatFormat(std::string templateStr);

0 commit comments

Comments
 (0)