@@ -113,6 +113,7 @@ class Opt {
113113 llama_context_params ctx_params;
114114 llama_model_params model_params;
115115 std::string model_;
116+ std::string chat_template_file;
116117 std::string user;
117118 bool use_jinja = false ;
118119 int context_size = -1 , ngl = -1 ;
@@ -148,6 +149,16 @@ class Opt {
148149 return 0 ;
149150 }
150151
152+ int handle_option_with_value (int argc, const char ** argv, int & i, std::string & option_value) {
153+ if (i + 1 >= argc) {
154+ return 1 ;
155+ }
156+
157+ option_value = argv[++i];
158+
159+ return 0 ;
160+ }
161+
151162 int parse (int argc, const char ** argv) {
152163 bool options_parsing = true ;
153164 for (int i = 1 , positional_args_i = 0 ; i < argc; ++i) {
@@ -169,6 +180,11 @@ class Opt {
169180 verbose = true ;
170181 } else if (options_parsing && strcmp (argv[i], " --jinja" ) == 0 ) {
171182 use_jinja = true ;
183+ } else if (options_parsing && strcmp (argv[i], " --chat-template-file" ) == 0 ){
184+ if (handle_option_with_value (argc, argv, i, chat_template_file) == 1 ) {
185+ return 1 ;
186+ }
187+ use_jinja = true ;
172188 } else if (options_parsing && parse_flag (argv, i, " -h" , " --help" )) {
173189 help = true ;
174190 return 0 ;
@@ -207,6 +223,11 @@ class Opt {
207223 " Options:\n "
208224 " -c, --context-size <value>\n "
209225 " Context size (default: %d)\n "
226+ " --chat-template-file <path>\n "
227+ " Path to the file containing the chat template to use with the model.\n "
228+ " Only supports jinja templates and implicitly sets the --jinja flag.\n "
229+ " --jinja\n "
230+ " Use jinja templating for the chat template of the model\n "
210231 " -n, -ngl, --ngl <value>\n "
211232 " Number of GPU layers (default: %d)\n "
212233 " --temp <value>\n "
@@ -261,13 +282,12 @@ static int get_terminal_width() {
261282#endif
262283}
263284
264- #ifdef LLAMA_USE_CURL
265285class File {
266286 public:
267287 FILE * file = nullptr ;
268288
269289 FILE * open (const std::string & filename, const char * mode) {
270- file = fopen (filename.c_str (), mode);
290+ file = ggml_fopen (filename.c_str (), mode);
271291
272292 return file;
273293 }
@@ -303,6 +323,28 @@ class File {
303323 return 0 ;
304324 }
305325
326+ std::string read_all (const std::string & filename){
327+ open (filename, " r" );
328+ lock ();
329+ if (!file) {
330+ printe (" Error opening file '%s': %s" , filename.c_str (), strerror (errno));
331+ return " " ;
332+ }
333+
334+ fseek (file, 0 , SEEK_END);
335+ size_t size = ftell (file);
336+ fseek (file, 0 , SEEK_SET);
337+
338+ std::string out;
339+ out.resize (size);
340+ size_t read_size = fread (&out[0 ], 1 , size, file);
341+ if (read_size != size) {
342+ printe (" Error reading file '%s': %s" , filename.c_str (), strerror (errno));
343+ return " " ;
344+ }
345+ return out;
346+ }
347+
306348 ~File () {
307349 if (fd >= 0 ) {
308350# ifdef _WIN32
@@ -327,6 +369,7 @@ class File {
327369# endif
328370};
329371
372+ #ifdef LLAMA_USE_CURL
330373class HttpClient {
331374 public:
332375 int init (const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
@@ -1053,11 +1096,33 @@ static int get_user_input(std::string & user_input, const std::string & user) {
10531096 return 0 ;
10541097}
10551098
1099+ // Reads a chat template file to be used
1100+ static std::string read_chat_template_file (const std::string & chat_template_file) {
1101+ if (chat_template_file.empty ()){
1102+ return " " ;
1103+ }
1104+
1105+ File file;
1106+ std::string chat_template = " " ;
1107+ chat_template = file.read_all (chat_template_file);
1108+ if (chat_template.empty ()){
1109+ printe (" Error opening chat template file '%s': %s" , chat_template_file.c_str (), strerror (errno));
1110+ return " " ;
1111+ }
1112+ return chat_template;
1113+ }
1114+
10561115// Main chat loop function
1057- static int chat_loop (LlamaData & llama_data, const std::string & user, bool use_jinja) {
1116+ static int chat_loop (LlamaData & llama_data, const std::string & user, const std::string & chat_template_file, bool use_jinja) {
10581117 int prev_len = 0 ;
10591118 llama_data.fmtted .resize (llama_n_ctx (llama_data.context .get ()));
1060- auto chat_templates = common_chat_templates_init (llama_data.model .get (), " " );
1119+
1120+ std::string chat_template = " " ;
1121+ if (!chat_template_file.empty ()){
1122+ chat_template = read_chat_template_file (chat_template_file);
1123+ }
1124+ auto chat_templates = common_chat_templates_init (llama_data.model .get (), chat_template.empty () ? nullptr : chat_template);
1125+
10611126 static const bool stdout_a_terminal = is_stdout_a_terminal ();
10621127 while (true ) {
10631128 // Get user input
@@ -1143,7 +1208,7 @@ int main(int argc, const char ** argv) {
11431208 return 1 ;
11441209 }
11451210
1146- if (chat_loop (llama_data, opt.user , opt.use_jinja )) {
1211+ if (chat_loop (llama_data, opt.user , opt.chat_template_file , opt. use_jinja )) {
11471212 return 1 ;
11481213 }
11491214
0 commit comments