| 
5 | 5 | #include <string>  | 
6 | 6 | #include <vector>  | 
7 | 7 | 
 
  | 
 | 8 | +// Add a message to `messages` and store its content in `owned_content`  | 
 | 9 | +static void add_message(const std::string &role, const std::string &text,  | 
 | 10 | +                        std::vector<llama_chat_message> &messages,  | 
 | 11 | +                        std::vector<std::unique_ptr<char[]>> &owned_content) {  | 
 | 12 | +    auto content = std::make_unique<char[]>(text.size() + 1);  | 
 | 13 | +    std::strcpy(content.get(), text.c_str());  | 
 | 14 | +    messages.push_back({role.c_str(), content.get()});  | 
 | 15 | +    owned_content.push_back(std::move(content));  | 
 | 16 | +}  | 
 | 17 | + | 
 | 18 | +// Function to apply the chat template and resize `formatted` if needed  | 
 | 19 | +static int apply_chat_template(llama_model *model,  | 
 | 20 | +                               const std::vector<llama_chat_message> &messages,  | 
 | 21 | +                               std::vector<char> &formatted, bool append) {  | 
 | 22 | +    int result = llama_chat_apply_template(model, nullptr, messages.data(),  | 
 | 23 | +                                           messages.size(), append,  | 
 | 24 | +                                           formatted.data(), formatted.size());  | 
 | 25 | +    if (result > static_cast<int>(formatted.size())) {  | 
 | 26 | +        formatted.resize(result);  | 
 | 27 | +        result = llama_chat_apply_template(model, nullptr, messages.data(),  | 
 | 28 | +                                           messages.size(), append,  | 
 | 29 | +                                           formatted.data(), formatted.size());  | 
 | 30 | +    }  | 
 | 31 | + | 
 | 32 | +    return result;  | 
 | 33 | +}  | 
 | 34 | + | 
8 | 35 | static void print_usage(int, char ** argv) {  | 
9 | 36 |     printf("\nexample usage:\n");  | 
10 | 37 |     printf("\n    %s -m model.gguf [-c context_size] [-ngl n_gpu_layers]\n", argv[0]);  | 
@@ -66,6 +93,7 @@ int main(int argc, char ** argv) {  | 
66 | 93 |     llama_model_params model_params = llama_model_default_params();  | 
67 | 94 |     model_params.n_gpu_layers = ngl;  | 
68 | 95 | 
 
  | 
 | 96 | +    // This prints ........  | 
69 | 97 |     llama_model * model = llama_load_model_from_file(model_path.c_str(), model_params);  | 
70 | 98 |     if (!model) {  | 
71 | 99 |         fprintf(stderr , "%s: error: unable to load model\n" , __func__);  | 
@@ -144,51 +172,44 @@ int main(int argc, char ** argv) {  | 
144 | 172 |     };  | 
145 | 173 | 
 
  | 
146 | 174 |     std::vector<llama_chat_message> messages;  | 
 | 175 | +    std::vector<std::unique_ptr<char[]>> owned_content;  | 
147 | 176 |     std::vector<char> formatted(llama_n_ctx(ctx));  | 
148 | 177 |     int prev_len = 0;  | 
149 | 178 |     while (true) {  | 
150 | 179 |         // get user input  | 
151 | 180 |         printf("\033[32m> \033[0m");  | 
152 | 181 |         std::string user;  | 
153 | 182 |         std::getline(std::cin, user);  | 
154 |  | - | 
155 | 183 |         if (user.empty()) {  | 
156 | 184 |             break;  | 
157 | 185 |         }  | 
158 | 186 | 
 
  | 
159 |  | -        // add the user input to the message list and format it  | 
160 |  | -        messages.push_back({"user", strdup(user.c_str())});  | 
161 |  | -        int new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size());  | 
162 |  | -        if (new_len > (int)formatted.size()) {  | 
163 |  | -            formatted.resize(new_len);  | 
164 |  | -            new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size());  | 
165 |  | -        }  | 
 | 187 | +        // Add user input to messages  | 
 | 188 | +        add_message("user", user, messages, owned_content);  | 
 | 189 | +        int new_len = apply_chat_template(model, messages, formatted, true);  | 
166 | 190 |         if (new_len < 0) {  | 
167 | 191 |             fprintf(stderr, "failed to apply the chat template\n");  | 
168 | 192 |             return 1;  | 
169 | 193 |         }  | 
170 | 194 | 
 
  | 
171 |  | -        // remove previous messages to obtain the prompt to generate the response  | 
172 |  | -        std::string prompt(formatted.begin() + prev_len, formatted.begin() + new_len);  | 
 | 195 | +        // remove previous messages to obtain the prompt to generate the  | 
 | 196 | +        // response  | 
 | 197 | +        std::string prompt(formatted.begin() + prev_len,  | 
 | 198 | +                           formatted.begin() + new_len);  | 
173 | 199 | 
 
  | 
174 | 200 |         // generate a response  | 
175 | 201 |         printf("\033[33m");  | 
176 | 202 |         std::string response = generate(prompt);  | 
177 | 203 |         printf("\n\033[0m");  | 
178 | 204 | 
 
  | 
179 |  | -        // add the response to the messages  | 
180 |  | -        messages.push_back({"assistant", strdup(response.c_str())});  | 
181 |  | -        prev_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), false, nullptr, 0);  | 
 | 205 | +        // Add response to messages  | 
 | 206 | +        prev_len = apply_chat_template(model, messages, formatted, false);  | 
182 | 207 |         if (prev_len < 0) {  | 
183 | 208 |             fprintf(stderr, "failed to apply the chat template\n");  | 
184 | 209 |             return 1;  | 
185 | 210 |         }  | 
186 | 211 |     }  | 
187 | 212 | 
 
  | 
188 |  | -    // free resources  | 
189 |  | -    for (auto & msg : messages) {  | 
190 |  | -        free(const_cast<char *>(msg.content));  | 
191 |  | -    }  | 
192 | 213 |     llama_sampler_free(smpl);  | 
193 | 214 |     llama_free(ctx);  | 
194 | 215 |     llama_free_model(model);  | 
 | 
0 commit comments