11#if defined(_WIN32)
2- # include < windows.h>
32# include < io.h>
3+ # include < windows.h>
44#else
55# include < sys/file.h>
66# include < sys/ioctl.h>
1212#endif
1313
1414#include < signal.h>
15+ #include < sys/stat.h>
1516
1617#include < climits>
1718#include < cstdarg>
1819#include < cstdio>
1920#include < cstring>
2021#include < filesystem>
22+ #include < fstream>
2123#include < iostream>
2224#include < sstream>
2325#include < string>
3436}
3537#endif
3638
39+ #define LLAMA_USE_CURL
40+
3741GGML_ATTRIBUTE_FORMAT (1 , 2 )
42+
3843static std::string fmt(const char * fmt, ...) {
3944 va_list ap;
4045 va_list ap2;
4146 va_start (ap, fmt);
4247 va_copy (ap2, ap);
4348 const int size = vsnprintf (NULL , 0 , fmt, ap);
44- GGML_ASSERT (size >= 0 && size < INT_MAX); // NOLINT
49+ GGML_ASSERT (size >= 0 && size < INT_MAX); // NOLINT
4550 std::string buf;
4651 buf.resize (size);
4752 const int size2 = vsnprintf (const_cast <char *>(buf.data ()), buf.size () + 1 , fmt, ap2);
@@ -53,6 +58,7 @@ static std::string fmt(const char * fmt, ...) {
5358}
5459
5560GGML_ATTRIBUTE_FORMAT (1 , 2 )
61+
5662static int printe(const char * fmt, ...) {
5763 va_list args;
5864 va_start (args, fmt);
@@ -101,7 +107,8 @@ class Opt {
101107
102108 llama_context_params ctx_params;
103109 llama_model_params model_params;
104- std::string model_;
110+ std::string model_;
111+ std::string chat_template_;
105112 std::string user;
106113 int context_size = -1 , ngl = -1 ;
107114 float temperature = -1 ;
@@ -137,7 +144,7 @@ class Opt {
137144 }
138145
139146 int parse (int argc, const char ** argv) {
140- bool options_parsing = true ;
147+ bool options_parsing = true ;
141148 for (int i = 1 , positional_args_i = 0 ; i < argc; ++i) {
142149 if (options_parsing && (strcmp (argv[i], " -c" ) == 0 || strcmp (argv[i], " --context-size" ) == 0 )) {
143150 if (handle_option_with_value (argc, argv, i, context_size) == 1 ) {
@@ -166,6 +173,11 @@ class Opt {
166173
167174 ++positional_args_i;
168175 model_ = argv[i];
176+ } else if (options_parsing && strcmp (argv[i], " --chat-template" ) == 0 ) {
177+ if (i + 1 >= argc) {
178+ return 1 ;
179+ }
180+ chat_template_ = argv[++i];
169181 } else if (positional_args_i == 1 ) {
170182 ++positional_args_i;
171183 user = argv[i];
@@ -475,7 +487,9 @@ class HttpClient {
475487 return (now_downloaded_plus_file_size * 100 ) / total_to_download;
476488 }
477489
478- static std::string generate_progress_prefix (curl_off_t percentage) { return fmt (" %3ld%% |" , static_cast <long int >(percentage)); }
490+ static std::string generate_progress_prefix (curl_off_t percentage) {
491+ return fmt (" %3ld%% |" , static_cast <long int >(percentage));
492+ }
479493
480494 static double calculate_speed (curl_off_t now_downloaded, const std::chrono::steady_clock::time_point & start_time) {
481495 const auto now = std::chrono::steady_clock::now ();
@@ -515,6 +529,7 @@ class HttpClient {
515529 printe (" \r %*s\r %s%s| %s" , get_terminal_width (), " " , progress_prefix.c_str (), progress_bar.c_str (),
516530 progress_suffix.c_str ());
517531 }
532+
518533 // Function to write data to a file
519534 static size_t write_data (void * ptr, size_t size, size_t nmemb, void * stream) {
520535 FILE * out = static_cast <FILE *>(stream);
@@ -538,19 +553,23 @@ class LlamaData {
538553 std::vector<llama_chat_message> messages;
539554 std::vector<std::string> msg_strs;
540555 std::vector<char > fmtted;
556+ std::string chat_template;
541557
542558 int init (Opt & opt) {
543559 model = initialize_model (opt);
544560 if (!model) {
545561 return 1 ;
546562 }
547563
564+ chat_template = initialize_chat_template (opt);
565+
548566 context = initialize_context (model, opt);
549567 if (!context) {
550568 return 1 ;
551569 }
552570
553571 sampler = initialize_sampler (opt);
572+
554573 return 0 ;
555574 }
556575
@@ -573,21 +592,76 @@ class LlamaData {
573592 }
574593#endif
575594
576- int huggingface_dl (const std::string & model, const std::vector<std::string> headers, const std::string & bn) {
595+ int huggingface_dl_tmpl (const std::string & hfr, const std::vector<std::string> headers, const std::string & tn) {
596+ // if template already exists, don't download it
597+ struct stat info;
598+ if (stat (tn.c_str (), &info) == 0 ) {
599+ return 0 ;
600+ }
601+
602+ const std::string config_url = " https://huggingface.co/" + hfr + " /resolve/main/tokenizer_config.json" ;
603+ std::string tokenizer_config_str;
604+ download (config_url, headers, " " , true , &tokenizer_config_str);
605+ if (tokenizer_config_str.empty ()) {
606+ // still return success since tokenizer_config is optional
607+ return 0 ;
608+ }
609+
610+ nlohmann::json config = nlohmann::json::parse (tokenizer_config_str);
611+ std::string tmpl = config[" chat_template" ];
612+
613+ FILE * tmpl_file = fopen (tn.c_str (), " w" );
614+ if (tmpl_file == NULL ) {
615+ return 1 ;
616+ }
617+ fprintf (tmpl_file, " %s" , tmpl.c_str ());
618+ fclose (tmpl_file);
619+
620+ return 0 ;
621+ }
622+
623+ int huggingface_dl (const std::string & model, const std::vector<std::string> headers, const std::string & bn,
624+ const std::string & tn) {
625+ bool model_exists = std::filesystem::exists (bn);
626+ bool chat_tmpl_exists = std::filesystem::exists (tn);
627+ if (model_exists && chat_tmpl_exists) {
628+ return 0 ;
629+ }
630+
577631 // Find the second occurrence of '/' after protocol string
578632 size_t pos = model.find (' /' );
579633 pos = model.find (' /' , pos + 1 );
580634 if (pos == std::string::npos) {
581635 return 1 ;
582636 }
583-
584637 const std::string hfr = model.substr (0 , pos);
585638 const std::string hff = model.substr (pos + 1 );
586- const std::string url = " https://huggingface.co/" + hfr + " /resolve/main/" + hff;
587- return download (url, headers, bn, true );
639+
640+ if (!chat_tmpl_exists) {
641+ const int ret = huggingface_dl_tmpl (hfr, headers, tn);
642+ if (ret) {
643+ return ret;
644+ }
645+ }
646+
647+ if (!model_exists) {
648+ const std::string url = " https://huggingface.co/" + hfr + " /resolve/main/" + hff;
649+ const int ret = download (url, headers, bn, true );
650+ if (ret) {
651+ return ret;
652+ }
653+ }
654+ return 0 ;
588655 }
589656
590- int ollama_dl (std::string & model, const std::vector<std::string> headers, const std::string & bn) {
657+ int ollama_dl (std::string & model, const std::vector<std::string> headers, const std::string & bn,
658+ const std::string & tn) {
659+ bool model_exists = std::filesystem::exists (bn);
660+ bool chat_tmpl_exists = std::filesystem::exists (tn);
661+ if (model_exists && chat_tmpl_exists) {
662+ return 0 ;
663+ }
664+
591665 if (model.find (' /' ) == std::string::npos) {
592666 model = " library/" + model;
593667 }
@@ -607,16 +681,34 @@ class LlamaData {
607681 }
608682
609683 nlohmann::json manifest = nlohmann::json::parse (manifest_str);
610- std::string layer;
684+ std::string sha_model;
685+ std::string sha_template;
611686 for (const auto & l : manifest[" layers" ]) {
612687 if (l[" mediaType" ] == " application/vnd.ollama.image.model" ) {
613- layer = l[" digest" ];
614- break ;
688+ sha_model = l[" digest" ];
689+ }
690+ if (l[" mediaType" ] == " application/vnd.ollama.image.template" ) {
691+ sha_template = l[" digest" ];
692+ }
693+ }
694+
695+ if (!chat_tmpl_exists && !sha_template.empty ()) {
696+ std::string tmpl_blob_url = " https://registry.ollama.ai/v2/" + model + " /blobs/" + sha_template;
697+ const int tmpl_ret = download (tmpl_blob_url, headers, tn, true );
698+ if (tmpl_ret) {
699+ return tmpl_ret;
615700 }
616701 }
617702
618- std::string blob_url = " https://registry.ollama.ai/v2/" + model + " /blobs/" + layer;
619- return download (blob_url, headers, bn, true );
703+ if (!model_exists) {
704+ std::string model_blob_url = " https://registry.ollama.ai/v2/" + model + " /blobs/" + sha_model;
705+ const int model_ret = download (model_blob_url, headers, bn, true );
706+ if (model_ret) {
707+ return model_ret;
708+ }
709+ }
710+
711+ return 0 ;
620712 }
621713
622714 std::string basename (const std::string & path) {
@@ -638,38 +730,38 @@ class LlamaData {
638730 return 0 ;
639731 }
640732
641- int resolve_model (std::string & model_) {
642- int ret = 0 ;
643- if (string_starts_with (model_, " file://" ) || std::filesystem::exists (model_) ) {
733+ int resolve_model (std::string & model_, std::string & chat_template_ ) {
734+ int ret = 0 ;
735+ if (string_starts_with (model_, " file://" )) {
644736 remove_proto (model_);
645-
646737 return ret;
647738 }
648739
740+ remove_proto (model_);
649741 const std::string bn = basename (model_);
742+ const std::string tn = chat_template_.empty () ? bn + " .template" : chat_template_;
650743 const std::vector<std::string> headers = { " --header" ,
651744 " Accept: application/vnd.docker.distribution.manifest.v2+json" };
652745 if (string_starts_with (model_, " hf://" ) || string_starts_with (model_, " huggingface://" )) {
653- remove_proto (model_);
654- ret = huggingface_dl (model_, headers, bn);
746+ ret = huggingface_dl (model_, headers, bn, tn);
655747 } else if (string_starts_with (model_, " ollama://" )) {
656- remove_proto (model_);
657- ret = ollama_dl (model_, headers, bn);
748+ ret = ollama_dl (model_, headers, bn, tn);
658749 } else if (string_starts_with (model_, " https://" )) {
659750 download (model_, headers, bn, true );
660751 } else {
661- ret = ollama_dl (model_, headers, bn);
752+ ret = ollama_dl (model_, headers, bn, tn );
662753 }
663754
664- model_ = bn;
755+ model_ = bn;
756+ chat_template_ = tn;
665757
666758 return ret;
667759 }
668760
669761 // Initializes the model and returns a unique pointer to it
670762 llama_model_ptr initialize_model (Opt & opt) {
671763 ggml_backend_load_all ();
672- resolve_model (opt.model_ );
764+ resolve_model (opt.model_ , opt. chat_template_ );
673765 printe (
674766 " \r %*s"
675767 " \r Loading model" ,
@@ -702,6 +794,27 @@ class LlamaData {
702794
703795 return sampler;
704796 }
797+
798+ std::string initialize_chat_template (const Opt & opt) {
799+ // if no template file doesn't exists, just return an empty string
800+ struct stat info;
801+ if (stat (opt.chat_template_ .c_str (), &info) != 0 ) {
802+ return " " ;
803+ }
804+
805+ std::ifstream tmpl_file;
806+ tmpl_file.open (opt.chat_template_ );
807+ if (tmpl_file.fail ()) {
808+ printe (" failed to open chat template: '%s'\n " , opt.chat_template_ .c_str ());
809+ return " " ;
810+ }
811+
812+ std::stringstream stream;
813+ stream << tmpl_file.rdbuf ();
814+ tmpl_file.close ();
815+
816+ return stream.str ();
817+ }
705818};
706819
707820// Add a message to `messages` and store its content in `msg_strs`
@@ -713,13 +826,15 @@ static void add_message(const char * role, const std::string & text, LlamaData &
713826// Function to apply the chat template and resize `formatted` if needed
714827static int apply_chat_template (LlamaData & llama_data, const bool append) {
715828 int result = llama_chat_apply_template (
716- llama_data.model .get (), nullptr , llama_data.messages .data (), llama_data.messages .size (), append,
717- append ? llama_data.fmtted .data () : nullptr , append ? llama_data.fmtted .size () : 0 );
829+ llama_data.model .get (), llama_data.chat_template .empty () ? nullptr : llama_data.chat_template .c_str (),
830+ llama_data.messages .data (), llama_data.messages .size (), append, append ? llama_data.fmtted .data () : nullptr ,
831+ append ? llama_data.fmtted .size () : 0 );
718832 if (append && result > static_cast <int >(llama_data.fmtted .size ())) {
719833 llama_data.fmtted .resize (result);
720- result = llama_chat_apply_template (llama_data.model .get (), nullptr , llama_data.messages .data (),
721- llama_data.messages .size (), append, llama_data.fmtted .data (),
722- llama_data.fmtted .size ());
834+ result = llama_chat_apply_template (
835+ llama_data.model .get (), llama_data.chat_template .empty () ? nullptr : llama_data.chat_template .c_str (),
836+ llama_data.messages .data (), llama_data.messages .size (), append, llama_data.fmtted .data (),
837+ llama_data.fmtted .size ());
723838 }
724839
725840 return result;
0 commit comments