3131#pragma warning(disable: 4244 4267) // possible loss of data
3232#endif
3333
34- static const std::string CMD_READFILE = " /readfile" ;
35- static const std::string CMD_SAVE_SESS = " /savesess" ;
36- static const std::string CMD_LOAD_SESS = " /loadsess" ;
37-
3834static llama_context ** g_ctx;
3935static llama_model ** g_model;
4036static common_sampler ** g_smpl;
@@ -45,13 +41,22 @@ static std::vector<llama_token> * g_output_tokens;
4541static bool is_interacting = false ;
4642static bool need_insert_eot = false ;
4743
44+ static const char * help_special_cmds = " special commands in conversation mode:\n "
45+ " /readfile FILE read prompt from file\n "
46+ " /savesess FILE save session to file\n "
47+ " /loadsess FILE load session from file\n "
48+ " /regen regenerate the last response\n "
49+ " /dump FILE dump chat content to a file\n " ;
50+
4851static void print_usage (int argc, char ** argv) {
4952 (void ) argc;
5053
5154 LOG (" \n example usage:\n " );
5255 LOG (" \n text generation: %s -m your_model.gguf -p \" I believe the meaning of life is\" -n 128\n " , argv[0 ]);
5356 LOG (" \n chat (conversation): %s -m your_model.gguf -p \" You are a helpful assistant\" -cnv\n " , argv[0 ]);
5457 LOG (" \n " );
58+ LOG (" %s" , help_special_cmds);
59+ LOG (" \n " );
5560}
5661
5762static bool file_exists (const std::string & path) {
@@ -109,6 +114,21 @@ static void write_logfile(
109114 fclose (logfile);
110115}
111116
117+ static std::vector<std::string> try_parse_command (std::string text) {
118+ if (text.empty () || text[0 ] != ' /' ) {
119+ return {};
120+ }
121+ std::vector<std::string> elem = string_split<std::string>(text, ' ' );
122+ std::vector<std::string> res;
123+ // filter empty strings
124+ for (const auto & e : elem) {
125+ if (!e.empty ()) {
126+ res.push_back (string_strip (e));
127+ }
128+ }
129+ return res;
130+ }
131+
112132#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
113133static void sigint_handler (int signo) {
114134 if (signo == SIGINT) {
@@ -131,7 +151,11 @@ static void sigint_handler(int signo) {
131151}
132152#endif
133153
154+ // return the formatted turn to be decoded
134155static std::string chat_add_and_format (struct llama_model * model, std::vector<common_chat_msg> & chat_msgs, const std::string & role, const std::string & content) {
156+ if (content.empty ()) {
157+ return " " ;
158+ }
135159 common_chat_msg new_msg{role, content};
136160 auto formatted = common_chat_format_single (model, g_params->chat_template , chat_msgs, new_msg, role == " user" );
137161 chat_msgs.push_back ({role, content});
@@ -193,6 +217,7 @@ int main(int argc, char ** argv) {
193217 llama_context * ctx = nullptr ;
194218 common_sampler * smpl = nullptr ;
195219
220+ std::vector<int > pos_history; // history of positions of chat messages
196221 std::vector<common_chat_msg> chat_msgs;
197222
198223 g_model = &model;
@@ -519,6 +544,7 @@ int main(int argc, char ** argv) {
519544 display = params.display_prompt ;
520545
521546 std::vector<llama_token> embd;
547+ llama_batch batch = llama_batch_init (params.n_batch , 0 , 1 );
522548
523549 // tokenized antiprompts
524550 std::vector<std::vector<llama_token>> antiprompt_ids;
@@ -546,6 +572,8 @@ int main(int argc, char ** argv) {
546572 embd_inp.push_back (decoder_start_token_id);
547573 }
548574
575+ std::stringstream pending_input; // used by "/readfile" command
576+
549577 while ((n_remain != 0 && !is_antiprompt) || params.interactive ) {
550578 // predict
551579 if (!embd.empty ()) {
@@ -652,7 +680,19 @@ int main(int argc, char ** argv) {
652680
653681 LOG_DBG (" eval: %s\n " , string_from (ctx, embd).c_str ());
654682
655- if (llama_decode (ctx, llama_batch_get_one (&embd[i], n_eval))) {
683+ common_batch_clear (batch);
684+ for (int j = 0 ; j < n_eval; j++) {
685+ int idx = i + j;
686+ common_batch_add (
687+ batch,
688+ embd[idx],
689+ n_past + idx,
690+ {0 },
691+ idx == (int ) embd.size () - 1
692+ );
693+ }
694+
695+ if (llama_decode (ctx, batch)) {
656696 LOG_ERR (" %s : failed to eval\n " , __func__);
657697 return 1 ;
658698 }
@@ -856,42 +896,83 @@ int main(int argc, char ** argv) {
856896 LOG_DBG (" buffer: '%s'\n " , buffer.c_str ());
857897
858898 // check for special commands
859- if (buffer.rfind (CMD_READFILE, 0 ) == 0 ) {
860- const std::string filename = string_strip (buffer.substr (CMD_READFILE.length ()));
899+ const std::vector<std::string> cmd = params.special_cmds
900+ ? try_parse_command (buffer)
901+ : std::vector<std::string>();
902+
903+ if (cmd.size () == 2 && cmd[0 ] == " /readfile" ) {
904+ const std::string filename = cmd[1 ];
861905 LOG_DBG (" reading file: '%s'\n " , filename.c_str ());
862906 std::ifstream text_file (filename);
863907 if (!text_file) {
864908 LOG (" failed to open file '%s'\n " , filename.c_str ());
865909 continue ;
866910 }
867- std::stringstream tmp;
868- tmp << text_file.rdbuf ();
869- buffer = tmp.str ();
870- LOG (" %s\n " , buffer.c_str ());
871- } else if (buffer.rfind (CMD_SAVE_SESS, 0 ) == 0 ) {
872- const std::string filename = string_strip (buffer.substr (CMD_SAVE_SESS.length ()));
911+ pending_input << text_file.rdbuf () << " \n\n " ;
912+ LOG (" read %zu characters from file\n " , (size_t ) text_file.tellg ());
913+ continue ;
914+ } else if (cmd.size () == 2 && cmd[0 ] == " /savesess" ) {
915+ const std::string filename = cmd[1 ];
873916 LOG (" save session file: '%s'\n " , filename.c_str ());
874917 size_t res = llama_state_save_file (ctx, filename.c_str (), embd_inp.data (), n_past);
875918 if (res == 0 ) {
876919 LOG (" failed to save session file '%s'\n " , filename.c_str ());
877920 }
878921 continue ;
879- } else if (buffer. rfind (CMD_LOAD_SESS, 0 ) == 0 ) {
880- const std::string filename = string_strip (buffer. substr (CMD_LOAD_SESS. length ())) ;
922+ } else if (cmd. size ( ) == 2 && cmd[ 0 ] == " /loadsess " ) {
923+ const std::string filename = cmd[ 1 ] ;
881924 LOG (" load session file: '%s'\n " , filename.c_str ());
882- std::vector<llama_token> sess_tokens;
883- sess_tokens.resize (n_ctx);
884- size_t n_loaded_tokens;
885- size_t res = llama_state_load_file (ctx, filename.c_str (), sess_tokens.data (), sess_tokens.size (), &n_loaded_tokens);
925+ session_tokens.resize (n_ctx);
926+ size_t n_token_count_out;
927+ size_t res = llama_state_load_file (ctx, filename.c_str (), session_tokens.data (), session_tokens.size (), &n_token_count_out);
886928 if (res == 0 ) {
887929 LOG (" failed to load session file '%s'\n " , filename.c_str ());
888930 } else {
889- n_past = n_loaded_tokens;
890- LOG (" loaded %zu tokens from session file '%s'\n " , n_loaded_tokens, filename.c_str ());
931+ session_tokens.resize (n_token_count_out);
932+ embd_inp = session_tokens;
933+ n_past = n_token_count_out;
934+ llama_kv_cache_seq_rm (ctx, 0 , n_past, -1 );
935+ LOG (" loaded %zu tokens from session file '%s'\n " , n_token_count_out, filename.c_str ());
936+ }
937+ continue ;
938+ } else if (cmd.size () == 1 && cmd[0 ] == " /regen" ) {
939+ if (pos_history.empty ()) {
940+ LOG (" no previous assistant message to regenerate\n " );
941+ continue ;
942+ }
943+ int last_n_past = pos_history.back ();
944+ int n_tokens_removed = n_past - last_n_past;
945+ llama_kv_cache_seq_rm (ctx, 0 , last_n_past, -1 );
946+ n_remain += n_tokens_removed;
947+ is_interacting = false ;
948+ // we intentionally do not reset the sampling, so new message will be more diverse
949+ continue ;
950+ } else if (cmd.size () == 2 && cmd[0 ] == " /dump" ) {
951+ const std::string filename = cmd[1 ];
952+ std::ofstream dump_file (filename);
953+ if (!dump_file) {
954+ LOG (" failed to create file '%s'\n " , filename.c_str ());
955+ continue ;
956+ }
957+ for (const auto & msg : chat_msgs) {
958+ dump_file << msg.role << " :\n " << msg.content << " \n ---\n " ;
891959 }
960+ dump_file.close ();
961+ LOG (" dumped chat messages to file '%s'\n " , filename.c_str ());
962+ continue ;
963+ } else if (!cmd.empty ()) {
964+ LOG (" unknown command: %s\n " , buffer.c_str ());
965+ LOG (" %s" , help_special_cmds);
892966 continue ;
893967 }
894968
969+ if (pending_input.tellp () > 0 ) {
970+ // concatenate read file and the prompt
971+ pending_input << buffer;
972+ buffer = pending_input.str ();
973+ pending_input.clear ();
974+ }
975+
895976 const size_t original_size = embd_inp.size ();
896977
897978 if (params.escape ) {
@@ -926,6 +1007,8 @@ int main(int argc, char ** argv) {
9261007 output_ss << common_token_to_piece (ctx, token);
9271008 }
9281009
1010+ pos_history.push_back (n_past + embd_inp.size () - original_size);
1011+
9291012 // reset assistant message
9301013 assistant_ss.str (" " );
9311014
@@ -971,6 +1054,7 @@ int main(int argc, char ** argv) {
9711054
9721055 common_sampler_free (smpl);
9731056
1057+ llama_batch_free (batch);
9741058 llama_free (ctx);
9751059 llama_free_model (model);
9761060
0 commit comments