|
| 1 | +#include "arg.h" |
| 2 | +#include "common.h" |
| 3 | +#include "console.h" |
| 4 | +#include "log.h" |
| 5 | +#include "sampling.h" |
| 6 | +#include "llama.h" |
| 7 | + |
| 8 | +#include <fstream> |
| 9 | + |
| 10 | +struct llama_cli_chat { |
| 11 | + struct llama_context * ctx; |
| 12 | + const struct llama_model * model; |
| 13 | + struct common_sampler * smpl; |
| 14 | + struct common_params params; |
| 15 | + |
| 16 | + bool interacting = false; |
| 17 | + std::vector<common_chat_msg> chat_msgs; |
| 18 | + std::ostringstream pending_input; |
| 19 | + |
| 20 | + struct llama_batch batch; |
| 21 | + llama_tokens cache_tokens; |
| 22 | + int n_past = 0; |
| 23 | + |
| 24 | + llama_cli_chat( |
| 25 | + struct common_params & params, |
| 26 | + struct llama_context * ctx, |
| 27 | + struct common_sampler * smpl) : ctx(ctx), smpl(smpl), params(params) { |
| 28 | + model = llama_get_model(ctx); |
| 29 | + batch = llama_batch_init(params.n_batch, 0, 1); |
| 30 | + } |
| 31 | + |
| 32 | + void decode(llama_tokens & eval_tokens, bool is_generating) { |
| 33 | + if (is_generating) { |
| 34 | + GGML_ASSERT(eval_tokens.size() == 1); |
| 35 | + } else { |
| 36 | + n_past = common_lcp(cache_tokens, eval_tokens); |
| 37 | + // in case we do a re-generation, we need to prevent eval_tokens from being empty |
| 38 | + if ((int) eval_tokens.size() == n_past) { |
| 39 | + n_past--; |
| 40 | + } |
| 41 | + if (n_past > 0) { |
| 42 | + eval_tokens.erase(eval_tokens.begin(), eval_tokens.begin() + n_past); |
| 43 | + cache_tokens.erase(cache_tokens.begin() + n_past, cache_tokens.end()); |
| 44 | + LOG_DBG("remove from cache [%d, inf)\n", n_past); |
| 45 | + LOG_DBG("in cache: %s\n", common_detokenize(ctx, cache_tokens, true).c_str()); |
| 46 | + LOG_DBG("to decode %d tokens\n", (int) eval_tokens.size()); |
| 47 | + llama_kv_cache_seq_rm(ctx, 0, n_past, -1); |
| 48 | + } |
| 49 | + } |
| 50 | + |
| 51 | + // decode |
| 52 | + for (size_t i = 0; i < eval_tokens.size(); i += params.n_batch) { |
| 53 | + if (interacting) { |
| 54 | + break; |
| 55 | + } |
| 56 | + |
| 57 | + common_batch_clear(batch); |
| 58 | + for (int j = 0; j < params.n_batch && i + j < eval_tokens.size(); ++j) { |
| 59 | + n_past++; |
| 60 | + bool is_last_token = i + j == eval_tokens.size() - 1; |
| 61 | + common_batch_add(batch, eval_tokens[i + j], n_past, {0}, is_last_token); |
| 62 | + } |
| 63 | + |
| 64 | + if (llama_decode(ctx, batch)) { |
| 65 | + GGML_ABORT("failed to decode\n"); |
| 66 | + } |
| 67 | + } |
| 68 | + |
| 69 | + // update cache tokens |
| 70 | + if (is_generating) { |
| 71 | + cache_tokens.push_back(eval_tokens[0]); |
| 72 | + } else { |
| 73 | + cache_tokens.insert(cache_tokens.end(), eval_tokens.begin(), eval_tokens.end()); |
| 74 | + } |
| 75 | + } |
| 76 | + |
| 77 | + [[noreturn]] void run() { |
| 78 | + while (true) { |
| 79 | + interacting = true; |
| 80 | + LOG("\n> "); |
| 81 | + |
| 82 | + // color user input only |
| 83 | + console::set_display(console::user_input); |
| 84 | + std::string line; |
| 85 | + bool another_line = true; |
| 86 | + bool continue_input = false; |
| 87 | + do { |
| 88 | + another_line = console::readline(line, params.multiline_input); |
| 89 | + if (handle_command(line, continue_input)) { |
| 90 | + continue; // do not add this line to pending_input |
| 91 | + } |
| 92 | + pending_input << line; |
| 93 | + } while (another_line); |
| 94 | + |
| 95 | + if (continue_input) { |
| 96 | + continue; |
| 97 | + } |
| 98 | + |
| 99 | + if (pending_input.tellp() == 0) { |
| 100 | + LOG_DBG("empty line, passing control back\n"); |
| 101 | + continue; |
| 102 | + } |
| 103 | + |
| 104 | + // done taking input, reset color |
| 105 | + console::set_display(console::reset); |
| 106 | + interacting = false; |
| 107 | + |
| 108 | + // add message and format chat |
| 109 | + if (!chat_msgs.empty() && chat_msgs.back().role == "user") { |
| 110 | + chat_msgs.pop_back(); |
| 111 | + } |
| 112 | + chat_msgs.push_back({"user", string_strip(pending_input.str())}); |
| 113 | + pending_input.str(""); // clear |
| 114 | + auto formatted = common_chat_apply_template(model, params.chat_template, chat_msgs, true); |
| 115 | + |
| 116 | + // tokenize the new chat history and decode |
| 117 | + llama_tokens prompt_tokens = common_tokenize(ctx, formatted, true, true); |
| 118 | + decode(prompt_tokens, false); |
| 119 | + |
| 120 | + // generate response |
| 121 | + llama_token new_token_id = LLAMA_TOKEN_NULL; |
| 122 | + llama_tokens generated_tokens; |
| 123 | + common_sampler_reset(smpl); |
| 124 | + while (true) { |
| 125 | + if (interacting) { |
| 126 | + break; |
| 127 | + } |
| 128 | + |
| 129 | + // sample the next token |
| 130 | + new_token_id = common_sampler_sample(smpl, ctx, -1); |
| 131 | + |
| 132 | + // is it an end of generation? |
| 133 | + if (llama_token_is_eog(model, new_token_id)) { |
| 134 | + break; |
| 135 | + } |
| 136 | + |
| 137 | + // print the token, then decode it |
| 138 | + printf("%s", common_token_to_piece(ctx, new_token_id, params.special).c_str()); |
| 139 | + fflush(stdout); |
| 140 | + generated_tokens.push_back(new_token_id); |
| 141 | + llama_tokens new_tok = {new_token_id}; |
| 142 | + decode(new_tok, true); |
| 143 | + } |
| 144 | + |
| 145 | + // add the generated tokens to the chat history |
| 146 | + std::string response = common_detokenize(ctx, generated_tokens, true); |
| 147 | + chat_msgs.push_back({"assistant", response}); |
| 148 | + |
| 149 | + // print a new line if needed |
| 150 | + if (!response.empty() && response.back() != '\n') { |
| 151 | + printf("\n"); |
| 152 | + } |
| 153 | + } |
| 154 | + } |
| 155 | + |
| 156 | + void interrupt() { |
| 157 | + if (interacting) { |
| 158 | + // exit |
| 159 | + printf("\n"); |
| 160 | + console::cleanup(); |
| 161 | + common_perf_print(ctx, smpl); |
| 162 | + common_log_pause(common_log_main()); |
| 163 | + exit(0); |
| 164 | + } |
| 165 | + interacting = true; |
| 166 | + } |
| 167 | + |
| 168 | + bool handle_command(std::string & inp, bool & continue_input) { |
| 169 | + if (inp.empty() || inp[0] != '/') { |
| 170 | + return false; // not a command |
| 171 | + } |
| 172 | + auto parts = string_split<std::string>(string_strip(inp), ' '); |
| 173 | + std::string & cmd = parts[0]; |
| 174 | + if (cmd == "/help") { |
| 175 | + LOG("TODO\n"); |
| 176 | + continue_input = true; |
| 177 | + } else if (cmd == "/history") { |
| 178 | + display_history(); |
| 179 | + continue_input = true; |
| 180 | + } else if (cmd == "/regen") { |
| 181 | + if (chat_msgs.empty()) { |
| 182 | + LOG_ERR("no chat history to regenerate\n"); |
| 183 | + continue_input = true; |
| 184 | + return true; |
| 185 | + } |
| 186 | + if (chat_msgs.back().role == "assistant") { |
| 187 | + chat_msgs.pop_back(); |
| 188 | + } |
| 189 | + if (chat_msgs.back().role == "user") { |
| 190 | + pending_input.str(""); // clear |
| 191 | + pending_input << chat_msgs.back().content; |
| 192 | + chat_msgs.pop_back(); |
| 193 | + } |
| 194 | + continue_input = false; |
| 195 | + } else if (cmd == "/readfile") { |
| 196 | + const std::string filename = parts[1]; |
| 197 | + LOG_DBG("reading file: '%s'\n", filename.c_str()); |
| 198 | + std::ifstream text_file(filename); |
| 199 | + if (!text_file) { |
| 200 | + LOG("failed to open file '%s'\n", filename.c_str()); |
| 201 | + } else { |
| 202 | + pending_input << text_file.rdbuf() << "\n\n"; |
| 203 | + LOG("read %zu characters from file\n", (size_t) text_file.tellg()); |
| 204 | + } |
| 205 | + continue_input = true; |
| 206 | + } else { |
| 207 | + LOG_ERR("unknown command: %s\n", cmd.c_str()); |
| 208 | + continue_input = true; |
| 209 | + } |
| 210 | + return true; |
| 211 | + } |
| 212 | + |
| 213 | + void display_history() { |
| 214 | + for (const auto & msg : chat_msgs) { |
| 215 | + LOG("%s: %s\n\n", msg.role.c_str(), msg.content.c_str()); |
| 216 | + } |
| 217 | + } |
| 218 | + |
| 219 | + ~llama_cli_chat() { |
| 220 | + llama_batch_free(batch); |
| 221 | + } |
| 222 | +}; |
0 commit comments