Skip to content

Commit 152158c

Browse files
committed
Use smart pointers in simple-chat
Avoid manual memory cleanups. Less memory leaks in the code now. Signed-off-by: Eric Curtin <[email protected]>
1 parent 2a82891 commit 152158c

File tree

2 files changed

+39
-18
lines changed

2 files changed

+39
-18
lines changed

examples/simple-chat/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ set(TARGET llama-simple-chat)
22
add_executable(${TARGET} simple-chat.cpp)
33
install(TARGETS ${TARGET} RUNTIME)
44
target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT})
5-
target_compile_features(${TARGET} PRIVATE cxx_std_11)
5+
target_compile_features(${TARGET} PRIVATE cxx_std_14)

examples/simple-chat/simple-chat.cpp

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,33 @@
55
#include <string>
66
#include <vector>
77

8+
// Add a message to `messages` and store its content in `owned_content`
9+
static void add_message(const std::string &role, const std::string &text,
10+
std::vector<llama_chat_message> &messages,
11+
std::vector<std::unique_ptr<char[]>> &owned_content) {
12+
auto content = std::make_unique<char[]>(text.size() + 1);
13+
std::strcpy(content.get(), text.c_str());
14+
messages.push_back({role.c_str(), content.get()});
15+
owned_content.push_back(std::move(content));
16+
}
17+
18+
// Function to apply the chat template and resize `formatted` if needed
19+
static int apply_chat_template(llama_model *model,
20+
const std::vector<llama_chat_message> &messages,
21+
std::vector<char> &formatted, bool append) {
22+
int result = llama_chat_apply_template(model, nullptr, messages.data(),
23+
messages.size(), append,
24+
formatted.data(), formatted.size());
25+
if (result > static_cast<int>(formatted.size())) {
26+
formatted.resize(result);
27+
result = llama_chat_apply_template(model, nullptr, messages.data(),
28+
messages.size(), append,
29+
formatted.data(), formatted.size());
30+
}
31+
32+
return result;
33+
}
34+
835
static void print_usage(int, char ** argv) {
936
printf("\nexample usage:\n");
1037
printf("\n %s -m model.gguf [-c context_size] [-ngl n_gpu_layers]\n", argv[0]);
@@ -66,6 +93,7 @@ int main(int argc, char ** argv) {
6693
llama_model_params model_params = llama_model_default_params();
6794
model_params.n_gpu_layers = ngl;
6895

96+
// This prints ........
6997
llama_model * model = llama_load_model_from_file(model_path.c_str(), model_params);
7098
if (!model) {
7199
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
@@ -144,51 +172,44 @@ int main(int argc, char ** argv) {
144172
};
145173

146174
std::vector<llama_chat_message> messages;
175+
std::vector<std::unique_ptr<char[]>> owned_content;
147176
std::vector<char> formatted(llama_n_ctx(ctx));
148177
int prev_len = 0;
149178
while (true) {
150179
// get user input
151180
printf("\033[32m> \033[0m");
152181
std::string user;
153182
std::getline(std::cin, user);
154-
155183
if (user.empty()) {
156184
break;
157185
}
158186

159-
// add the user input to the message list and format it
160-
messages.push_back({"user", strdup(user.c_str())});
161-
int new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size());
162-
if (new_len > (int)formatted.size()) {
163-
formatted.resize(new_len);
164-
new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size());
165-
}
187+
// Add user input to messages
188+
add_message("user", user, messages, owned_content);
189+
int new_len = apply_chat_template(model, messages, formatted, true);
166190
if (new_len < 0) {
167191
fprintf(stderr, "failed to apply the chat template\n");
168192
return 1;
169193
}
170194

171-
// remove previous messages to obtain the prompt to generate the response
172-
std::string prompt(formatted.begin() + prev_len, formatted.begin() + new_len);
195+
// remove previous messages to obtain the prompt to generate the
196+
// response
197+
std::string prompt(formatted.begin() + prev_len,
198+
formatted.begin() + new_len);
173199

174200
// generate a response
175201
printf("\033[33m");
176202
std::string response = generate(prompt);
177203
printf("\n\033[0m");
178204

179-
// add the response to the messages
180-
messages.push_back({"assistant", strdup(response.c_str())});
181-
prev_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), false, nullptr, 0);
205+
// Add response to messages
206+
prev_len = apply_chat_template(model, messages, formatted, false);
182207
if (prev_len < 0) {
183208
fprintf(stderr, "failed to apply the chat template\n");
184209
return 1;
185210
}
186211
}
187212

188-
// free resources
189-
for (auto & msg : messages) {
190-
free(const_cast<char *>(msg.content));
191-
}
192213
llama_sampler_free(smpl);
193214
llama_free(ctx);
194215
llama_free_model(model);

0 commit comments

Comments
 (0)