@@ -111,6 +111,7 @@ class JinjaImpl final : public ChatFormat::impl {
111111 JinjaImpl (ChatFormat::Params params)
112112 {
113113 m_templateStr = std::move (params.chatTemplate );
114+ m_assistantRole = std::move (params.roleAssistant );
114115
115116 try {
116117 m_minjaTemplate = std::make_unique<minja::chat_template>(m_templateStr, params.bosToken , params.eosToken );
@@ -121,9 +122,9 @@ class JinjaImpl final : public ChatFormat::impl {
121122
122123 ~JinjaImpl () {}
123124
124- virtual std::string formatChat (std::span<const ChatMsg> chat, bool /* addAssistantPrompt*/ ) const override {
125+ virtual std::string formatChat (std::span<const ChatMsg> chat, bool addAssistantPrompt) const override {
125126 auto [jChat, size] = ac2jsonChatMessages (chat);
126- return size == 0 ? std::string{} : applyJinja (jChat);
127+ return size == 0 ? std::string{} : applyJinja (jChat, addAssistantPrompt );
127128 }
128129
129130 virtual std::string formatMsg (const ChatMsg& msg, std::span<const ChatMsg> history, bool addAssistantPrompt) const override {
@@ -132,10 +133,10 @@ class JinjaImpl final : public ChatFormat::impl {
132133 }
133134
134135 auto [jchat, size] = ac2jsonChatMessages (history);
135- auto fmtHistory = applyJinja (jchat);
136+ auto fmtHistory = applyJinja (jchat, addAssistantPrompt );
136137
137138 jchat.push_back ({{" role" , msg.role }, {" content" , msg.text }});
138- auto fmtNew = applyJinja (jchat);
139+ auto fmtNew = applyJinja (jchat, addAssistantPrompt );
139140
140141 return fmtNew.substr (fmtHistory.size ());
141142 }
@@ -156,19 +157,22 @@ class JinjaImpl final : public ChatFormat::impl {
156157 return {messages, size};
157158 }
158159
159- std::string applyJinja (acnl::json jChat) const {
160+ std::string applyJinja (acnl::json jChat, bool addAssistantPrompt ) const {
160161 auto startsWith = [](const std::string& str, const std::string& prefix) {
161162 return str.rfind (prefix, 0 ) == 0 ;
162163 };
163164
164165 minja::chat_template_inputs tmpl_inputs;
165166 tmpl_inputs.messages = jChat;
167+ tmpl_inputs.add_generation_prompt = addAssistantPrompt;
168+ tmpl_inputs.extra_context = {
169+ {" assistant_role" , m_assistantRole}
170+ };
166171
167- minja::chat_template_options tmpl_opts;
168172 // To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
169173 // instead of using `chat_template_options.use_bos_token = false`, since these tokens
170174 // may be needed inside the template / between messages too.
171- auto result = m_minjaTemplate->apply (tmpl_inputs, tmpl_opts );
175+ auto result = m_minjaTemplate->apply (tmpl_inputs);
172176 if (startsWith (result, m_minjaTemplate->bos_token ())) {
173177 result = result.substr (m_minjaTemplate->bos_token ().size ());
174178 }
@@ -180,6 +184,7 @@ class JinjaImpl final : public ChatFormat::impl {
180184
181185 std::unique_ptr<minja::chat_template> m_minjaTemplate;
182186 std::string m_templateStr;
187+ std::string m_assistantRole;
183188};
184189
185190
0 commit comments