11#if defined(_WIN32)
2- # include < windows.h>
32# include < io.h>
3+ # include < windows.h>
44#else
55# include < sys/file.h>
66# include < sys/ioctl.h>
1212#endif
1313
1414#include < signal.h>
15+ #include < sys/stat.h>
1516
1617#include < climits>
1718#include < cstdarg>
1819#include < cstdio>
1920#include < cstring>
2021#include < filesystem>
22+ #include < fstream>
2123#include < iostream>
2224#include < sstream>
2325#include < string>
3537#endif
3638
3739GGML_ATTRIBUTE_FORMAT (1 , 2 )
40+
3841static std::string fmt(const char * fmt, ...) {
3942 va_list ap;
4043 va_list ap2;
4144 va_start (ap, fmt);
4245 va_copy (ap2, ap);
4346 const int size = vsnprintf (NULL , 0 , fmt, ap);
44- GGML_ASSERT (size >= 0 && size < INT_MAX); // NOLINT
47+ GGML_ASSERT (size >= 0 && size < INT_MAX); // NOLINT
4548 std::string buf;
4649 buf.resize (size);
4750 const int size2 = vsnprintf (const_cast <char *>(buf.data ()), buf.size () + 1 , fmt, ap2);
@@ -53,6 +56,7 @@ static std::string fmt(const char * fmt, ...) {
5356}
5457
5558GGML_ATTRIBUTE_FORMAT (1 , 2 )
59+
5660static int printe(const char * fmt, ...) {
5761 va_list args;
5862 va_start (args, fmt);
@@ -101,7 +105,8 @@ class Opt {
101105
102106 llama_context_params ctx_params;
103107 llama_model_params model_params;
104- std::string model_;
108+ std::string model_;
109+ std::string chat_template_;
105110 std::string user;
106111 int context_size = -1 , ngl = -1 ;
107112 float temperature = -1 ;
@@ -137,7 +142,7 @@ class Opt {
137142 }
138143
139144 int parse (int argc, const char ** argv) {
140- bool options_parsing = true ;
145+ bool options_parsing = true ;
141146 for (int i = 1 , positional_args_i = 0 ; i < argc; ++i) {
142147 if (options_parsing && (strcmp (argv[i], " -c" ) == 0 || strcmp (argv[i], " --context-size" ) == 0 )) {
143148 if (handle_option_with_value (argc, argv, i, context_size) == 1 ) {
@@ -166,6 +171,11 @@ class Opt {
166171
167172 ++positional_args_i;
168173 model_ = argv[i];
174+ } else if (options_parsing && strcmp (argv[i], " --chat-template" ) == 0 ) {
175+ if (i + 1 >= argc) {
176+ return 1 ;
177+ }
178+ chat_template_ = argv[++i];
169179 } else if (positional_args_i == 1 ) {
170180 ++positional_args_i;
171181 user = argv[i];
@@ -475,7 +485,9 @@ class HttpClient {
475485 return (now_downloaded_plus_file_size * 100 ) / total_to_download;
476486 }
477487
478- static std::string generate_progress_prefix (curl_off_t percentage) { return fmt (" %3ld%% |" , static_cast <long int >(percentage)); }
488+ static std::string generate_progress_prefix (curl_off_t percentage) {
489+ return fmt (" %3ld%% |" , static_cast <long int >(percentage));
490+ }
479491
480492 static double calculate_speed (curl_off_t now_downloaded, const std::chrono::steady_clock::time_point & start_time) {
481493 const auto now = std::chrono::steady_clock::now ();
@@ -515,6 +527,7 @@ class HttpClient {
515527 printe (" \r %*s\r %s%s| %s" , get_terminal_width (), " " , progress_prefix.c_str (), progress_bar.c_str (),
516528 progress_suffix.c_str ());
517529 }
530+
518531 // Function to write data to a file
519532 static size_t write_data (void * ptr, size_t size, size_t nmemb, void * stream) {
520533 FILE * out = static_cast <FILE *>(stream);
@@ -538,19 +551,23 @@ class LlamaData {
538551 std::vector<llama_chat_message> messages;
539552 std::vector<std::string> msg_strs;
540553 std::vector<char > fmtted;
554+ std::string chat_template;
541555
542556 int init (Opt & opt) {
543557 model = initialize_model (opt);
544558 if (!model) {
545559 return 1 ;
546560 }
547561
562+ chat_template = initialize_chat_template (model, opt);
563+
548564 context = initialize_context (model, opt);
549565 if (!context) {
550566 return 1 ;
551567 }
552568
553569 sampler = initialize_sampler (opt);
570+
554571 return 0 ;
555572 }
556573
@@ -573,21 +590,76 @@ class LlamaData {
573590 }
574591#endif
575592
576- int huggingface_dl (const std::string & model, const std::vector<std::string> headers, const std::string & bn) {
593+ int huggingface_dl_tmpl (const std::string & hfr, const std::vector<std::string> headers, const std::string & tn) {
594+ // if template already exists, don't download it
595+ struct stat info;
596+ if (stat (tn.c_str (), &info) == 0 ) {
597+ return 0 ;
598+ }
599+
600+ const std::string config_url = " https://huggingface.co/" + hfr + " /resolve/main/tokenizer_config.json" ;
601+ std::string tokenizer_config_str;
602+ download (config_url, headers, " " , true , &tokenizer_config_str);
603+ if (tokenizer_config_str.empty ()) {
604+ // still return success since tokenizer_config is optional
605+ return 0 ;
606+ }
607+
608+ nlohmann::json config = nlohmann::json::parse (tokenizer_config_str);
609+ std::string tmpl = config[" chat_template" ];
610+
611+ FILE * tmpl_file = fopen (tn.c_str (), " w" );
612+ if (tmpl_file == NULL ) {
613+ return 1 ;
614+ }
615+ fprintf (tmpl_file, " %s" , tmpl.c_str ());
616+ fclose (tmpl_file);
617+
618+ return 0 ;
619+ }
620+
621+ int huggingface_dl (const std::string & model, const std::vector<std::string> headers, const std::string & bn,
622+ const std::string & tn) {
623+ bool model_exists = std::filesystem::exists (bn);
624+ bool chat_tmpl_exists = std::filesystem::exists (tn);
625+ if (model_exists && chat_tmpl_exists) {
626+ return 0 ;
627+ }
628+
577629 // Find the second occurrence of '/' after protocol string
578630 size_t pos = model.find (' /' );
579631 pos = model.find (' /' , pos + 1 );
580632 if (pos == std::string::npos) {
581633 return 1 ;
582634 }
583-
584635 const std::string hfr = model.substr (0 , pos);
585636 const std::string hff = model.substr (pos + 1 );
586- const std::string url = " https://huggingface.co/" + hfr + " /resolve/main/" + hff;
587- return download (url, headers, bn, true );
637+
638+ if (!chat_tmpl_exists) {
639+ const int ret = huggingface_dl_tmpl (hfr, headers, tn);
640+ if (ret) {
641+ return ret;
642+ }
643+ }
644+
645+ if (!model_exists) {
646+ const std::string url = " https://huggingface.co/" + hfr + " /resolve/main/" + hff;
647+ const int ret = download (url, headers, bn, true );
648+ if (ret) {
649+ return ret;
650+ }
651+ }
652+ return 0 ;
588653 }
589654
590- int ollama_dl (std::string & model, const std::vector<std::string> headers, const std::string & bn) {
655+ int ollama_dl (std::string & model, const std::vector<std::string> headers, const std::string & bn,
656+ const std::string & tn) {
657+ bool model_exists = std::filesystem::exists (bn);
658+ bool chat_tmpl_exists = std::filesystem::exists (tn);
659+ if (model_exists && chat_tmpl_exists) {
660+ return 0 ;
661+ }
662+
591663 if (model.find (' /' ) == std::string::npos) {
592664 model = " library/" + model;
593665 }
@@ -607,16 +679,34 @@ class LlamaData {
607679 }
608680
609681 nlohmann::json manifest = nlohmann::json::parse (manifest_str);
610- std::string layer;
682+ std::string sha_model;
683+ std::string sha_template;
611684 for (const auto & l : manifest[" layers" ]) {
612685 if (l[" mediaType" ] == " application/vnd.ollama.image.model" ) {
613- layer = l[" digest" ];
614- break ;
686+ sha_model = l[" digest" ];
687+ }
688+ if (l[" mediaType" ] == " application/vnd.ollama.image.template" ) {
689+ sha_template = l[" digest" ];
690+ }
691+ }
692+
693+ if (!chat_tmpl_exists && !sha_template.empty ()) {
694+ std::string tmpl_blob_url = " https://registry.ollama.ai/v2/" + model + " /blobs/" + sha_template;
695+ const int tmpl_ret = download (tmpl_blob_url, headers, tn, true );
696+ if (tmpl_ret) {
697+ return tmpl_ret;
698+ }
699+ }
700+
701+ if (!model_exists) {
702+ std::string model_blob_url = " https://registry.ollama.ai/v2/" + model + " /blobs/" + sha_model;
703+ const int model_ret = download (model_blob_url, headers, bn, true );
704+ if (model_ret) {
705+ return model_ret;
615706 }
616707 }
617708
618- std::string blob_url = " https://registry.ollama.ai/v2/" + model + " /blobs/" + layer;
619- return download (blob_url, headers, bn, true );
709+ return 0 ;
620710 }
621711
622712 std::string basename (const std::string & path) {
@@ -628,6 +718,15 @@ class LlamaData {
628718 return path.substr (pos + 1 );
629719 }
630720
721+ std::string get_proto (const std::string & model_) {
722+ const std::string::size_type pos = model_.find (" ://" );
723+ if (pos == std::string::npos) {
724+ return " " ;
725+ }
726+
727+ return model_.substr (0 , pos + 3 ); // Include "://"
728+ }
729+
631730 int remove_proto (std::string & model_) {
632731 const std::string::size_type pos = model_.find (" ://" );
633732 if (pos == std::string::npos) {
@@ -638,38 +737,40 @@ class LlamaData {
638737 return 0 ;
639738 }
640739
641- int resolve_model (std::string & model_) {
642- int ret = 0 ;
643- if (string_starts_with (model_, " file://" ) || std::filesystem::exists (model_) ) {
740+ int resolve_model (std::string & model_, std::string & chat_template_ ) {
741+ int ret = 0 ;
742+ if (string_starts_with (model_, " file://" )) {
644743 remove_proto (model_);
645-
646744 return ret;
647745 }
648746
747+ std::string proto = get_proto (model_);
748+ remove_proto (model_);
749+
649750 const std::string bn = basename (model_);
751+ const std::string tn = chat_template_.empty () ? bn + " .template" : chat_template_;
650752 const std::vector<std::string> headers = { " --header" ,
651753 " Accept: application/vnd.docker.distribution.manifest.v2+json" };
652- if (string_starts_with (model_, " hf://" ) || string_starts_with (model_, " huggingface://" )) {
653- remove_proto (model_);
654- ret = huggingface_dl (model_, headers, bn);
655- } else if (string_starts_with (model_, " ollama://" )) {
656- remove_proto (model_);
657- ret = ollama_dl (model_, headers, bn);
658- } else if (string_starts_with (model_, " https://" )) {
754+ if (string_starts_with (proto, " hf://" ) || string_starts_with (proto, " huggingface://" )) {
755+ ret = huggingface_dl (model_, headers, bn, tn);
756+ } else if (string_starts_with (proto, " ollama://" )) {
757+ ret = ollama_dl (model_, headers, bn, tn);
758+ } else if (string_starts_with (proto, " https://" )) {
659759 download (model_, headers, bn, true );
660760 } else {
661- ret = ollama_dl (model_, headers, bn);
761+ ret = ollama_dl (model_, headers, bn, tn );
662762 }
663763
664- model_ = bn;
764+ model_ = bn;
765+ chat_template_ = tn;
665766
666767 return ret;
667768 }
668769
669770 // Initializes the model and returns a unique pointer to it
670771 llama_model_ptr initialize_model (Opt & opt) {
671772 ggml_backend_load_all ();
672- resolve_model (opt.model_ );
773+ resolve_model (opt.model_ , opt. chat_template_ );
673774 printe (
674775 " \r %*s"
675776 " \r Loading model" ,
@@ -702,6 +803,27 @@ class LlamaData {
702803
703804 return sampler;
704805 }
806+
807+ std::string initialize_chat_template (const llama_model_ptr & model, const Opt & opt) {
808+ // if no template file doesn't exists, just return an empty string
809+ struct stat info;
810+ if (stat (opt.chat_template_ .c_str (), &info) != 0 ) {
811+ return common_get_builtin_chat_template (model.get ());
812+ }
813+
814+ std::ifstream tmpl_file;
815+ tmpl_file.open (opt.chat_template_ );
816+ if (tmpl_file.fail ()) {
817+ printe (" failed to open chat template: '%s'\n " , opt.chat_template_ .c_str ());
818+ return " " ;
819+ }
820+
821+ std::stringstream stream;
822+ stream << tmpl_file.rdbuf ();
823+ tmpl_file.close ();
824+
825+ return stream.str ();
826+ }
705827};
706828
707829// Add a message to `messages` and store its content in `msg_strs`
@@ -713,11 +835,11 @@ static void add_message(const char * role, const std::string & text, LlamaData &
713835// Function to apply the chat template and resize `formatted` if needed
714836static int apply_chat_template (LlamaData & llama_data, const bool append) {
715837 int result = llama_chat_apply_template (
716- llama_model_chat_template ( llama_data.model . get () ), llama_data.messages .data (), llama_data.messages .size (), append,
838+ llama_data.chat_template . c_str ( ), llama_data.messages .data (), llama_data.messages .size (), append,
717839 append ? llama_data.fmtted .data () : nullptr , append ? llama_data.fmtted .size () : 0 );
718840 if (append && result > static_cast <int >(llama_data.fmtted .size ())) {
719841 llama_data.fmtted .resize (result);
720- result = llama_chat_apply_template (llama_model_chat_template ( llama_data.model . get () ), llama_data.messages .data (),
842+ result = llama_chat_apply_template (llama_data.chat_template . c_str ( ), llama_data.messages .data (),
721843 llama_data.messages .size (), append, llama_data.fmtted .data (),
722844 llama_data.fmtted .size ());
723845 }
@@ -730,8 +852,8 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt
730852 std::vector<llama_token> & prompt_tokens) {
731853 const int n_prompt_tokens = -llama_tokenize (vocab, prompt.c_str (), prompt.size (), NULL , 0 , true , true );
732854 prompt_tokens.resize (n_prompt_tokens);
733- if (llama_tokenize (vocab, prompt.c_str (), prompt.size (), prompt_tokens.data (), prompt_tokens.size (), true ,
734- true ) < 0 ) {
855+ if (llama_tokenize (vocab, prompt.c_str (), prompt.size (), prompt_tokens.data (), prompt_tokens.size (), true , true ) <
856+ 0 ) {
735857 printe (" failed to tokenize the prompt\n " );
736858 return -1 ;
737859 }
0 commit comments