|
5 | 5 | #include <string> |
6 | 6 | #include <vector> |
7 | 7 |
|
| 8 | +// Add a message to `messages` and store its content in `owned_content` |
| 9 | +static void add_message(const std::string &role, const std::string &text, |
| 10 | + std::vector<llama_chat_message> &messages, |
| 11 | + std::vector<std::unique_ptr<char[]>> &owned_content) { |
| 12 | + auto content = std::make_unique<char[]>(text.size() + 1); |
| 13 | + std::strcpy(content.get(), text.c_str()); |
| 14 | + messages.push_back({role.c_str(), content.get()}); |
| 15 | + owned_content.push_back(std::move(content)); |
| 16 | +} |
| 17 | + |
| 18 | +// Function to apply the chat template and resize `formatted` if needed |
| 19 | +static int apply_chat_template(llama_model *model, |
| 20 | + const std::vector<llama_chat_message> &messages, |
| 21 | + std::vector<char> &formatted, bool append) { |
| 22 | + int result = llama_chat_apply_template(model, nullptr, messages.data(), |
| 23 | + messages.size(), append, |
| 24 | + formatted.data(), formatted.size()); |
| 25 | + if (result > static_cast<int>(formatted.size())) { |
| 26 | + formatted.resize(result); |
| 27 | + result = llama_chat_apply_template(model, nullptr, messages.data(), |
| 28 | + messages.size(), append, |
| 29 | + formatted.data(), formatted.size()); |
| 30 | + } |
| 31 | + |
| 32 | + return result; |
| 33 | +} |
| 34 | + |
8 | 35 | static void print_usage(int, char ** argv) { |
9 | 36 | printf("\nexample usage:\n"); |
10 | 37 | printf("\n %s -m model.gguf [-c context_size] [-ngl n_gpu_layers]\n", argv[0]); |
@@ -66,6 +93,7 @@ int main(int argc, char ** argv) { |
66 | 93 | llama_model_params model_params = llama_model_default_params(); |
67 | 94 | model_params.n_gpu_layers = ngl; |
68 | 95 |
|
| 96 | + // This prints ........ |
69 | 97 | llama_model * model = llama_load_model_from_file(model_path.c_str(), model_params); |
70 | 98 | if (!model) { |
71 | 99 | fprintf(stderr , "%s: error: unable to load model\n" , __func__); |
@@ -144,51 +172,43 @@ int main(int argc, char ** argv) { |
144 | 172 | }; |
145 | 173 |
|
146 | 174 | std::vector<llama_chat_message> messages; |
| 175 | + std::vector<std::unique_ptr<char[]>> owned_content; |
147 | 176 | std::vector<char> formatted(llama_n_ctx(ctx)); |
148 | 177 | int prev_len = 0; |
149 | 178 | while (true) { |
150 | | - // get user input |
151 | | - printf("\033[32m> \033[0m"); |
152 | | - std::string user; |
153 | | - std::getline(std::cin, user); |
154 | | - |
155 | | - if (user.empty()) { |
156 | | - break; |
157 | | - } |
158 | | - |
159 | | - // add the user input to the message list and format it |
160 | | - messages.push_back({"user", strdup(user.c_str())}); |
161 | | - int new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); |
162 | | - if (new_len > (int)formatted.size()) { |
163 | | - formatted.resize(new_len); |
164 | | - new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); |
165 | | - } |
166 | | - if (new_len < 0) { |
167 | | - fprintf(stderr, "failed to apply the chat template\n"); |
168 | | - return 1; |
169 | | - } |
| 179 | + // get user input |
| 180 | + printf("\033[32m> \033[0m"); |
| 181 | + std::string user; |
| 182 | + std::getline(std::cin, user); |
| 183 | + if (user.empty()) { |
| 184 | + break; |
| 185 | + } |
| 186 | + |
| 187 | + // Add user input to messages |
| 188 | + add_message("user", user, messages, owned_content); |
| 189 | + int new_len = apply_chat_template(model, messages, formatted, true); |
| 190 | + if (new_len < 0) { |
| 191 | + fprintf(stderr, "failed to apply the chat template\n"); |
| 192 | + return 1; |
| 193 | + } |
170 | 194 |
|
171 | | - // remove previous messages to obtain the prompt to generate the response |
172 | | - std::string prompt(formatted.begin() + prev_len, formatted.begin() + new_len); |
| 195 | + // remove previous messages to obtain the prompt to generate the response |
| 196 | + std::string prompt(formatted.begin() + prev_len, |
| 197 | + formatted.begin() + new_len); |
173 | 198 |
|
174 | | - // generate a response |
175 | | - printf("\033[33m"); |
176 | | - std::string response = generate(prompt); |
177 | | - printf("\n\033[0m"); |
| 199 | + // generate a response |
| 200 | + printf("\033[33m"); |
| 201 | + std::string response = generate(prompt); |
| 202 | + printf("\n\033[0m"); |
178 | 203 |
|
179 | | - // add the response to the messages |
180 | | - messages.push_back({"assistant", strdup(response.c_str())}); |
181 | | - prev_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), false, nullptr, 0); |
182 | | - if (prev_len < 0) { |
183 | | - fprintf(stderr, "failed to apply the chat template\n"); |
184 | | - return 1; |
185 | | - } |
| 204 | + // Add response to messages |
| 205 | + prev_len = apply_chat_template(model, messages, formatted, false); |
| 206 | + if (prev_len < 0) { |
| 207 | + fprintf(stderr, "failed to apply the chat template\n"); |
| 208 | + return 1; |
| 209 | + } |
186 | 210 | } |
187 | 211 |
|
188 | | - // free resources |
189 | | - for (auto & msg : messages) { |
190 | | - free(const_cast<char *>(msg.content)); |
191 | | - } |
192 | 212 | llama_sampler_free(smpl); |
193 | 213 | llama_free(ctx); |
194 | 214 | llama_free_model(model); |
|
0 commit comments