@@ -12,9 +12,7 @@ static void print_usage(int, char ** argv) {
1212}
1313
1414int main (int argc, char ** argv) {
15- // path to the model gguf file
1615 std::string model_path;
17- // number of layers to offload to the GPU
1816 int ngl = 99 ;
1917 int n_ctx = 2048 ;
2018
@@ -91,13 +89,13 @@ int main(int argc, char ** argv) {
9189 llama_sampler_chain_add (smpl, llama_sampler_init_temp (0 .8f ));
9290 llama_sampler_chain_add (smpl, llama_sampler_init_dist (LLAMA_DEFAULT_SEED));
9391
94- // generation helper
92+ // helper function to evaluate a prompt and generate a response
9593 auto generate = [&](const std::string & prompt) {
9694 std::string response;
9795
9896 // tokenize the prompt
99- const int n_prompt = -llama_tokenize (model, prompt.c_str (), prompt.size (), NULL , 0 , true , true );
100- std::vector<llama_token> prompt_tokens (n_prompt );
97+ const int n_prompt_tokens = -llama_tokenize (model, prompt.c_str (), prompt.size (), NULL , 0 , true , true );
98+ std::vector<llama_token> prompt_tokens (n_prompt_tokens );
10199 if (llama_tokenize (model, prompt.c_str (), prompt.size (), prompt_tokens.data (), prompt_tokens.size (), true , true ) < 0 ) {
102100 GGML_ABORT (" failed to tokenize the prompt\n " );
103101 }
@@ -106,7 +104,7 @@ int main(int argc, char ** argv) {
106104 llama_batch batch = llama_batch_get_one (prompt_tokens.data (), prompt_tokens.size ());
107105 llama_token new_token_id;
108106 while (true ) {
109- // check if we have enough context space to evaluate this batch
107+ // check if we have enough space in the context to evaluate this batch
110108 int n_ctx = llama_n_ctx (ctx);
111109 int n_ctx_used = llama_get_kv_cache_used_cells (ctx);
112110 if (n_ctx_used + batch.n_tokens > n_ctx) {
@@ -116,7 +114,7 @@ int main(int argc, char ** argv) {
116114 }
117115
118116 if (llama_decode (ctx, batch)) {
119- GGML_ABORT (" failed to eval \n " );
117+ GGML_ABORT (" failed to decode \n " );
120118 }
121119
122120 // sample the next token
@@ -127,16 +125,16 @@ int main(int argc, char ** argv) {
127125 break ;
128126 }
129127
130- // add the token to the response
131- char buf[128 ];
128+ // convert the token to a string, print it and add it to the response
129+ char buf[256 ];
132130 int n = llama_token_to_piece (model, new_token_id, buf, sizeof (buf), 0 , true );
133131 if (n < 0 ) {
134132 GGML_ABORT (" failed to convert token to piece\n " );
135133 }
136134 std::string piece (buf, n);
137- response += piece;
138135 printf (" %s" , piece.c_str ());
139136 fflush (stdout);
137+ response += piece;
140138
141139 // prepare the next batch with the sampled token
142140 batch = llama_batch_get_one (&new_token_id, 1 );
@@ -146,34 +144,51 @@ int main(int argc, char ** argv) {
146144 };
147145
148146 std::vector<llama_chat_message> messages;
149- std::vector<char > formatted (2048 );
147+ std::vector<char > formatted (llama_n_ctx (ctx) );
150148 int prev_len = 0 ;
151149 while (true ) {
150+ // get user input
151+ printf (" \033 [32m> \033 [0m" );
152152 std::string user;
153153 std::getline (std::cin, user);
154- messages.push_back ({" user" , strdup (user.c_str ())});
155154
156- // format the messages
155+ if (user.empty ()) {
156+ break ;
157+ }
158+
159+ // add the user input to the message list and format it
160+ messages.push_back ({" user" , strdup (user.c_str ())});
157161 int new_len = llama_chat_apply_template (model, nullptr , messages.data (), messages.size (), true , formatted.data (), formatted.size ());
158162 if (new_len > (int )formatted.size ()) {
159163 formatted.resize (new_len);
160164 new_len = llama_chat_apply_template (model, nullptr , messages.data (), messages.size (), true , formatted.data (), formatted.size ());
161165 }
166+ if (new_len < 0 ) {
167+ fprintf (stderr, " failed to apply the chat template\n " );
168+ return 1 ;
169+ }
162170
163- // remove previous messages and obtain a prompt
171+ // remove previous messages to obtain the prompt to generate the response
164172 std::string prompt (formatted.begin () + prev_len, formatted.begin () + new_len);
165173
166174 // generate a response
167- printf (" \033 [31m " );
175+ printf (" \033 [33m " );
168176 std::string response = generate (prompt);
169177 printf (" \n\033 [0m" );
170178
171179 // add the response to the messages
172180 messages.push_back ({" assistant" , strdup (response.c_str ())});
173181 prev_len = llama_chat_apply_template (model, nullptr , messages.data (), messages.size (), false , formatted.data (), formatted.size ());
182+ if (prev_len < 0 ) {
183+ fprintf (stderr, " failed to apply the chat template\n " );
184+ return 1 ;
185+ }
174186 }
175187
176-
188+ // free resources
189+ for (auto & msg : messages) {
190+ free (const_cast <char *>(msg.content ));
191+ }
177192 llama_sampler_free (smpl);
178193 llama_free (ctx);
179194 llama_free_model (model);
0 commit comments