Skip to content

Commit 95e0afb

Browse files
committed
wip: chat cli
1 parent c05e8c9 commit 95e0afb

File tree

3 files changed

+243
-22
lines changed

3 files changed

+243
-22
lines changed

examples/main/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
set(TARGET llama-cli)
2-
add_executable(${TARGET} main.cpp)
2+
add_executable(${TARGET} main.cpp chat.hpp)
33
install(TARGETS ${TARGET} RUNTIME)
44
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
55
target_compile_features(${TARGET} PRIVATE cxx_std_17)

examples/main/chat.hpp

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
#include "arg.h"
2+
#include "common.h"
3+
#include "console.h"
4+
#include "log.h"
5+
#include "sampling.h"
6+
#include "llama.h"
7+
8+
#include <fstream>
9+
10+
struct llama_cli_chat {
11+
struct llama_context * ctx;
12+
const struct llama_model * model;
13+
struct common_sampler * smpl;
14+
struct common_params params;
15+
16+
bool interacting = false;
17+
std::vector<common_chat_msg> chat_msgs;
18+
std::ostringstream pending_input;
19+
20+
struct llama_batch batch;
21+
llama_tokens cache_tokens;
22+
int n_past = 0;
23+
24+
llama_cli_chat(
25+
struct common_params & params,
26+
struct llama_context * ctx,
27+
struct common_sampler * smpl) : ctx(ctx), smpl(smpl), params(params) {
28+
model = llama_get_model(ctx);
29+
batch = llama_batch_init(params.n_batch, 0, 1);
30+
}
31+
32+
void decode(llama_tokens & eval_tokens, bool is_generating) {
33+
if (is_generating) {
34+
GGML_ASSERT(eval_tokens.size() == 1);
35+
} else {
36+
n_past = common_lcp(cache_tokens, eval_tokens);
37+
// in case we do a re-generation, we need to prevent eval_tokens from being empty
38+
if ((int) eval_tokens.size() == n_past) {
39+
n_past--;
40+
}
41+
if (n_past > 0) {
42+
eval_tokens.erase(eval_tokens.begin(), eval_tokens.begin() + n_past);
43+
cache_tokens.erase(cache_tokens.begin() + n_past, cache_tokens.end());
44+
LOG_DBG("remove from cache [%d, inf)\n", n_past);
45+
LOG_DBG("in cache: %s\n", common_detokenize(ctx, cache_tokens, true).c_str());
46+
LOG_DBG("to decode %d tokens\n", (int) eval_tokens.size());
47+
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
48+
}
49+
}
50+
51+
// decode
52+
for (size_t i = 0; i < eval_tokens.size(); i += params.n_batch) {
53+
if (interacting) {
54+
break;
55+
}
56+
57+
common_batch_clear(batch);
58+
for (int j = 0; j < params.n_batch && i + j < eval_tokens.size(); ++j) {
59+
n_past++;
60+
bool is_last_token = i + j == eval_tokens.size() - 1;
61+
common_batch_add(batch, eval_tokens[i + j], n_past, {0}, is_last_token);
62+
}
63+
64+
if (llama_decode(ctx, batch)) {
65+
GGML_ABORT("failed to decode\n");
66+
}
67+
}
68+
69+
// update cache tokens
70+
if (is_generating) {
71+
cache_tokens.push_back(eval_tokens[0]);
72+
} else {
73+
cache_tokens.insert(cache_tokens.end(), eval_tokens.begin(), eval_tokens.end());
74+
}
75+
}
76+
77+
[[noreturn]] void run() {
78+
while (true) {
79+
interacting = true;
80+
LOG("\n> ");
81+
82+
// color user input only
83+
console::set_display(console::user_input);
84+
std::string line;
85+
bool another_line = true;
86+
bool continue_input = false;
87+
do {
88+
another_line = console::readline(line, params.multiline_input);
89+
if (handle_command(line, continue_input)) {
90+
continue; // do not add this line to pending_input
91+
}
92+
pending_input << line;
93+
} while (another_line);
94+
95+
if (continue_input) {
96+
continue;
97+
}
98+
99+
if (pending_input.tellp() == 0) {
100+
LOG_DBG("empty line, passing control back\n");
101+
continue;
102+
}
103+
104+
// done taking input, reset color
105+
console::set_display(console::reset);
106+
interacting = false;
107+
108+
// add message and format chat
109+
if (!chat_msgs.empty() && chat_msgs.back().role == "user") {
110+
chat_msgs.pop_back();
111+
}
112+
chat_msgs.push_back({"user", string_strip(pending_input.str())});
113+
pending_input.str(""); // clear
114+
auto formatted = common_chat_apply_template(model, params.chat_template, chat_msgs, true);
115+
116+
// tokenize the new chat history and decode
117+
llama_tokens prompt_tokens = common_tokenize(ctx, formatted, true, true);
118+
decode(prompt_tokens, false);
119+
120+
// generate response
121+
llama_token new_token_id = LLAMA_TOKEN_NULL;
122+
llama_tokens generated_tokens;
123+
common_sampler_reset(smpl);
124+
while (true) {
125+
if (interacting) {
126+
break;
127+
}
128+
129+
// sample the next token
130+
new_token_id = common_sampler_sample(smpl, ctx, -1);
131+
132+
// is it an end of generation?
133+
if (llama_token_is_eog(model, new_token_id)) {
134+
break;
135+
}
136+
137+
// print the token, then decode it
138+
printf("%s", common_token_to_piece(ctx, new_token_id, params.special).c_str());
139+
fflush(stdout);
140+
generated_tokens.push_back(new_token_id);
141+
llama_tokens new_tok = {new_token_id};
142+
decode(new_tok, true);
143+
}
144+
145+
// add the generated tokens to the chat history
146+
std::string response = common_detokenize(ctx, generated_tokens, true);
147+
chat_msgs.push_back({"assistant", response});
148+
149+
// print a new line if needed
150+
if (!response.empty() && response.back() != '\n') {
151+
printf("\n");
152+
}
153+
}
154+
}
155+
156+
void interrupt() {
157+
if (interacting) {
158+
// exit
159+
printf("\n");
160+
console::cleanup();
161+
common_perf_print(ctx, smpl);
162+
common_log_pause(common_log_main());
163+
exit(0);
164+
}
165+
interacting = true;
166+
}
167+
168+
bool handle_command(std::string & inp, bool & continue_input) {
169+
if (inp.empty() || inp[0] != '/') {
170+
return false; // not a command
171+
}
172+
auto parts = string_split<std::string>(string_strip(inp), ' ');
173+
std::string & cmd = parts[0];
174+
if (cmd == "/help") {
175+
LOG("TODO\n");
176+
continue_input = true;
177+
} else if (cmd == "/history") {
178+
display_history();
179+
continue_input = true;
180+
} else if (cmd == "/regen") {
181+
if (chat_msgs.empty()) {
182+
LOG_ERR("no chat history to regenerate\n");
183+
continue_input = true;
184+
return true;
185+
}
186+
if (chat_msgs.back().role == "assistant") {
187+
chat_msgs.pop_back();
188+
}
189+
if (chat_msgs.back().role == "user") {
190+
pending_input.str(""); // clear
191+
pending_input << chat_msgs.back().content;
192+
chat_msgs.pop_back();
193+
}
194+
continue_input = false;
195+
} else if (cmd == "/readfile") {
196+
const std::string filename = parts[1];
197+
LOG_DBG("reading file: '%s'\n", filename.c_str());
198+
std::ifstream text_file(filename);
199+
if (!text_file) {
200+
LOG("failed to open file '%s'\n", filename.c_str());
201+
} else {
202+
pending_input << text_file.rdbuf() << "\n\n";
203+
LOG("read %zu characters from file\n", (size_t) text_file.tellg());
204+
}
205+
continue_input = true;
206+
} else {
207+
LOG_ERR("unknown command: %s\n", cmd.c_str());
208+
continue_input = true;
209+
}
210+
return true;
211+
}
212+
213+
void display_history() {
214+
for (const auto & msg : chat_msgs) {
215+
LOG("%s: %s\n\n", msg.role.c_str(), msg.content.c_str());
216+
}
217+
}
218+
219+
~llama_cli_chat() {
220+
llama_batch_free(batch);
221+
}
222+
};

examples/main/main.cpp

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "log.h"
55
#include "sampling.h"
66
#include "llama.h"
7+
#include "chat.hpp"
78

89
#include <cassert>
910
#include <cstdio>
@@ -35,6 +36,7 @@ static llama_context ** g_ctx;
3536
static llama_model ** g_model;
3637
static common_sampler ** g_smpl;
3738
static common_params * g_params;
39+
static llama_cli_chat * g_chat;
3840
static std::vector<llama_token> * g_input_tokens;
3941
static std::ostringstream * g_output_ss;
4042
static std::vector<llama_token> * g_output_tokens;
@@ -65,7 +67,9 @@ static bool file_is_empty(const std::string & path) {
6567
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
6668
static void sigint_handler(int signo) {
6769
if (signo == SIGINT) {
68-
if (!is_interacting && g_params->interactive) {
70+
if (g_chat) {
71+
g_chat->interrupt();
72+
} else if (!is_interacting && g_params->interactive) {
6973
is_interacting = true;
7074
need_insert_eot = true;
7175
} else {
@@ -83,14 +87,6 @@ static void sigint_handler(int signo) {
8387
}
8488
#endif
8589

86-
static std::string chat_add_and_format(struct llama_model * model, std::vector<common_chat_msg> & chat_msgs, const std::string & role, const std::string & content) {
87-
common_chat_msg new_msg{role, content};
88-
auto formatted = common_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user");
89-
chat_msgs.push_back({role, content});
90-
LOG_DBG("formatted: '%s'\n", formatted.c_str());
91-
return formatted;
92-
}
93-
9490
int main(int argc, char ** argv) {
9591
common_params params;
9692
g_params = &params;
@@ -203,6 +199,12 @@ int main(int argc, char ** argv) {
203199
LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx);
204200
}
205201

202+
// switch on conversation mode if chat template is present
203+
if (!params.chat_template.empty() || !common_get_builtin_chat_template(model).empty()) {
204+
LOG("%s: using chat mode\n", __func__);
205+
params.conversation = true;
206+
}
207+
206208
// print chat template example in conversation mode
207209
if (params.conversation) {
208210
if (params.enable_chat_template) {
@@ -251,18 +253,15 @@ int main(int argc, char ** argv) {
251253
std::vector<llama_token> embd_inp;
252254

253255
{
254-
auto prompt = (params.conversation && params.enable_chat_template && !params.prompt.empty())
255-
? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode
256-
: params.prompt;
257256
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
258257
LOG_DBG("tokenize the prompt\n");
259-
embd_inp = common_tokenize(ctx, prompt, true, true);
258+
embd_inp = common_tokenize(ctx, params.prompt, true, true);
260259
} else {
261260
LOG_DBG("use session tokens\n");
262261
embd_inp = session_tokens;
263262
}
264263

265-
LOG_DBG("prompt: \"%s\"\n", prompt.c_str());
264+
LOG_DBG("prompt: \"%s\"\n", params.prompt.c_str());
266265
LOG_DBG("tokens: %s\n", string_from(ctx, embd_inp).c_str());
267266
}
268267

@@ -420,6 +419,12 @@ int main(int argc, char ** argv) {
420419

421420
LOG_INF("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
422421

422+
if (params.conversation) {
423+
llama_cli_chat chat(params, ctx, smpl);
424+
g_chat = &chat;
425+
chat.run();
426+
}
427+
423428
// group-attention state
424429
// number of grouped KV tokens so far (used only if params.grp_attn_n > 1)
425430
int ga_i = 0;
@@ -752,10 +757,6 @@ int main(int argc, char ** argv) {
752757
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
753758
is_antiprompt = true;
754759
}
755-
756-
if (params.enable_chat_template) {
757-
chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str());
758-
}
759760
is_interacting = true;
760761
LOG("\n");
761762
}
@@ -818,9 +819,7 @@ int main(int argc, char ** argv) {
818819
}
819820

820821
bool format_chat = params.conversation && params.enable_chat_template;
821-
std::string user_inp = format_chat
822-
? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
823-
: std::move(buffer);
822+
std::string user_inp = std::move(buffer);
824823
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
825824
const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true);
826825
const auto line_inp = common_tokenize(ctx, user_inp, false, format_chat);

0 commit comments

Comments
 (0)