Skip to content

Commit 20270fd

Browse files
committed
Some llama-run cleanups
Use consolidated open function call from File class. Change read_all to to_string(). Remove exclusive locking, the intent for that lock is to avoid multiple processes writing to the same file, it's not an issue for readers, although we may want to consider adding a shared lock. Remove passing nullptr as reference, references are never supposed to be null. Signed-off-by: Eric Curtin <[email protected]>
1 parent 4806498 commit 20270fd

File tree

1 file changed

+25
-35
lines changed

1 file changed

+25
-35
lines changed

examples/run/run.cpp

Lines changed: 25 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -323,25 +323,19 @@ class File {
323323
return 0;
324324
}
325325

326-
std::string read_all(const std::string & filename){
327-
open(filename, "r");
328-
lock();
329-
if (!file) {
330-
printe("Error opening file '%s': %s", filename.c_str(), strerror(errno));
331-
return "";
332-
}
333-
326+
std::string to_string() {
334327
fseek(file, 0, SEEK_END);
335-
size_t size = ftell(file);
328+
const size_t size = ftell(file);
336329
fseek(file, 0, SEEK_SET);
337-
338330
std::string out;
339-
out.resize(size);
340-
size_t read_size = fread(&out[0], 1, size, file);
341-
if (read_size != size) {
342-
printe("Error reading file '%s': %s", filename.c_str(), strerror(errno));
343-
return "";
331+
if (size > 0) {
332+
out.resize(size);
333+
const size_t read_size = fread(&out[0], 1, size, file);
334+
if (read_size != size) {
335+
printe("Error reading file: %s", strerror(errno));
336+
}
344337
}
338+
345339
return out;
346340
}
347341

@@ -1098,42 +1092,37 @@ static int get_user_input(std::string & user_input, const std::string & user) {
10981092

10991093
// Reads a chat template file to be used
11001094
static std::string read_chat_template_file(const std::string & chat_template_file) {
1101-
if(chat_template_file.empty()){
1102-
return "";
1103-
}
1104-
11051095
File file;
1106-
std::string chat_template = "";
1107-
chat_template = file.read_all(chat_template_file);
1108-
if(chat_template.empty()){
1096+
if (!file.open(chat_template_file, "r")) {
11091097
printe("Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno));
11101098
return "";
11111099
}
1112-
return chat_template;
1100+
1101+
return file.to_string();
11131102
}
11141103

11151104
// Main chat loop function
1116-
static int chat_loop(LlamaData & llama_data, const std::string & user, const std::string & chat_template_file, bool use_jinja) {
1105+
static int chat_loop(LlamaData & llama_data, const Opt & opt) {
11171106
int prev_len = 0;
11181107
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
1119-
1120-
std::string chat_template = "";
1121-
if(!chat_template_file.empty()){
1122-
chat_template = read_chat_template_file(chat_template_file);
1108+
std::string chat_template;
1109+
if (!opt.chat_template_file.empty()) {
1110+
chat_template = read_chat_template_file(opt.chat_template_file);
11231111
}
1124-
auto chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template.empty() ? nullptr : chat_template);
11251112

1113+
auto chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template);
11261114
static const bool stdout_a_terminal = is_stdout_a_terminal();
11271115
while (true) {
11281116
// Get user input
11291117
std::string user_input;
1130-
if (get_user_input(user_input, user) == 1) {
1118+
if (get_user_input(user_input, opt.user) == 1) {
11311119
return 0;
11321120
}
11331121

1134-
add_message("user", user.empty() ? user_input : user, llama_data);
1122+
add_message("user", opt.user.empty() ? user_input : opt.user, llama_data);
11351123
int new_len;
1136-
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, use_jinja) < 0) {
1124+
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, opt.use_jinja) <
1125+
0) {
11371126
return 1;
11381127
}
11391128

@@ -1143,12 +1132,13 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, const std
11431132
return 1;
11441133
}
11451134

1146-
if (!user.empty()) {
1135+
if (!opt.user.empty()) {
11471136
break;
11481137
}
11491138

11501139
add_message("assistant", response, llama_data);
1151-
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, use_jinja) < 0) {
1140+
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, opt.use_jinja) <
1141+
0) {
11521142
return 1;
11531143
}
11541144
}
@@ -1208,7 +1198,7 @@ int main(int argc, const char ** argv) {
12081198
return 1;
12091199
}
12101200

1211-
if (chat_loop(llama_data, opt.user, opt.chat_template_file, opt.use_jinja)) {
1201+
if (chat_loop(llama_data, opt)) {
12121202
return 1;
12131203
}
12141204

0 commit comments

Comments
 (0)