@@ -323,25 +323,17 @@ class File {
323323 return 0 ;
324324 }
325325
326- std::string read_all (const std::string & filename){
327- open (filename, " r" );
328- lock ();
329- if (!file) {
330- printe (" Error opening file '%s': %s" , filename.c_str (), strerror (errno));
331- return " " ;
332- }
333-
326+ std::string to_string () {
334327 fseek (file, 0 , SEEK_END);
335- size_t size = ftell (file);
328+ const size_t size = ftell (file);
336329 fseek (file, 0 , SEEK_SET);
337-
338330 std::string out;
339331 out.resize (size);
340- size_t read_size = fread (&out[0 ], 1 , size, file);
332+ const size_t read_size = fread (&out[0 ], 1 , size, file);
341333 if (read_size != size) {
342- printe (" Error reading file '%s': %s" , filename.c_str (), strerror (errno));
343- return " " ;
334+ printe (" Error reading file: %s" , strerror (errno));
344335 }
336+
345337 return out;
346338 }
347339
@@ -1098,59 +1090,66 @@ static int get_user_input(std::string & user_input, const std::string & user) {
10981090
10991091// Reads a chat template file to be used
11001092static std::string read_chat_template_file (const std::string & chat_template_file) {
1101- if (chat_template_file.empty ()){
1102- return " " ;
1103- }
1104-
11051093 File file;
1106- std::string chat_template = " " ;
1107- chat_template = file.read_all (chat_template_file);
1108- if (chat_template.empty ()){
1094+ if (!file.open (chat_template_file, " r" )) {
11091095 printe (" Error opening chat template file '%s': %s" , chat_template_file.c_str (), strerror (errno));
11101096 return " " ;
11111097 }
1112- return chat_template;
1098+
1099+ return file.to_string ();
1100+ }
1101+
1102+ static int process_user_message (const Opt & opt, const std::string & user_input, LlamaData & llama_data,
1103+ const common_chat_templates_ptr & chat_templates, int & prev_len,
1104+ const bool stdout_a_terminal) {
1105+ add_message (" user" , opt.user .empty () ? user_input : opt.user , llama_data);
1106+ int new_len;
1107+ if (apply_chat_template_with_error_handling (chat_templates.get (), llama_data, true , new_len, opt.use_jinja ) < 0 ) {
1108+ return 1 ;
1109+ }
1110+
1111+ std::string prompt (llama_data.fmtted .begin () + prev_len, llama_data.fmtted .begin () + new_len);
1112+ std::string response;
1113+ if (generate_response (llama_data, prompt, response, stdout_a_terminal)) {
1114+ return 1 ;
1115+ }
1116+
1117+ if (!opt.user .empty ()) {
1118+ return 2 ;
1119+ }
1120+
1121+ add_message (" assistant" , response, llama_data);
1122+ if (apply_chat_template_with_error_handling (chat_templates.get (), llama_data, false , prev_len, opt.use_jinja ) < 0 ) {
1123+ return 1 ;
1124+ }
1125+
1126+ return 0 ;
11131127}
11141128
11151129// Main chat loop function
1116- static int chat_loop (LlamaData & llama_data, const std::string & user, const std::string & chat_template_file, bool use_jinja ) {
1130+ static int chat_loop (LlamaData & llama_data, const Opt & opt ) {
11171131 int prev_len = 0 ;
11181132 llama_data.fmtted .resize (llama_n_ctx (llama_data.context .get ()));
1119-
1120- std::string chat_template = " " ;
1121- if (!chat_template_file.empty ()){
1122- chat_template = read_chat_template_file (chat_template_file);
1133+ std::string chat_template;
1134+ if (!opt.chat_template_file .empty ()) {
1135+ chat_template = read_chat_template_file (opt.chat_template_file );
11231136 }
1124- auto chat_templates = common_chat_templates_init (llama_data.model .get (), chat_template.empty () ? nullptr : chat_template);
11251137
1138+ common_chat_templates_ptr chat_templates = common_chat_templates_init (llama_data.model .get (), chat_template);
11261139 static const bool stdout_a_terminal = is_stdout_a_terminal ();
11271140 while (true ) {
11281141 // Get user input
11291142 std::string user_input;
1130- if (get_user_input (user_input, user) == 1 ) {
1143+ if (get_user_input (user_input, opt. user ) == 1 ) {
11311144 return 0 ;
11321145 }
11331146
1134- add_message (" user" , user.empty () ? user_input : user, llama_data);
1135- int new_len;
1136- if (apply_chat_template_with_error_handling (chat_templates.get (), llama_data, true , new_len, use_jinja) < 0 ) {
1137- return 1 ;
1138- }
1139-
1140- std::string prompt (llama_data.fmtted .begin () + prev_len, llama_data.fmtted .begin () + new_len);
1141- std::string response;
1142- if (generate_response (llama_data, prompt, response, stdout_a_terminal)) {
1147+ const int ret = process_user_message (opt, user_input, llama_data, chat_templates, prev_len, stdout_a_terminal);
1148+ if (ret == 1 ) {
11431149 return 1 ;
1144- }
1145-
1146- if (!user.empty ()) {
1150+ } else if (ret == 2 ) {
11471151 break ;
11481152 }
1149-
1150- add_message (" assistant" , response, llama_data);
1151- if (apply_chat_template_with_error_handling (chat_templates.get (), llama_data, false , prev_len, use_jinja) < 0 ) {
1152- return 1 ;
1153- }
11541153 }
11551154
11561155 return 0 ;
@@ -1208,7 +1207,7 @@ int main(int argc, const char ** argv) {
12081207 return 1 ;
12091208 }
12101209
1211- if (chat_loop (llama_data, opt. user , opt. chat_template_file , opt. use_jinja )) {
1210+ if (chat_loop (llama_data, opt)) {
12121211 return 1 ;
12131212 }
12141213
0 commit comments