Skip to content

Commit c3cbf90

Browse files
committed
Use smart pointers in simple-chat
Avoid manual memory cleanups. Less memory leaks in the code now. Avoid printing multiple dots. Split code into smaller functions. Use C-style IO, rather than a mix of C++ streams and C style. No exception handling. Signed-off-by: Eric Curtin <[email protected]>
1 parent 1842922 commit c3cbf90

File tree

5 files changed

+351
-0
lines changed

5 files changed

+351
-0
lines changed

Makefile

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ BUILD_TARGETS = \
3434
llama-server \
3535
llama-simple \
3636
llama-simple-chat \
37+
llama-ramalama-core \
3738
llama-speculative \
3839
llama-tokenize \
3940
llama-vdot \
@@ -1382,6 +1383,11 @@ llama-infill: examples/infill/infill.cpp \
13821383
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
13831384
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
13841385

1386+
llama-ramalama-core: examples/ramalama-core/ramalama-core.cpp \
1387+
$(OBJ_ALL)
1388+
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
1389+
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
1390+
13851391
llama-simple: examples/simple/simple.cpp \
13861392
$(OBJ_ALL)
13871393
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ else()
4747
add_subdirectory(sycl)
4848
endif()
4949
add_subdirectory(save-load-state)
50+
add_subdirectory(ramalama-core)
5051
add_subdirectory(simple)
5152
add_subdirectory(simple-chat)
5253
add_subdirectory(speculative)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(TARGET llama-ramalama-core)
2+
add_executable(${TARGET} ramalama-core.cpp)
3+
install(TARGETS ${TARGET} RUNTIME)
4+
target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_features(${TARGET} PRIVATE cxx_std_11)

examples/ramalama-core/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# llama.cpp/example/ramalama-core
2+
3+
The purpose of this example is to demonstrate a minimal usage of llama.cpp for `ramalama run` in the RamaLama project. Other may find it useful also for purposes outside RamaLama of directly.
4+
5+
```bash
6+
./llama-ramalama-core -m Meta-Llama-3.1-8B-Instruct.gguf -c 2048
7+
...
Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
#include <climits>
2+
#include <cstdio>
3+
#include <iostream>
4+
#include <memory>
5+
#include <string>
6+
#include <vector>
7+
8+
#include "llama.h"
9+
10+
// Add a message to `messages` and store its content in `owned_content`
11+
static void add_message(const std::string & role, const std::string & text, std::vector<llama_chat_message> & messages,
12+
std::vector<std::unique_ptr<char[]>> & owned_content) {
13+
auto content = std::unique_ptr<char[]>(new char[text.size() + 1]);
14+
std::strcpy(content.get(), text.c_str());
15+
messages.push_back({role.c_str(), content.get()});
16+
owned_content.push_back(std::move(content));
17+
}
18+
19+
// Function to apply the chat template and resize `formatted` if needed
20+
static int apply_chat_template(const llama_model * model, const std::vector<llama_chat_message> & messages,
21+
std::vector<char> & formatted, const bool append) {
22+
int result = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), append, formatted.data(),
23+
formatted.size());
24+
if (result > static_cast<int>(formatted.size())) {
25+
formatted.resize(result);
26+
result = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), append, formatted.data(),
27+
formatted.size());
28+
}
29+
30+
return result;
31+
}
32+
33+
// Function to tokenize the prompt
34+
static int tokenize_prompt(const llama_model * model, const std::string & prompt,
35+
std::vector<llama_token> & prompt_tokens) {
36+
const int n_prompt_tokens = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true);
37+
prompt_tokens.resize(n_prompt_tokens);
38+
if (llama_tokenize(model, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) <
39+
0) {
40+
GGML_ABORT("failed to tokenize the prompt\n");
41+
}
42+
43+
return n_prompt_tokens;
44+
}
45+
46+
// Check if we have enough space in the context to evaluate this batch
47+
static int check_context_size(const llama_context * ctx, const llama_batch & batch) {
48+
const int n_ctx = llama_n_ctx(ctx);
49+
const int n_ctx_used = llama_get_kv_cache_used_cells(ctx);
50+
if (n_ctx_used + batch.n_tokens > n_ctx) {
51+
printf("\033[0m\n");
52+
fprintf(stderr, "context size exceeded\n");
53+
return 1;
54+
}
55+
56+
return 0;
57+
}
58+
59+
// convert the token to a string
60+
static int convert_token_to_string(const llama_model * model, const llama_token token_id, std::string & piece) {
61+
char buf[256];
62+
int n = llama_token_to_piece(model, token_id, buf, sizeof(buf), 0, true);
63+
if (n < 0) {
64+
GGML_ABORT("failed to convert token to piece\n");
65+
}
66+
67+
piece = std::string(buf, n);
68+
return 0;
69+
}
70+
71+
static void print_word_and_concatenate_to_response(const std::string & piece, std::string & response) {
72+
printf("%s", piece.c_str());
73+
fflush(stdout);
74+
response += piece;
75+
}
76+
77+
// helper function to evaluate a prompt and generate a response
78+
static int generate(const llama_model * model, llama_sampler * smpl, llama_context * ctx, const std::string & prompt,
79+
std::string & response) {
80+
std::vector<llama_token> prompt_tokens;
81+
const int n_prompt_tokens = tokenize_prompt(model, prompt, prompt_tokens);
82+
if (n_prompt_tokens < 0) {
83+
return 1;
84+
}
85+
86+
// prepare a batch for the prompt
87+
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
88+
llama_token new_token_id;
89+
while (true) {
90+
check_context_size(ctx, batch);
91+
if (llama_decode(ctx, batch)) {
92+
GGML_ABORT("failed to decode\n");
93+
}
94+
95+
// sample the next token, check is it an end of generation?
96+
new_token_id = llama_sampler_sample(smpl, ctx, -1);
97+
if (llama_token_is_eog(model, new_token_id)) {
98+
break;
99+
}
100+
101+
std::string piece;
102+
if (convert_token_to_string(model, new_token_id, piece)) {
103+
return 1;
104+
}
105+
106+
print_word_and_concatenate_to_response(piece, response);
107+
108+
// prepare the next batch with the sampled token
109+
batch = llama_batch_get_one(&new_token_id, 1);
110+
}
111+
112+
return 0;
113+
}
114+
115+
static void print_usage(int, const char ** argv) {
116+
printf("\nexample usage:\n");
117+
printf("\n %s -m model.gguf [-c context_size] [-ngl n_gpu_layers]\n", argv[0]);
118+
printf("\n");
119+
}
120+
121+
static int parse_int_arg(const char * arg, int & value) {
122+
char * end;
123+
long val = std::strtol(arg, &end, 10);
124+
if (*end == '\0' && val >= INT_MIN && val <= INT_MAX) {
125+
value = static_cast<int>(val);
126+
return 0;
127+
}
128+
129+
return 1;
130+
}
131+
132+
static int handle_model_path(const int argc, const char ** argv, int & i, std::string & model_path) {
133+
if (i + 1 < argc) {
134+
model_path = argv[++i];
135+
return 0;
136+
}
137+
138+
print_usage(argc, argv);
139+
return 1;
140+
}
141+
142+
static int handle_n_ctx(const int argc, const char ** argv, int & i, int & n_ctx) {
143+
if (i + 1 < argc) {
144+
if (parse_int_arg(argv[++i], n_ctx)) {
145+
return 0;
146+
} else {
147+
fprintf(stderr, "error: invalid value for -c: %s\n", argv[i]);
148+
print_usage(argc, argv);
149+
}
150+
} else {
151+
print_usage(argc, argv);
152+
}
153+
154+
return 1;
155+
}
156+
157+
static int handle_ngl(const int argc, const char ** argv, int & i, int & ngl) {
158+
if (i + 1 < argc) {
159+
if (parse_int_arg(argv[++i], ngl)) {
160+
return 0;
161+
} else {
162+
fprintf(stderr, "error: invalid value for -ngl: %s\n", argv[i]);
163+
print_usage(argc, argv);
164+
}
165+
} else {
166+
print_usage(argc, argv);
167+
}
168+
169+
return 1;
170+
}
171+
172+
static int parse_arguments(const int argc, const char ** argv, std::string & model_path, int & n_ctx, int & ngl) {
173+
for (int i = 1; i < argc; ++i) {
174+
if (strcmp(argv[i], "-m") == 0) {
175+
if (handle_model_path(argc, argv, i, model_path)) {
176+
return 1;
177+
}
178+
} else if (strcmp(argv[i], "-c") == 0) {
179+
if (handle_n_ctx(argc, argv, i, n_ctx)) {
180+
return 1;
181+
}
182+
} else if (strcmp(argv[i], "-ngl") == 0) {
183+
if (handle_ngl(argc, argv, i, ngl)) {
184+
return 1;
185+
}
186+
} else {
187+
print_usage(argc, argv);
188+
return 1;
189+
}
190+
}
191+
192+
if (model_path.empty()) {
193+
print_usage(argc, argv);
194+
return 1;
195+
}
196+
197+
return 0;
198+
}
199+
200+
static int read_user_input(std::string & user) {
201+
std::getline(std::cin, user);
202+
return user.empty(); // Indicate an error or empty input
203+
}
204+
205+
// Function to generate a response based on the prompt
206+
static int generate_response(llama_model * model, llama_sampler * sampler, llama_context * context,
207+
const std::string & prompt, std::string & response) {
208+
// Set response color
209+
printf("\033[33m");
210+
if (generate(model, sampler, context, prompt, response)) {
211+
fprintf(stderr, "failed to generate response\n");
212+
return 1;
213+
}
214+
215+
// End response with color reset and newline
216+
printf("\n\033[0m");
217+
return 0;
218+
}
219+
220+
// The main chat loop where user inputs are processed and responses generated.
221+
static int chat_loop(llama_model * model, llama_sampler * sampler, llama_context * context,
222+
std::vector<llama_chat_message> & messages) {
223+
std::vector<std::unique_ptr<char[]>> owned_content;
224+
std::vector<char> formatted(llama_n_ctx(context));
225+
int prev_len = 0;
226+
227+
while (true) {
228+
// Print prompt for user input
229+
printf("\033[32m> \033[0m");
230+
std::string user;
231+
if (read_user_input(user)) {
232+
break;
233+
}
234+
235+
add_message("user", user, messages, owned_content);
236+
int new_len = apply_chat_template(model, messages, formatted, true);
237+
if (new_len < 0) {
238+
fprintf(stderr, "failed to apply the chat template\n");
239+
return 1;
240+
}
241+
242+
std::string prompt(formatted.begin() + prev_len, formatted.begin() + new_len);
243+
std::string response;
244+
if (generate_response(model, sampler, context, prompt, response)) {
245+
return 1;
246+
}
247+
248+
add_message("assistant", response, messages, owned_content);
249+
prev_len = apply_chat_template(model, messages, formatted, false);
250+
if (prev_len < 0) {
251+
fprintf(stderr, "failed to apply the chat template\n");
252+
return 1;
253+
}
254+
}
255+
256+
return 0;
257+
}
258+
259+
static void log_callback(const enum ggml_log_level level, const char * text, void *) {
260+
if (level == GGML_LOG_LEVEL_ERROR) {
261+
fprintf(stderr, "%s", text);
262+
}
263+
}
264+
265+
// Initializes the model and returns a unique pointer to it.
266+
static std::unique_ptr<llama_model, decltype(&llama_free_model)> initialize_model(const std::string & model_path,
267+
int ngl) {
268+
llama_model_params model_params = llama_model_default_params();
269+
model_params.n_gpu_layers = ngl;
270+
271+
auto model = std::unique_ptr<llama_model, decltype(&llama_free_model)>(
272+
llama_load_model_from_file(model_path.c_str(), model_params), llama_free_model);
273+
if (!model) {
274+
fprintf(stderr, "%s: error: unable to load model\n", __func__);
275+
}
276+
277+
return model;
278+
}
279+
280+
// Initializes the context with the specified parameters.
281+
static std::unique_ptr<llama_context, decltype(&llama_free)> initialize_context(llama_model * model, int n_ctx) {
282+
llama_context_params ctx_params = llama_context_default_params();
283+
ctx_params.n_ctx = n_ctx;
284+
ctx_params.n_batch = n_ctx;
285+
286+
auto context = std::unique_ptr<llama_context, decltype(&llama_free)>(
287+
llama_new_context_with_model(model, ctx_params), llama_free);
288+
if (!context) {
289+
fprintf(stderr, "%s: error: failed to create the llama_context\n", __func__);
290+
}
291+
292+
return context;
293+
}
294+
295+
// Initializes and configures the sampler.
296+
static std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)> initialize_sampler() {
297+
auto sampler = std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>(
298+
llama_sampler_chain_init(llama_sampler_chain_default_params()), llama_sampler_free);
299+
llama_sampler_chain_add(sampler.get(), llama_sampler_init_min_p(0.05f, 1));
300+
llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(0.8f));
301+
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
302+
303+
return sampler;
304+
}
305+
306+
int main(int argc, const char ** argv) {
307+
std::string model_path;
308+
int ngl = 99;
309+
int n_ctx = 2048;
310+
if (parse_arguments(argc, argv, model_path, n_ctx, ngl)) {
311+
return 1;
312+
}
313+
314+
llama_log_set(log_callback, nullptr);
315+
auto model = initialize_model(model_path, ngl);
316+
if (!model) {
317+
return 1;
318+
}
319+
320+
auto context = initialize_context(model.get(), n_ctx);
321+
if (!context) {
322+
return 1;
323+
}
324+
325+
auto sampler = initialize_sampler();
326+
std::vector<llama_chat_message> messages;
327+
if (chat_loop(model.get(), sampler.get(), context.get(), messages)) {
328+
return 1;
329+
}
330+
331+
return 0;
332+
}

0 commit comments

Comments
 (0)