Skip to content

Commit e6862cf

Browse files
committed
llama : add simple-chat example
1 parent 85679d3 commit e6862cf

File tree

6 files changed

+218
-4
lines changed

6 files changed

+218
-4
lines changed

Makefile

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1287,6 +1287,11 @@ llama-simple: examples/simple/simple.cpp \
12871287
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
12881288
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
12891289

1290+
llama-simple-chat: examples/simple-chat/simple-chat.cpp \
1291+
$(OBJ_ALL)
1292+
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
1293+
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
1294+
12901295
llama-tokenize: examples/tokenize/tokenize.cpp \
12911296
$(OBJ_ALL)
12921297
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ else()
4949
endif()
5050
add_subdirectory(save-load-state)
5151
add_subdirectory(simple)
52+
add_subdirectory(simple-chat)
5253
add_subdirectory(speculative)
5354
add_subdirectory(tokenize)
5455
endif()
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(TARGET llama-simple-chat)
2+
add_executable(${TARGET} simple-chat.cpp)
3+
install(TARGETS ${TARGET} RUNTIME)
4+
target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_features(${TARGET} PRIVATE cxx_std_11)

examples/simple-chat/README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# llama.cpp/example/simple
2+
3+
The purpose of this example is to demonstrate a minimal usage of llama.cpp for generating text with a given prompt.
4+
5+
```bash
6+
./llama-simple -m ./models/llama-7b-v2/ggml-model-f16.gguf -p "Hello my name is"
7+
8+
...
9+
10+
main: n_len = 32, n_ctx = 2048, n_parallel = 1, n_kv_req = 32
11+
12+
Hello my name is Shawn and I'm a 20 year old male from the United States. I'm a 20 year old
13+
14+
main: decoded 27 tokens in 2.31 s, speed: 11.68 t/s
15+
16+
llama_print_timings: load time = 579.15 ms
17+
llama_print_timings: sample time = 0.72 ms / 28 runs ( 0.03 ms per token, 38888.89 tokens per second)
18+
llama_print_timings: prompt eval time = 655.63 ms / 10 tokens ( 65.56 ms per token, 15.25 tokens per second)
19+
llama_print_timings: eval time = 2180.97 ms / 27 runs ( 80.78 ms per token, 12.38 tokens per second)
20+
llama_print_timings: total time = 2891.13 ms
21+
```
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
#include "llama.h"
2+
#include <cstdio>
3+
#include <cstring>
4+
#include <iostream>
5+
#include <string>
6+
#include <vector>
7+
8+
static void print_usage(int, char ** argv) {
9+
printf("\nexample usage:\n");
10+
printf("\n %s -m model.gguf [-c context_size] [-ngl n_gpu_layers]\n", argv[0]);
11+
printf("\n");
12+
}
13+
14+
int main(int argc, char ** argv) {
15+
// path to the model gguf file
16+
std::string model_path;
17+
// number of layers to offload to the GPU
18+
int ngl = 99;
19+
int n_ctx = 2048;
20+
21+
// parse command line arguments
22+
for (int i = 1; i < argc; i++) {
23+
try {
24+
if (strcmp(argv[i], "-m") == 0) {
25+
if (i + 1 < argc) {
26+
model_path = argv[++i];
27+
} else {
28+
print_usage(argc, argv);
29+
return 1;
30+
}
31+
} else if (strcmp(argv[i], "-c") == 0) {
32+
if (i + 1 < argc) {
33+
n_ctx = std::stoi(argv[++i]);
34+
} else {
35+
print_usage(argc, argv);
36+
return 1;
37+
}
38+
} else if (strcmp(argv[i], "-ngl") == 0) {
39+
if (i + 1 < argc) {
40+
ngl = std::stoi(argv[++i]);
41+
} else {
42+
print_usage(argc, argv);
43+
return 1;
44+
}
45+
} else {
46+
print_usage(argc, argv);
47+
return 1;
48+
}
49+
} catch (std::exception & e) {
50+
fprintf(stderr, "error: %s\n", e.what());
51+
print_usage(argc, argv);
52+
return 1;
53+
}
54+
}
55+
if (model_path.empty()) {
56+
print_usage(argc, argv);
57+
return 1;
58+
}
59+
60+
// only print errors
61+
llama_log_set([](enum ggml_log_level level, const char * text, void * /* user_data */) {
62+
if (level >= GGML_LOG_LEVEL_ERROR) {
63+
fprintf(stderr, "%s", text);
64+
}
65+
}, nullptr);
66+
67+
// initialize the model
68+
llama_model_params model_params = llama_model_default_params();
69+
model_params.n_gpu_layers = ngl;
70+
71+
llama_model * model = llama_load_model_from_file(model_path.c_str(), model_params);
72+
if (!model) {
73+
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
74+
return 1;
75+
}
76+
77+
// initialize the context
78+
llama_context_params ctx_params = llama_context_default_params();
79+
ctx_params.n_ctx = n_ctx;
80+
ctx_params.n_batch = n_ctx;
81+
82+
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
83+
if (!ctx) {
84+
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
85+
return 1;
86+
}
87+
88+
// initialize the sampler
89+
llama_sampler * smpl = llama_sampler_chain_init(llama_sampler_chain_default_params());
90+
llama_sampler_chain_add(smpl, llama_sampler_init_min_p(0.05f, 1));
91+
llama_sampler_chain_add(smpl, llama_sampler_init_temp(0.8f));
92+
llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
93+
94+
// generation helper
95+
auto generate = [&](const std::string & prompt) {
96+
std::string response;
97+
98+
// tokenize the prompt
99+
const int n_prompt = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true);
100+
std::vector<llama_token> prompt_tokens(n_prompt);
101+
if (llama_tokenize(model, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) {
102+
GGML_ABORT("failed to tokenize the prompt\n");
103+
}
104+
105+
// prepare a batch for the prompt
106+
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
107+
llama_token new_token_id;
108+
while (true) {
109+
// check if we have enough context space to evaluate this batch
110+
int n_ctx = llama_n_ctx(ctx);
111+
int n_ctx_used = llama_get_kv_cache_used_cells(ctx);
112+
if (n_ctx_used + batch.n_tokens > n_ctx) {
113+
printf("\033[0m\n");
114+
fprintf(stderr, "context size exceeded\n");
115+
exit(0);
116+
}
117+
118+
if (llama_decode(ctx, batch)) {
119+
GGML_ABORT("failed to eval\n");
120+
}
121+
122+
// sample the next token
123+
new_token_id = llama_sampler_sample(smpl, ctx, -1);
124+
125+
// is it an end of generation?
126+
if (llama_token_is_eog(model, new_token_id)) {
127+
break;
128+
}
129+
130+
// add the token to the response
131+
char buf[128];
132+
int n = llama_token_to_piece(model, new_token_id, buf, sizeof(buf), 0, true);
133+
if (n < 0) {
134+
GGML_ABORT("failed to convert token to piece\n");
135+
}
136+
std::string piece(buf, n);
137+
response += piece;
138+
printf("%s", piece.c_str());
139+
fflush(stdout);
140+
141+
// prepare the next batch with the sampled token
142+
batch = llama_batch_get_one(&new_token_id, 1);
143+
}
144+
145+
return response;
146+
};
147+
148+
std::vector<llama_chat_message> messages;
149+
std::vector<char> formatted(2048);
150+
int prev_len = 0;
151+
while (true) {
152+
std::string user;
153+
std::getline(std::cin, user);
154+
messages.push_back({"user", strdup(user.c_str())});
155+
156+
// format the messages
157+
int new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size());
158+
if (new_len > (int)formatted.size()) {
159+
formatted.resize(new_len);
160+
new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size());
161+
}
162+
163+
// remove previous messages and obtain a prompt
164+
std::string prompt(formatted.begin() + prev_len, formatted.begin() + new_len);
165+
166+
// generate a response
167+
printf("\033[31m");
168+
std::string response = generate(prompt);
169+
printf("\n\033[0m");
170+
171+
// add the response to the messages
172+
messages.push_back({"assistant", strdup(response.c_str())});
173+
prev_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), false, formatted.data(), formatted.size());
174+
}
175+
176+
177+
llama_sampler_free(smpl);
178+
llama_free(ctx);
179+
llama_free_model(model);
180+
181+
return 0;
182+
}

ggml/include/ggml.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -559,10 +559,10 @@ extern "C" {
559559

560560
enum ggml_log_level {
561561
GGML_LOG_LEVEL_NONE = 0,
562-
GGML_LOG_LEVEL_INFO = 1,
563-
GGML_LOG_LEVEL_WARN = 2,
564-
GGML_LOG_LEVEL_ERROR = 3,
565-
GGML_LOG_LEVEL_DEBUG = 4,
562+
GGML_LOG_LEVEL_DEBUG = 1,
563+
GGML_LOG_LEVEL_INFO = 2,
564+
GGML_LOG_LEVEL_WARN = 3,
565+
GGML_LOG_LEVEL_ERROR = 4,
566566
GGML_LOG_LEVEL_CONT = 5, // continue previous log
567567
};
568568

0 commit comments

Comments
 (0)