@@ -59,6 +59,8 @@ class ChatSession {
5959 std::vector<llama::ChatMsg> m_chatMessages;
6060 size_t m_submittedMessages = 0 ;
6161
62+ ac::llama::AntipromptManager m_antiprompt;
63+
6264public:
6365 using Schema = sc::StateChatInstance;
6466
@@ -103,6 +105,17 @@ class ChatSession {
103105 // user prefix should a substr before stop
104106 m_userPrefix = m_chatFormat->formatMsg ({.role = m_roleUser, .text = " stop" }, {}, false );
105107 m_userPrefix = trim (m_userPrefix.substr (0 , m_userPrefix.find (" stop" )));
108+ m_antiprompt.addAntiprompt (m_userPrefix);
109+
110+ std::vector<llama::ChatMsg> msgs{
111+ {.role = m_roleAsistant, .text = " pre" },
112+ {.role = m_roleUser, .text = " post" },
113+ };
114+
115+ auto assistantEnd = m_chatFormat->formatChat (msgs, false );
116+ assistantEnd = assistantEnd.substr (assistantEnd.find (" pre" ) + 3 ); // 3 because of the length of "pre"
117+ assistantEnd = trim (assistantEnd.substr (0 , assistantEnd.find (" post" )));
118+ m_antiprompt.addAntiprompt (assistantEnd);
106119 }
107120
108121 ~ChatSession () {
@@ -149,8 +162,7 @@ class ChatSession {
149162 m_submittedMessages = m_chatMessages.size ();
150163 }
151164
152- ac::llama::AntipromptManager antiprompt;
153- antiprompt.addAntiprompt (m_userPrefix);
165+ m_antiprompt.reset ();
154166
155167 std::string fullResponse;
156168 Schema::OpGetChatResponse::Return ret;
@@ -167,7 +179,7 @@ class ChatSession {
167179 result += tokenStr;
168180 fullResponse += tokenStr;
169181
170- auto matchedAntiPrompt = antiprompt .feedGeneratedText (tokenStr);
182+ auto matchedAntiPrompt = m_antiprompt .feedGeneratedText (tokenStr);
171183 if (!matchedAntiPrompt.empty ()) {
172184 // and also hide it from the return value
173185 // note that we assume that m_userPrefix is always the final piece of text in the response
@@ -176,7 +188,7 @@ class ChatSession {
176188 break ;
177189 }
178190
179- if (isStreaming && !antiprompt .hasRunningAntiprompts ()) {
191+ if (isStreaming && !m_antiprompt .hasRunningAntiprompts ()) {
180192 co_await m_io.push (Frame_from (sc::StreamToken{}, result));
181193 result = {};
182194 }
0 commit comments