Skip to content

Commit 3037285

Browse files
committed
Use smart pointers in simple-chat
Avoid manual memory cleanups. Less memory leaks in the code now. Signed-off-by: Eric Curtin <[email protected]>
1 parent 2a82891 commit 3037285

File tree

2 files changed

+49
-41
lines changed

2 files changed

+49
-41
lines changed

examples/simple-chat/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ set(TARGET llama-simple-chat)
22
add_executable(${TARGET} simple-chat.cpp)
33
install(TARGETS ${TARGET} RUNTIME)
44
target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT})
5-
target_compile_features(${TARGET} PRIVATE cxx_std_11)
5+
target_compile_features(${TARGET} PRIVATE cxx_std_14)

examples/simple-chat/simple-chat.cpp

Lines changed: 48 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ int main(int argc, char ** argv) {
6666
llama_model_params model_params = llama_model_default_params();
6767
model_params.n_gpu_layers = ngl;
6868

69+
// This prints ........
6970
llama_model * model = llama_load_model_from_file(model_path.c_str(), model_params);
7071
if (!model) {
7172
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
@@ -144,51 +145,58 @@ int main(int argc, char ** argv) {
144145
};
145146

146147
std::vector<llama_chat_message> messages;
148+
std::vector<std::unique_ptr<char[]>> owned_content;
147149
std::vector<char> formatted(llama_n_ctx(ctx));
148150
int prev_len = 0;
149151
while (true) {
150-
// get user input
151-
printf("\033[32m> \033[0m");
152-
std::string user;
153-
std::getline(std::cin, user);
154-
155-
if (user.empty()) {
156-
break;
157-
}
158-
159-
// add the user input to the message list and format it
160-
messages.push_back({"user", strdup(user.c_str())});
161-
int new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size());
162-
if (new_len > (int)formatted.size()) {
163-
formatted.resize(new_len);
164-
new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size());
165-
}
166-
if (new_len < 0) {
167-
fprintf(stderr, "failed to apply the chat template\n");
168-
return 1;
169-
}
170-
171-
// remove previous messages to obtain the prompt to generate the response
172-
std::string prompt(formatted.begin() + prev_len, formatted.begin() + new_len);
173-
174-
// generate a response
175-
printf("\033[33m");
176-
std::string response = generate(prompt);
177-
printf("\n\033[0m");
178-
179-
// add the response to the messages
180-
messages.push_back({"assistant", strdup(response.c_str())});
181-
prev_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), false, nullptr, 0);
182-
if (prev_len < 0) {
183-
fprintf(stderr, "failed to apply the chat template\n");
184-
return 1;
185-
}
152+
// get user input
153+
printf("\033[32m> \033[0m");
154+
std::string user;
155+
std::getline(std::cin, user);
156+
if (user.empty()) {
157+
break;
158+
}
159+
160+
auto user_content = std::make_unique<char[]>(user.size() + 1);
161+
std::strcpy(user_content.get(), user.c_str());
162+
messages.push_back({"user", user_content.get()});
163+
owned_content.push_back(std::move(user_content)); // manage lifetime
164+
int new_len = llama_chat_apply_template(
165+
model, nullptr, messages.data(), messages.size(), true,
166+
formatted.data(), formatted.size());
167+
if (new_len > static_cast<int>(formatted.size())) {
168+
formatted.resize(new_len);
169+
new_len = llama_chat_apply_template(model, nullptr, messages.data(),
170+
messages.size(), true,
171+
formatted.data(), formatted.size());
172+
}
173+
if (new_len < 0) {
174+
fprintf(stderr, "failed to apply the chat template\n");
175+
return 1;
176+
}
177+
178+
// remove previous messages to obtain the prompt to generate the response
179+
std::string prompt(formatted.begin() + prev_len,
180+
formatted.begin() + new_len);
181+
182+
// generate a response
183+
printf("\033[33m");
184+
std::string response = generate(prompt);
185+
printf("\n\033[0m");
186+
187+
// Allocate memory and copy response, storing in `owned_content`
188+
auto response_content = std::make_unique<char[]>(response.size() + 1);
189+
std::strcpy(response_content.get(), response.c_str());
190+
messages.push_back({"assistant", response_content.get()});
191+
owned_content.push_back(std::move(response_content));
192+
prev_len = llama_chat_apply_template(model, nullptr, messages.data(),
193+
messages.size(), false, nullptr, 0);
194+
if (prev_len < 0) {
195+
fprintf(stderr, "failed to apply the chat template\n");
196+
return 1;
197+
}
186198
}
187199

188-
// free resources
189-
for (auto & msg : messages) {
190-
free(const_cast<char *>(msg.content));
191-
}
192200
llama_sampler_free(smpl);
193201
llama_free(ctx);
194202
llama_free_model(model);

0 commit comments

Comments
 (0)