1010#include < ac/llama/ControlVector.hpp>
1111#include < ac/llama/LogitComparer.hpp>
1212#include < ac/llama/ResourceCache.hpp>
13+ #include < ac/llama/ChatFormat.hpp>
1314
1415#include < ac/local/Service.hpp>
1516#include < ac/local/ServiceFactory.hpp>
3839#include " aclp-llama-version.h"
3940#include " aclp-llama-interface.hpp"
4041
41- // TODO: remove this include
42- #include < iostream>
43-
4442namespace ac ::local {
4543
4644namespace {
@@ -53,13 +51,12 @@ class ChatSession {
5351 const llama::Vocab& m_vocab;
5452 llama::Instance& m_instance;
5553 IoEndpoint& m_io;
56- std::string m_userPrefix;
57- std::string m_assistantPrefix;
5854
59- std::vector<llama::Token> m_promptTokens;
55+ std::string m_userPrefix;
56+ std::unique_ptr<llama::ChatFormat> m_chatFormat;
57+ std::vector<llama::ChatMsg> m_chatMessages;
58+ size_t m_submittedMessages = 0 ;
6059
61- bool m_addUserPrefix = true ;
62- bool m_addAssistantPrefix = true ;
6360public:
6461 using Schema = sc::StateChatInstance;
6562
@@ -69,50 +66,57 @@ class ChatSession {
6966 , m_instance(instance)
7067 , m_io(io)
7168 {
72- m_promptTokens = instance.model ().vocab ().tokenize (params.setup .value (), true , true );
73- m_session.setInitialPrompt (m_promptTokens);
69+ auto & chatTemplate = params.chatTemplate .value ();
70+ auto modelChatParams = llama::ChatFormat::getChatParams (instance.model ());
71+ if (chatTemplate.empty ()) {
72+ if (modelChatParams.chatTemplate .empty ()) {
73+ throw_ex{} << " The model does not have a default chat template, please provide one." ;
74+ }
75+
76+ m_chatFormat = std::make_unique<llama::ChatFormat>(modelChatParams.chatTemplate );
77+ } else {
78+ modelChatParams.chatTemplate = chatTemplate;
79+ m_chatFormat = std::make_unique<llama::ChatFormat>(std::move (modelChatParams));
80+ }
81+
82+ auto promptTokens = instance.model ().vocab ().tokenize (params.setup .value (), true , true );
83+ m_session.setInitialPrompt (promptTokens);
7484
7585 m_userPrefix = " \n " ;
7686 m_userPrefix += params.roleUser ;
7787 m_userPrefix += " :" ;
78- m_assistantPrefix = " \n " ;
79- m_assistantPrefix += params.roleAssistant ;
80- m_assistantPrefix += " :" ;
8188 }
8289
8390 ~ChatSession () {
8491 m_instance.stopSession ();
8592 }
8693
87- xec::coro<void > sendMessages (Schema::OpSendMessages ::Params& params) {
94+ xec::coro<void > addMessages (Schema::OpAddChatMessages ::Params& params) {
8895 auto & messages = params.messages .value ();
89- for (size_t i = 0 ; i < messages.size (); ++i) {
90- std::cout << messages[i].role .value () << " : " << messages[i].content .value () << " \n " ;
96+ std::vector<llama::Token> tokens;
97+
98+ for (const auto & message : messages) {
99+ m_chatMessages.push_back (llama::ChatMsg{
100+ .role = message.role .value (),
101+ .text = message.content .value ()
102+ });
91103 }
92104
93- co_await m_io.push (Frame_from (schema::SimpleOpReturn<Schema::OpSendMessages >{}, {}));
105+ co_await m_io.push (Frame_from (schema::SimpleOpReturn<Schema::OpAddChatMessages >{}, {}));
94106 }
95107
96- xec::coro<void > pushPrompt (Schema::OpAddChatPrompt::Params& params) {
97- auto & prompt = params.prompt .value ();
98-
99- // prefix with space as the generated content doesn't include it
100- prompt = ' ' + prompt;
101-
102- if (m_addUserPrefix) {
103- // we haven't had an interaction yet, so we need to add the user prefix
104- // subsequent interaction will have it generated
105- prompt = m_userPrefix + prompt;
108+ void submitPendingImages () {
109+ auto messagesToSubmit = m_chatMessages.size () - m_submittedMessages;
110+ std::string formatted;
111+ if (messagesToSubmit == 1 ) {
112+ formatted = m_chatFormat->formatMsg (
113+ m_chatMessages.back (), {m_chatMessages.begin (), m_chatMessages.end () - 1 }, true );
114+ } else {
115+ formatted = m_chatFormat->formatChat (
116+ {m_chatMessages.begin () + m_submittedMessages, m_chatMessages.end ()}, true );
106117 }
107118
108- // prepare for the next generation
109- prompt += m_assistantPrefix;
110-
111- m_promptTokens = m_vocab.tokenize (prompt, false , false );
112- m_session.pushPrompt (m_promptTokens);
113- m_addAssistantPrefix = false ;
114-
115- co_await m_io.push (Frame_from (schema::SimpleOpReturn<Schema::OpAddChatPrompt>{}, {}));
119+ m_session.pushPrompt (m_vocab.tokenize (formatted, true , true ));
116120 }
117121
118122 xec::coro<void > getResponse (Schema::ChatResponseParams params, bool isStreaming) {
@@ -122,18 +126,14 @@ class ChatSession {
122126 maxTokens = 1000 ;
123127 }
124128
125- if (m_addAssistantPrefix) {
126- // generated responses are requested first, but we haven't yet fed the assistant prefix to the model
127- auto prompt = m_assistantPrefix;
128- assert (m_promptTokens.empty ()); // nothing should be pending here
129- m_promptTokens = m_vocab.tokenize (prompt, false , false );
130- m_session.pushPrompt (m_promptTokens);
129+ if (m_submittedMessages != m_chatMessages.size ()) {
130+ submitPendingImages ();
131+ m_submittedMessages = m_chatMessages.size ();
131132 }
132133
133134 ac::llama::AntipromptManager antiprompt;
134135 antiprompt.addAntiprompt (m_userPrefix);
135136
136- m_addUserPrefix = true ;
137137 Schema::OpGetChatResponse::Return ret;
138138 auto & result = ret.response .materialize ();
139139
@@ -149,14 +149,10 @@ class ChatSession {
149149
150150 auto matchedAntiPrompt = antiprompt.feedGeneratedText (tokenStr);
151151 if (!matchedAntiPrompt.empty ()) {
152- // user prefix was added by generation, so don't add it again
153- m_addUserPrefix = false ;
154-
155152 // and also hide it from the return value
156153 // note that we assume that m_userPrefix is always the final piece of text in the response
157154 // TODO: update to better match the cutoff when issue #131 is done
158155 result.erase (result.size () - matchedAntiPrompt.size ());
159- m_addUserPrefix = false ;
160156 break ;
161157 }
162158
@@ -447,14 +443,12 @@ struct LocalLlama {
447443 Frame err;
448444
449445 try {
450- if (auto iparams = Frame_optTo (schema::OpParams<Schema::OpAddChatPrompt>{}, *f)) {
451- co_await chatSession.pushPrompt (*iparams);
452- } else if (auto iparams = Frame_optTo (schema::OpParams<Schema::OpGetChatResponse>{}, *f)) {
446+ if (auto iparams = Frame_optTo (schema::OpParams<Schema::OpGetChatResponse>{}, *f)) {
453447 co_await chatSession.getResponse (*iparams, false );
454448 } else if (auto iparams = Frame_optTo (schema::OpParams<Schema::OpStreamChatResponse>{}, *f)) {
455449 co_await chatSession.getResponse (*iparams, true );
456- } else if (auto iparams = Frame_optTo (schema::OpParams<Schema::OpSendMessages >{}, *f)) {
457- co_await chatSession.sendMessages (*iparams);
450+ } else if (auto iparams = Frame_optTo (schema::OpParams<Schema::OpAddChatMessages >{}, *f)) {
451+ co_await chatSession.addMessages (*iparams);
458452 } else {
459453 err = unknownOpError (*f);
460454 }
@@ -478,7 +472,6 @@ struct LocalLlama {
478472 lparams.vocabOnly = lmParams.vocabOnly .valueOr (false );
479473 lparams.prefixInputsWithBos = lmParams.prefixInputsWithBos .valueOr (false );
480474
481-
482475 auto model = m_resourceCache.getModel ({.gguf = gguf, .params = lparams});
483476
484477 std::vector<llama::ResourceCache::LoraLock> loras;
0 commit comments