55#include < string>
66#include < vector>
77
8+ // Add a message to `messages` and store its content in `owned_content`
9+ static void add_message (const std::string &role, const std::string &text,
10+ std::vector<llama_chat_message> &messages,
11+ std::vector<std::unique_ptr<char []>> &owned_content) {
12+ auto content = std::make_unique<char []>(text.size () + 1 );
13+ std::strcpy (content.get (), text.c_str ());
14+ messages.push_back ({role.c_str (), content.get ()});
15+ owned_content.push_back (std::move (content));
16+ }
17+
18+ // Function to apply the chat template and resize `formatted` if needed
19+ static int apply_chat_template (llama_model *model,
20+ const std::vector<llama_chat_message> &messages,
21+ std::vector<char > &formatted, bool append) {
22+ int result = llama_chat_apply_template (model, nullptr , messages.data (),
23+ messages.size (), append,
24+ formatted.data (), formatted.size ());
25+ if (result > static_cast <int >(formatted.size ())) {
26+ formatted.resize (result);
27+ result = llama_chat_apply_template (model, nullptr , messages.data (),
28+ messages.size (), append,
29+ formatted.data (), formatted.size ());
30+ }
31+
32+ return result;
33+ }
34+
835static void print_usage (int , char ** argv) {
936 printf (" \n example usage:\n " );
1037 printf (" \n %s -m model.gguf [-c context_size] [-ngl n_gpu_layers]\n " , argv[0 ]);
@@ -66,6 +93,7 @@ int main(int argc, char ** argv) {
6693 llama_model_params model_params = llama_model_default_params ();
6794 model_params.n_gpu_layers = ngl;
6895
96+ // This prints ........
6997 llama_model * model = llama_load_model_from_file (model_path.c_str (), model_params);
7098 if (!model) {
7199 fprintf (stderr , " %s: error: unable to load model\n " , __func__);
@@ -90,9 +118,7 @@ int main(int argc, char ** argv) {
90118 llama_sampler_chain_add (smpl, llama_sampler_init_dist (LLAMA_DEFAULT_SEED));
91119
92120 // helper function to evaluate a prompt and generate a response
93- auto generate = [&](const std::string & prompt) {
94- std::string response;
95-
121+ auto generate = [&](const std::string &prompt, std::string &response) {
96122 // tokenize the prompt
97123 const int n_prompt_tokens = -llama_tokenize (model, prompt.c_str (), prompt.size (), NULL , 0 , true , true );
98124 std::vector<llama_token> prompt_tokens (n_prompt_tokens);
@@ -110,7 +136,7 @@ int main(int argc, char ** argv) {
110136 if (n_ctx_used + batch.n_tokens > n_ctx) {
111137 printf (" \033 [0m\n " );
112138 fprintf (stderr, " context size exceeded\n " );
113- exit ( 0 ) ;
139+ return 1 ;
114140 }
115141
116142 if (llama_decode (ctx, batch)) {
@@ -140,55 +166,52 @@ int main(int argc, char ** argv) {
140166 batch = llama_batch_get_one (&new_token_id, 1 );
141167 }
142168
143- return response ;
169+ return 0 ;
144170 };
145171
146172 std::vector<llama_chat_message> messages;
173+ std::vector<std::unique_ptr<char []>> owned_content;
147174 std::vector<char > formatted (llama_n_ctx (ctx));
148175 int prev_len = 0 ;
149176 while (true ) {
150177 // get user input
151178 printf (" \033 [32m> \033 [0m" );
152179 std::string user;
153180 std::getline (std::cin, user);
154-
155181 if (user.empty ()) {
156182 break ;
157183 }
158184
159- // add the user input to the message list and format it
160- messages.push_back ({" user" , strdup (user.c_str ())});
161- int new_len = llama_chat_apply_template (model, nullptr , messages.data (), messages.size (), true , formatted.data (), formatted.size ());
162- if (new_len > (int )formatted.size ()) {
163- formatted.resize (new_len);
164- new_len = llama_chat_apply_template (model, nullptr , messages.data (), messages.size (), true , formatted.data (), formatted.size ());
165- }
185+ // Add user input to messages
186+ add_message (" user" , user, messages, owned_content);
187+ int new_len = apply_chat_template (model, messages, formatted, true );
166188 if (new_len < 0 ) {
167189 fprintf (stderr, " failed to apply the chat template\n " );
168190 return 1 ;
169191 }
170192
171- // remove previous messages to obtain the prompt to generate the response
172- std::string prompt (formatted.begin () + prev_len, formatted.begin () + new_len);
193+ // remove previous messages to obtain the prompt to generate the
194+ // response
195+ std::string prompt (formatted.begin () + prev_len,
196+ formatted.begin () + new_len);
173197
174198 // generate a response
175199 printf (" \033 [33m" );
176- std::string response = generate (prompt);
200+ std::string response;
201+ if (generate (prompt, response)) {
202+ return 1 ;
203+ }
204+
177205 printf (" \n\033 [0m" );
178206
179- // add the response to the messages
180- messages.push_back ({" assistant" , strdup (response.c_str ())});
181- prev_len = llama_chat_apply_template (model, nullptr , messages.data (), messages.size (), false , nullptr , 0 );
207+ // Add response to messages
208+ prev_len = apply_chat_template (model, messages, formatted, false );
182209 if (prev_len < 0 ) {
183210 fprintf (stderr, " failed to apply the chat template\n " );
184211 return 1 ;
185212 }
186213 }
187214
188- // free resources
189- for (auto & msg : messages) {
190- free (const_cast <char *>(msg.content ));
191- }
192215 llama_sampler_free (smpl);
193216 llama_free (ctx);
194217 llama_free_model (model);
0 commit comments