Skip to content

Commit 2c8d4cf

Browse files
committed
fix: capture ending of the assistant prompt too since we need to erase it from the reponse, ref #65
1 parent 1fb6372 commit 2c8d4cf

File tree

3 files changed

+29
-9
lines changed

3 files changed

+29
-9
lines changed

ac-local-plugin/code/LocalLlama.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ class ChatSession {
5959
std::vector<llama::ChatMsg> m_chatMessages;
6060
size_t m_submittedMessages = 0;
6161

62+
ac::llama::AntipromptManager m_antiprompt;
63+
6264
public:
6365
using Schema = sc::StateChatInstance;
6466

@@ -103,6 +105,17 @@ class ChatSession {
103105
// user prefix should a substr before stop
104106
m_userPrefix = m_chatFormat->formatMsg({.role = m_roleUser, .text = "stop"}, {}, false);
105107
m_userPrefix = trim(m_userPrefix.substr(0, m_userPrefix.find("stop")));
108+
m_antiprompt.addAntiprompt(m_userPrefix);
109+
110+
std::vector<llama::ChatMsg> msgs{
111+
{.role = m_roleAsistant, .text = "pre"},
112+
{.role = m_roleUser, .text = "post"},
113+
};
114+
115+
auto assistantEnd = m_chatFormat->formatChat(msgs, false);
116+
assistantEnd = assistantEnd.substr(assistantEnd.find("pre") + 3); // 3 because of the length of "pre"
117+
assistantEnd = trim(assistantEnd.substr(0, assistantEnd.find("post")));
118+
m_antiprompt.addAntiprompt(assistantEnd);
106119
}
107120

108121
~ChatSession() {
@@ -149,8 +162,7 @@ class ChatSession {
149162
m_submittedMessages = m_chatMessages.size();
150163
}
151164

152-
ac::llama::AntipromptManager antiprompt;
153-
antiprompt.addAntiprompt(m_userPrefix);
165+
m_antiprompt.reset();
154166

155167
std::string fullResponse;
156168
Schema::OpGetChatResponse::Return ret;
@@ -167,7 +179,7 @@ class ChatSession {
167179
result += tokenStr;
168180
fullResponse += tokenStr;
169181

170-
auto matchedAntiPrompt = antiprompt.feedGeneratedText(tokenStr);
182+
auto matchedAntiPrompt = m_antiprompt.feedGeneratedText(tokenStr);
171183
if (!matchedAntiPrompt.empty()) {
172184
// and also hide it from the return value
173185
// note that we assume that m_userPrefix is always the final piece of text in the response
@@ -176,7 +188,7 @@ class ChatSession {
176188
break;
177189
}
178190

179-
if (isStreaming && !antiprompt.hasRunningAntiprompts()) {
191+
if (isStreaming && !m_antiprompt.hasRunningAntiprompts()) {
180192
co_await m_io.push(Frame_from(sc::StreamToken{}, result));
181193
result = {};
182194
}

ac-local-plugin/example/ep-chat.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,11 @@ int main() try {
4646
"{{ '<|' + assistant_role + '|>\\n' }}"
4747
"{% endif %}";
4848

49+
constexpr bool useChatTemplate = false;
4950
sid = llama.call<schema::StateModelLoaded::OpStartInstance>({
5051
.instanceType = "chat",
5152
.setup = "A chat between a human user and a helpful AI assistant.",
52-
.chatTemplate = chatTemplate,
53+
.chatTemplate = useChatTemplate ? chatTemplate : "",
5354
.roleUser = roleUser,
5455
.roleAssistant = roleAssistant,
5556
});

code/ac/llama/AntipromptManager.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,22 @@ void AntipromptManager::addAntiprompt(std::string_view antiprompt) {
1010
}
1111

1212
std::string AntipromptManager::feedGeneratedText(std::string_view text) {
13+
std::vector<std::pair<std::string, size_t>> matchedAntiprompts;
1314
for (auto& ap : m_antiprompts) {
1415
int found = ap.feedText(text);
1516
if (found > 0) {
16-
reset();
17-
return found == 0 ?
18-
ap.getString():
19-
ap.getString() + std::string(text.substr(found, text.length()));
17+
auto res = found == 0 ?
18+
ap.getString():
19+
ap.getString() + std::string(text.substr(found, text.length()));
20+
matchedAntiprompts.push_back({res, found});
2021
}
2122
}
23+
if (!matchedAntiprompts.empty()) {
24+
reset();
25+
std::sort(matchedAntiprompts.begin(), matchedAntiprompts.end());
26+
auto& [res, found] = matchedAntiprompts.front();
27+
return res;
28+
}
2229

2330
return {};
2431
}

0 commit comments

Comments
 (0)