Skip to content

Commit 1716e6b

Browse files
committed
add some other commands
1 parent d7a4f3e commit 1716e6b

File tree

3 files changed

+113
-21
lines changed

3 files changed

+113
-21
lines changed

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1939,6 +1939,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
19391939
params.simple_io = true;
19401940
}
19411941
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_INFILL}));
1942+
add_opt(common_arg(
1943+
{"-nsc", "--no-special-command"},
1944+
string_format("disable special commands in conversation mode (default: %s)", params.special_cmds ? "enabled" : "disabled"),
1945+
[](common_params & params) {
1946+
params.special_cmds = false;
1947+
}
1948+
).set_examples({LLAMA_EXAMPLE_MAIN}));
19421949
add_opt(common_arg(
19431950
{"-ld", "--logdir"}, "LOGDIR",
19441951
"path under which to save YAML logs (no logging if unset)",

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ struct common_params {
251251
bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix)
252252
bool prompt_cache_all = false; // save user input and generations to prompt cache
253253
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it
254+
bool special_cmds = true; // enable special commands in main example
254255

255256
bool escape = true; // escape "\n", "\r", "\t", "\'", "\"", and "\\"
256257
bool multiline_input = false; // reverse the usage of `\`

examples/main/main.cpp

Lines changed: 105 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,6 @@
3131
#pragma warning(disable: 4244 4267) // possible loss of data
3232
#endif
3333

34-
static const std::string CMD_READFILE = "/readfile";
35-
static const std::string CMD_SAVE_SESS = "/savesess";
36-
static const std::string CMD_LOAD_SESS = "/loadsess";
37-
3834
static llama_context ** g_ctx;
3935
static llama_model ** g_model;
4036
static common_sampler ** g_smpl;
@@ -45,13 +41,22 @@ static std::vector<llama_token> * g_output_tokens;
4541
static bool is_interacting = false;
4642
static bool need_insert_eot = false;
4743

44+
static const char * help_special_cmds = "special commands in conversation mode:\n"
45+
" /readfile FILE read prompt from file\n"
46+
" /savesess FILE save session to file\n"
47+
" /loadsess FILE load session from file\n"
48+
" /regen regenerate the last response\n"
49+
" /dump FILE dump chat content to a file\n";
50+
4851
static void print_usage(int argc, char ** argv) {
4952
(void) argc;
5053

5154
LOG("\nexample usage:\n");
5255
LOG("\n text generation: %s -m your_model.gguf -p \"I believe the meaning of life is\" -n 128\n", argv[0]);
5356
LOG("\n chat (conversation): %s -m your_model.gguf -p \"You are a helpful assistant\" -cnv\n", argv[0]);
5457
LOG("\n");
58+
LOG("%s", help_special_cmds);
59+
LOG("\n");
5560
}
5661

5762
static bool file_exists(const std::string & path) {
@@ -109,6 +114,21 @@ static void write_logfile(
109114
fclose(logfile);
110115
}
111116

117+
static std::vector<std::string> try_parse_command(std::string text) {
118+
if (text.empty() || text[0] != '/') {
119+
return {};
120+
}
121+
std::vector<std::string> elem = string_split<std::string>(text, ' ');
122+
std::vector<std::string> res;
123+
// filter empty strings
124+
for (const auto & e : elem) {
125+
if (!e.empty()) {
126+
res.push_back(string_strip(e));
127+
}
128+
}
129+
return res;
130+
}
131+
112132
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
113133
static void sigint_handler(int signo) {
114134
if (signo == SIGINT) {
@@ -131,7 +151,11 @@ static void sigint_handler(int signo) {
131151
}
132152
#endif
133153

154+
// return the formatted turn to be decoded
134155
static std::string chat_add_and_format(struct llama_model * model, std::vector<common_chat_msg> & chat_msgs, const std::string & role, const std::string & content) {
156+
if (content.empty()) {
157+
return "";
158+
}
135159
common_chat_msg new_msg{role, content};
136160
auto formatted = common_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user");
137161
chat_msgs.push_back({role, content});
@@ -193,6 +217,7 @@ int main(int argc, char ** argv) {
193217
llama_context * ctx = nullptr;
194218
common_sampler * smpl = nullptr;
195219

220+
std::vector<int> pos_history; // history of positions of chat messages
196221
std::vector<common_chat_msg> chat_msgs;
197222

198223
g_model = &model;
@@ -519,6 +544,7 @@ int main(int argc, char ** argv) {
519544
display = params.display_prompt;
520545

521546
std::vector<llama_token> embd;
547+
llama_batch batch = llama_batch_init(params.n_batch, 0, 1);
522548

523549
// tokenized antiprompts
524550
std::vector<std::vector<llama_token>> antiprompt_ids;
@@ -546,6 +572,8 @@ int main(int argc, char ** argv) {
546572
embd_inp.push_back(decoder_start_token_id);
547573
}
548574

575+
std::stringstream pending_input; // used by "/readfile" command
576+
549577
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
550578
// predict
551579
if (!embd.empty()) {
@@ -652,7 +680,19 @@ int main(int argc, char ** argv) {
652680

653681
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
654682

655-
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) {
683+
common_batch_clear(batch);
684+
for (int j = 0; j < n_eval; j++) {
685+
int idx = i + j;
686+
common_batch_add(
687+
batch,
688+
embd[idx],
689+
n_past + idx,
690+
{0},
691+
idx == (int) embd.size() - 1
692+
);
693+
}
694+
695+
if (llama_decode(ctx, batch)) {
656696
LOG_ERR("%s : failed to eval\n", __func__);
657697
return 1;
658698
}
@@ -856,42 +896,83 @@ int main(int argc, char ** argv) {
856896
LOG_DBG("buffer: '%s'\n", buffer.c_str());
857897

858898
// check for special commands
859-
if (buffer.rfind(CMD_READFILE, 0) == 0) {
860-
const std::string filename = string_strip(buffer.substr(CMD_READFILE.length()));
899+
const std::vector<std::string> cmd = params.special_cmds
900+
? try_parse_command(buffer)
901+
: std::vector<std::string>();
902+
903+
if (cmd.size() == 2 && cmd[0] == "/readfile") {
904+
const std::string filename = cmd[1];
861905
LOG_DBG("reading file: '%s'\n", filename.c_str());
862906
std::ifstream text_file(filename);
863907
if (!text_file) {
864908
LOG("failed to open file '%s'\n", filename.c_str());
865909
continue;
866910
}
867-
std::stringstream tmp;
868-
tmp << text_file.rdbuf();
869-
buffer = tmp.str();
870-
LOG("%s\n", buffer.c_str());
871-
} else if (buffer.rfind(CMD_SAVE_SESS, 0) == 0) {
872-
const std::string filename = string_strip(buffer.substr(CMD_SAVE_SESS.length()));
911+
pending_input << text_file.rdbuf() << "\n\n";
912+
LOG("read %zu characters from file\n", (size_t) text_file.tellg());
913+
continue;
914+
} else if (cmd.size() == 2 && cmd[0] == "/savesess") {
915+
const std::string filename = cmd[1];
873916
LOG("save session file: '%s'\n", filename.c_str());
874917
size_t res = llama_state_save_file(ctx, filename.c_str(), embd_inp.data(), n_past);
875918
if (res == 0) {
876919
LOG("failed to save session file '%s'\n", filename.c_str());
877920
}
878921
continue;
879-
} else if (buffer.rfind(CMD_LOAD_SESS, 0) == 0) {
880-
const std::string filename = string_strip(buffer.substr(CMD_LOAD_SESS.length()));
922+
} else if (cmd.size() == 2 && cmd[0] == "/loadsess") {
923+
const std::string filename = cmd[1];
881924
LOG("load session file: '%s'\n", filename.c_str());
882-
std::vector<llama_token> sess_tokens;
883-
sess_tokens.resize(n_ctx);
884-
size_t n_loaded_tokens;
885-
size_t res = llama_state_load_file(ctx, filename.c_str(), sess_tokens.data(), sess_tokens.size(), &n_loaded_tokens);
925+
session_tokens.resize(n_ctx);
926+
size_t n_token_count_out;
927+
size_t res = llama_state_load_file(ctx, filename.c_str(), session_tokens.data(), session_tokens.size(), &n_token_count_out);
886928
if (res == 0) {
887929
LOG("failed to load session file '%s'\n", filename.c_str());
888930
} else {
889-
n_past = n_loaded_tokens;
890-
LOG("loaded %zu tokens from session file '%s'\n", n_loaded_tokens, filename.c_str());
931+
session_tokens.resize(n_token_count_out);
932+
embd_inp = session_tokens;
933+
n_past = n_token_count_out;
934+
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
935+
LOG("loaded %zu tokens from session file '%s'\n", n_token_count_out, filename.c_str());
936+
}
937+
continue;
938+
} else if (cmd.size() == 1 && cmd[0] == "/regen") {
939+
if (pos_history.empty()) {
940+
LOG("no previous assistant message to regenerate\n");
941+
continue;
942+
}
943+
int last_n_past = pos_history.back();
944+
int n_tokens_removed = n_past - last_n_past;
945+
llama_kv_cache_seq_rm(ctx, 0, last_n_past, -1);
946+
n_remain += n_tokens_removed;
947+
is_interacting = false;
948+
// we intentionally do not reset the sampling, so new message will be more diverse
949+
continue;
950+
} else if (cmd.size() == 2 && cmd[0] == "/dump") {
951+
const std::string filename = cmd[1];
952+
std::ofstream dump_file(filename);
953+
if (!dump_file) {
954+
LOG("failed to create file '%s'\n", filename.c_str());
955+
continue;
956+
}
957+
for (const auto & msg : chat_msgs) {
958+
dump_file << msg.role << ":\n" << msg.content << "\n---\n";
891959
}
960+
dump_file.close();
961+
LOG("dumped chat messages to file '%s'\n", filename.c_str());
962+
continue;
963+
} else if (!cmd.empty()) {
964+
LOG("unknown command: %s\n", buffer.c_str());
965+
LOG("%s", help_special_cmds);
892966
continue;
893967
}
894968

969+
if (pending_input.tellp() > 0) {
970+
// concatenate read file and the prompt
971+
pending_input << buffer;
972+
buffer = pending_input.str();
973+
pending_input.clear();
974+
}
975+
895976
const size_t original_size = embd_inp.size();
896977

897978
if (params.escape) {
@@ -926,6 +1007,8 @@ int main(int argc, char ** argv) {
9261007
output_ss << common_token_to_piece(ctx, token);
9271008
}
9281009

1010+
pos_history.push_back(n_past + embd_inp.size() - original_size);
1011+
9291012
// reset assistant message
9301013
assistant_ss.str("");
9311014

@@ -971,6 +1054,7 @@ int main(int argc, char ** argv) {
9711054

9721055
common_sampler_free(smpl);
9731056

1057+
llama_batch_free(batch);
9741058
llama_free(ctx);
9751059
llama_free_model(model);
9761060

0 commit comments

Comments
 (0)