11#if defined(_WIN32)
2- # include < windows.h>
32# include < io.h>
3+ # include < windows.h>
44#else
55# include < sys/file.h>
66# include < sys/ioctl.h>
1212#endif
1313
1414#include < signal.h>
15+ #include < sys/stat.h>
1516
1617#include < climits>
1718#include < cstdarg>
1819#include < cstdio>
1920#include < cstring>
2021#include < filesystem>
22+ #include < fstream>
2123#include < iostream>
2224#include < sstream>
2325#include < string>
3537#endif
3638
3739GGML_ATTRIBUTE_FORMAT (1 , 2 )
40+
3841static std::string fmt(const char * fmt, ...) {
3942 va_list ap;
4043 va_list ap2;
4144 va_start (ap, fmt);
4245 va_copy (ap2, ap);
4346 const int size = vsnprintf (NULL , 0 , fmt, ap);
44- GGML_ASSERT (size >= 0 && size < INT_MAX); // NOLINT
47+ GGML_ASSERT (size >= 0 && size < INT_MAX); // NOLINT
4548 std::string buf;
4649 buf.resize (size);
4750 const int size2 = vsnprintf (const_cast <char *>(buf.data ()), buf.size () + 1 , fmt, ap2);
@@ -53,6 +56,7 @@ static std::string fmt(const char * fmt, ...) {
5356}
5457
5558GGML_ATTRIBUTE_FORMAT (1 , 2 )
59+
5660static int printe(const char * fmt, ...) {
5761 va_list args;
5862 va_start (args, fmt);
@@ -101,7 +105,8 @@ class Opt {
101105
102106 llama_context_params ctx_params;
103107 llama_model_params model_params;
104- std::string model_;
108+ std::string model_;
109+ std::string chat_template_;
105110 std::string user;
106111 int context_size = -1 , ngl = -1 ;
107112 float temperature = -1 ;
@@ -137,7 +142,7 @@ class Opt {
137142 }
138143
139144 int parse (int argc, const char ** argv) {
140- bool options_parsing = true ;
145+ bool options_parsing = true ;
141146 for (int i = 1 , positional_args_i = 0 ; i < argc; ++i) {
142147 if (options_parsing && (strcmp (argv[i], " -c" ) == 0 || strcmp (argv[i], " --context-size" ) == 0 )) {
143148 if (handle_option_with_value (argc, argv, i, context_size) == 1 ) {
@@ -166,6 +171,11 @@ class Opt {
166171
167172 ++positional_args_i;
168173 model_ = argv[i];
174+ } else if (options_parsing && strcmp (argv[i], " --chat-template" ) == 0 ) {
175+ if (i + 1 >= argc) {
176+ return 1 ;
177+ }
178+ chat_template_ = argv[++i];
169179 } else if (positional_args_i == 1 ) {
170180 ++positional_args_i;
171181 user = argv[i];
@@ -475,7 +485,9 @@ class HttpClient {
475485 return (now_downloaded_plus_file_size * 100 ) / total_to_download;
476486 }
477487
478- static std::string generate_progress_prefix (curl_off_t percentage) { return fmt (" %3ld%% |" , static_cast <long int >(percentage)); }
488+ static std::string generate_progress_prefix (curl_off_t percentage) {
489+ return fmt (" %3ld%% |" , static_cast <long int >(percentage));
490+ }
479491
480492 static double calculate_speed (curl_off_t now_downloaded, const std::chrono::steady_clock::time_point & start_time) {
481493 const auto now = std::chrono::steady_clock::now ();
@@ -515,6 +527,7 @@ class HttpClient {
515527 printe (" \r %*s\r %s%s| %s" , get_terminal_width (), " " , progress_prefix.c_str (), progress_bar.c_str (),
516528 progress_suffix.c_str ());
517529 }
530+
518531 // Function to write data to a file
519532 static size_t write_data (void * ptr, size_t size, size_t nmemb, void * stream) {
520533 FILE * out = static_cast <FILE *>(stream);
@@ -538,19 +551,23 @@ class LlamaData {
538551 std::vector<llama_chat_message> messages;
539552 std::vector<std::string> msg_strs;
540553 std::vector<char > fmtted;
554+ std::string chat_template;
541555
542556 int init (Opt & opt) {
543557 model = initialize_model (opt);
544558 if (!model) {
545559 return 1 ;
546560 }
547561
562+ chat_template = initialize_chat_template (model, opt);
563+
548564 context = initialize_context (model, opt);
549565 if (!context) {
550566 return 1 ;
551567 }
552568
553569 sampler = initialize_sampler (opt);
570+
554571 return 0 ;
555572 }
556573
@@ -573,21 +590,74 @@ class LlamaData {
573590 }
574591#endif
575592
576- int huggingface_dl (const std::string & model, const std::vector<std::string> headers, const std::string & bn) {
593+ int huggingface_dl_tmpl (const std::string & hfr, const std::vector<std::string> headers, const std::string & tn) {
594+ if (std::filesystem::exists (tn)) {
595+ return 0 ;
596+ }
597+
598+ const std::string config_url = " https://huggingface.co/" + hfr + " /resolve/main/tokenizer_config.json" ;
599+ std::string tokenizer_config_str;
600+ download (config_url, headers, " " , true , &tokenizer_config_str);
601+ if (tokenizer_config_str.empty ()) {
602+ // still return success since tokenizer_config is optional
603+ return 0 ;
604+ }
605+
606+ nlohmann::json config = nlohmann::json::parse (tokenizer_config_str);
607+ std::string tmpl = config[" chat_template" ];
608+
609+ FILE * tmpl_file = fopen (tn.c_str (), " w" );
610+ if (tmpl_file == NULL ) {
611+ return 1 ;
612+ }
613+ fprintf (tmpl_file, " %s" , tmpl.c_str ());
614+ fclose (tmpl_file);
615+
616+ return 0 ;
617+ }
618+
619+ int huggingface_dl (const std::string & model, const std::vector<std::string> headers, const std::string & bn,
620+ const std::string & tn) {
621+ bool model_exists = std::filesystem::exists (bn);
622+ bool chat_tmpl_exists = std::filesystem::exists (tn);
623+ if (model_exists && chat_tmpl_exists) {
624+ return 0 ;
625+ }
626+
577627 // Find the second occurrence of '/' after protocol string
578628 size_t pos = model.find (' /' );
579629 pos = model.find (' /' , pos + 1 );
580630 if (pos == std::string::npos) {
581631 return 1 ;
582632 }
583-
584633 const std::string hfr = model.substr (0 , pos);
585634 const std::string hff = model.substr (pos + 1 );
586- const std::string url = " https://huggingface.co/" + hfr + " /resolve/main/" + hff;
587- return download (url, headers, bn, true );
635+
636+ if (!chat_tmpl_exists) {
637+ const int ret = huggingface_dl_tmpl (hfr, headers, tn);
638+ if (ret) {
639+ return ret;
640+ }
641+ }
642+
643+ if (!model_exists) {
644+ const std::string url = " https://huggingface.co/" + hfr + " /resolve/main/" + hff;
645+ const int ret = download (url, headers, bn, true );
646+ if (ret) {
647+ return ret;
648+ }
649+ }
650+ return 0 ;
588651 }
589652
590- int ollama_dl (std::string & model, const std::vector<std::string> headers, const std::string & bn) {
653+ int ollama_dl (std::string & model, const std::vector<std::string> headers, const std::string & bn,
654+ const std::string & tn) {
655+ bool model_exists = std::filesystem::exists (bn);
656+ bool chat_tmpl_exists = std::filesystem::exists (tn);
657+ if (model_exists && chat_tmpl_exists) {
658+ return 0 ;
659+ }
660+
591661 if (model.find (' /' ) == std::string::npos) {
592662 model = " library/" + model;
593663 }
@@ -607,16 +677,34 @@ class LlamaData {
607677 }
608678
609679 nlohmann::json manifest = nlohmann::json::parse (manifest_str);
610- std::string layer;
680+ std::string sha_model;
681+ std::string sha_template;
611682 for (const auto & l : manifest[" layers" ]) {
612683 if (l[" mediaType" ] == " application/vnd.ollama.image.model" ) {
613- layer = l[" digest" ];
614- break ;
684+ sha_model = l[" digest" ];
685+ }
686+ if (l[" mediaType" ] == " application/vnd.ollama.image.template" ) {
687+ sha_template = l[" digest" ];
688+ }
689+ }
690+
691+ if (!chat_tmpl_exists && !sha_template.empty ()) {
692+ std::string tmpl_blob_url = " https://registry.ollama.ai/v2/" + model + " /blobs/" + sha_template;
693+ const int tmpl_ret = download (tmpl_blob_url, headers, tn, true );
694+ if (tmpl_ret) {
695+ return tmpl_ret;
696+ }
697+ }
698+
699+ if (!model_exists) {
700+ std::string model_blob_url = " https://registry.ollama.ai/v2/" + model + " /blobs/" + sha_model;
701+ const int model_ret = download (model_blob_url, headers, bn, true );
702+ if (model_ret) {
703+ return model_ret;
615704 }
616705 }
617706
618- std::string blob_url = " https://registry.ollama.ai/v2/" + model + " /blobs/" + layer;
619- return download (blob_url, headers, bn, true );
707+ return 0 ;
620708 }
621709
622710 std::string basename (const std::string & path) {
@@ -628,6 +716,15 @@ class LlamaData {
628716 return path.substr (pos + 1 );
629717 }
630718
719+ std::string get_proto (const std::string & model_) {
720+ const std::string::size_type pos = model_.find (" ://" );
721+ if (pos == std::string::npos) {
722+ return " " ;
723+ }
724+
725+ return model_.substr (0 , pos + 3 ); // Include "://"
726+ }
727+
631728 int remove_proto (std::string & model_) {
632729 const std::string::size_type pos = model_.find (" ://" );
633730 if (pos == std::string::npos) {
@@ -638,38 +735,40 @@ class LlamaData {
638735 return 0 ;
639736 }
640737
641- int resolve_model (std::string & model_) {
642- int ret = 0 ;
643- if (string_starts_with (model_, " file://" ) || std::filesystem::exists (model_) ) {
738+ int resolve_model (std::string & model_, std::string & chat_template_ ) {
739+ int ret = 0 ;
740+ if (string_starts_with (model_, " file://" )) {
644741 remove_proto (model_);
645-
646742 return ret;
647743 }
648744
745+ std::string proto = get_proto (model_);
746+ remove_proto (model_);
747+
649748 const std::string bn = basename (model_);
749+ const std::string tn = chat_template_.empty () ? bn + " .template" : chat_template_;
650750 const std::vector<std::string> headers = { " --header" ,
651751 " Accept: application/vnd.docker.distribution.manifest.v2+json" };
652- if (string_starts_with (model_, " hf://" ) || string_starts_with (model_, " huggingface://" )) {
653- remove_proto (model_);
654- ret = huggingface_dl (model_, headers, bn);
655- } else if (string_starts_with (model_, " ollama://" )) {
656- remove_proto (model_);
657- ret = ollama_dl (model_, headers, bn);
658- } else if (string_starts_with (model_, " https://" )) {
752+ if (string_starts_with (proto, " hf://" ) || string_starts_with (proto, " huggingface://" )) {
753+ ret = huggingface_dl (model_, headers, bn, tn);
754+ } else if (string_starts_with (proto, " ollama://" )) {
755+ ret = ollama_dl (model_, headers, bn, tn);
756+ } else if (string_starts_with (proto, " https://" )) {
659757 download (model_, headers, bn, true );
660758 } else {
661- ret = ollama_dl (model_, headers, bn);
759+ ret = ollama_dl (model_, headers, bn, tn );
662760 }
663761
664- model_ = bn;
762+ model_ = bn;
763+ chat_template_ = tn;
665764
666765 return ret;
667766 }
668767
669768 // Initializes the model and returns a unique pointer to it
670769 llama_model_ptr initialize_model (Opt & opt) {
671770 ggml_backend_load_all ();
672- resolve_model (opt.model_ );
771+ resolve_model (opt.model_ , opt. chat_template_ );
673772 printe (
674773 " \r %*s"
675774 " \r Loading model" ,
@@ -702,6 +801,31 @@ class LlamaData {
702801
703802 return sampler;
704803 }
804+
805+ std::string initialize_chat_template (const llama_model_ptr & model, const Opt & opt) {
806+ if (!std::filesystem::exists (opt.chat_template_ )) {
807+ return common_get_builtin_chat_template (model.get ());
808+ }
809+
810+ FILE * tmpl_file = ggml_fopen (opt.chat_template_ .c_str (), " r" );
811+ if (!tmpl_file) {
812+ std::cerr << " Error opening file '" << opt.chat_template_ << " ': " << strerror (errno) << " \n " ;
813+ return " " ;
814+ }
815+
816+ fseek (tmpl_file, 0 , SEEK_END);
817+ size_t size = ftell (tmpl_file);
818+ fseek (tmpl_file, 0 , SEEK_SET);
819+
820+ std::vector<unsigned char > data (size);
821+ size_t read_size = fread (data.data (), 1 , size, tmpl_file);
822+ fclose (tmpl_file);
823+ if (read_size != size) {
824+ std::cerr << " Error reading file '" << opt.chat_template_ << " ': " << strerror (errno) << " \n " ;
825+ return " " ;
826+ }
827+ return std::string (data.begin (), data.end ());
828+ }
705829};
706830
707831// Add a message to `messages` and store its content in `msg_strs`
@@ -713,11 +837,11 @@ static void add_message(const char * role, const std::string & text, LlamaData &
713837// Function to apply the chat template and resize `formatted` if needed
714838static int apply_chat_template (LlamaData & llama_data, const bool append) {
715839 int result = llama_chat_apply_template (
716- llama_model_chat_template ( llama_data.model . get () ), llama_data.messages .data (), llama_data.messages .size (), append,
840+ llama_data.chat_template . c_str ( ), llama_data.messages .data (), llama_data.messages .size (), append,
717841 append ? llama_data.fmtted .data () : nullptr , append ? llama_data.fmtted .size () : 0 );
718842 if (append && result > static_cast <int >(llama_data.fmtted .size ())) {
719843 llama_data.fmtted .resize (result);
720- result = llama_chat_apply_template (llama_model_chat_template ( llama_data.model . get () ), llama_data.messages .data (),
844+ result = llama_chat_apply_template (llama_data.chat_template . c_str ( ), llama_data.messages .data (),
721845 llama_data.messages .size (), append, llama_data.fmtted .data (),
722846 llama_data.fmtted .size ());
723847 }
@@ -730,8 +854,8 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt
730854 std::vector<llama_token> & prompt_tokens) {
731855 const int n_prompt_tokens = -llama_tokenize (vocab, prompt.c_str (), prompt.size (), NULL , 0 , true , true );
732856 prompt_tokens.resize (n_prompt_tokens);
733- if (llama_tokenize (vocab, prompt.c_str (), prompt.size (), prompt_tokens.data (), prompt_tokens.size (), true ,
734- true ) < 0 ) {
857+ if (llama_tokenize (vocab, prompt.c_str (), prompt.size (), prompt_tokens.data (), prompt_tokens.size (), true , true ) <
858+ 0 ) {
735859 printe (" failed to tokenize the prompt\n " );
736860 return -1 ;
737861 }
0 commit comments