Skip to content

Commit bb2ca87

Browse files
committed
fix: return back the assistant role to the scheme, ref #65
1 parent 9e92fd1 commit bb2ca87

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

ac-local-plugin/code/LocalLlama.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ class ChatSession {
5252
llama::Instance& m_instance;
5353
IoEndpoint& m_io;
5454

55-
std::string m_userPrefix;
55+
std::string m_roleUser;
56+
std::string m_roleAsistant;
5657
std::unique_ptr<llama::ChatFormat> m_chatFormat;
5758
std::vector<llama::ChatMsg> m_chatMessages;
5859
size_t m_submittedMessages = 0;
@@ -82,9 +83,8 @@ class ChatSession {
8283
auto promptTokens = instance.model().vocab().tokenize(params.setup.value(), true, true);
8384
m_session.setInitialPrompt(promptTokens);
8485

85-
m_userPrefix = "\n";
86-
m_userPrefix += params.roleUser;
87-
m_userPrefix += ":";
86+
m_roleUser = params.roleUser;
87+
m_roleAsistant = params.roleAssistant;
8888
}
8989

9090
~ChatSession() {
@@ -132,8 +132,10 @@ class ChatSession {
132132
}
133133

134134
ac::llama::AntipromptManager antiprompt;
135-
antiprompt.addAntiprompt(m_userPrefix);
135+
auto userPrefix = "\n" + m_roleUser + ": ";
136+
antiprompt.addAntiprompt(userPrefix);
136137

138+
std::string fullResponse;
137139
Schema::OpGetChatResponse::Return ret;
138140
auto& result = ret.response.materialize();
139141

@@ -146,6 +148,7 @@ class ChatSession {
146148

147149
auto tokenStr = m_vocab.tokenToString(t);
148150
result += tokenStr;
151+
fullResponse += tokenStr;
149152

150153
auto matchedAntiPrompt = antiprompt.feedGeneratedText(tokenStr);
151154
if (!matchedAntiPrompt.empty()) {
@@ -167,6 +170,7 @@ class ChatSession {
167170
// with a leading space, so instead of burdening them with "unorthodox" tokens, we'll clear it here
168171
if (!result.empty() && result[0] == ' ') {
169172
result.erase(0, 1);
173+
fullResponse.erase(0, 1);
170174
}
171175

172176
if (isStreaming) {
@@ -180,6 +184,8 @@ class ChatSession {
180184
.response = std::move(result)
181185
}));
182186
}
187+
188+
m_chatMessages.push_back({.role = m_roleAsistant, .text = std::move(fullResponse)});
183189
}
184190
};
185191

ac-local-plugin/example/ep-chat.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ int main() try {
4343
.instanceType = "chat",
4444
.setup = "A chat between a human user and a helpful AI assistant.",
4545
.roleUser = roleUser,
46+
.roleAssistant = roleAssistant,
4647
});
4748
std::cout << "Instance started: " << sid << '\n';
4849

ac-local-plugin/schema/ac/schema/LlamaCpp.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ struct StateModelLoaded {
8686
Field<std::string> bosOverride = Default();
8787
Field<std::string> eosOverride = Default();
8888
Field<std::string> roleUser = Default("User");
89+
Field<std::string> roleAssistant = Default("Assistant");
8990

9091
template <typename Visitor>
9192
void visitFields(Visitor& v) {

0 commit comments

Comments
 (0)