@@ -319,6 +319,10 @@ class HttpClient {
319319 public:
320320 int init (const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
321321 const bool progress, std::string * response_str = nullptr ) {
322+ if (std::filesystem::exists (output_file)) {
323+ return 0 ;
324+ }
325+
322326 std::string output_file_partial;
323327 curl = curl_easy_init ();
324328 if (!curl) {
@@ -558,13 +562,14 @@ class LlamaData {
558562 }
559563
560564 sampler = initialize_sampler (opt);
565+
561566 return 0 ;
562567 }
563568
564569 private:
565570#ifdef LLAMA_USE_CURL
566- int download (const std::string & url, const std::vector<std:: string> & headers , const std::string & output_file ,
567- const bool progress , std::string * response_str = nullptr ) {
571+ int download (const std::string & url, const std::string & output_file , const bool progress ,
572+ const std::vector<std::string> & headers = {} , std::string * response_str = nullptr ) {
568573 HttpClient http;
569574 if (http.init (url, headers, output_file, progress, response_str)) {
570575 return 1 ;
@@ -573,57 +578,95 @@ class LlamaData {
573578 return 0 ;
574579 }
575580#else
576- int download (const std::string &, const std::vector<std:: string> &, const std::string &, const bool ,
581+ int download (const std::string &, const std::string &, const bool , const std::vector<std:: string> & = {} ,
577582 std::string * = nullptr ) {
578583 printe (" %s: llama.cpp built without libcurl, downloading from an url not supported.\n " , __func__);
584+
579585 return 1 ;
580586 }
581587#endif
582588
583- int huggingface_dl (const std::string & model, const std::vector<std::string> headers, const std::string & bn) {
589+ // Helper function to handle model tag extraction and URL construction
590+ std::pair<std::string, std::string> extract_model_and_tag (std::string & model, const std::string & base_url) {
591+ std::string model_tag = " latest" ;
592+ const size_t colon_pos = model.find (' :' );
593+ if (colon_pos != std::string::npos) {
594+ model_tag = model.substr (colon_pos + 1 );
595+ model = model.substr (0 , colon_pos);
596+ }
597+
598+ std::string url = base_url + model + " /manifests/" + model_tag;
599+
600+ return { model, url };
601+ }
602+
603+ // Helper function to download and parse the manifest
604+ int download_and_parse_manifest (const std::string & url, const std::vector<std::string> & headers,
605+ nlohmann::json & manifest) {
606+ std::string manifest_str;
607+ int ret = download (url, " " , false , headers, &manifest_str);
608+ if (ret) {
609+ return ret;
610+ }
611+
612+ manifest = nlohmann::json::parse (manifest_str);
613+
614+ return 0 ;
615+ }
616+
617+ int huggingface_dl (std::string & model, const std::string & bn) {
584618 // Find the second occurrence of '/' after protocol string
585619 size_t pos = model.find (' /' );
586620 pos = model.find (' /' , pos + 1 );
621+ std::string hfr, hff;
622+ std::vector<std::string> headers = { " User-Agent: llama-cpp" , " Accept: application/json" };
623+ std::string url;
624+
587625 if (pos == std::string::npos) {
588- return 1 ;
626+ auto [model_name, manifest_url] = extract_model_and_tag (model, " https://huggingface.co/v2/" );
627+ hfr = model_name;
628+
629+ nlohmann::json manifest;
630+ int ret = download_and_parse_manifest (manifest_url, headers, manifest);
631+ if (ret) {
632+ return ret;
633+ }
634+
635+ hff = manifest[" ggufFile" ][" rfilename" ];
636+ } else {
637+ hfr = model.substr (0 , pos);
638+ hff = model.substr (pos + 1 );
589639 }
590640
591- const std::string hfr = model.substr (0 , pos);
592- const std::string hff = model.substr (pos + 1 );
593- const std::string url = " https://huggingface.co/" + hfr + " /resolve/main/" + hff;
594- return download (url, headers, bn, true );
641+ url = " https://huggingface.co/" + hfr + " /resolve/main/" + hff;
642+
643+ return download (url, bn, true , headers);
595644 }
596645
597- int ollama_dl (std::string & model, const std::vector<std::string> headers, const std::string & bn) {
646+ int ollama_dl (std::string & model, const std::string & bn) {
647+ const std::vector<std::string> headers = { " Accept: application/vnd.docker.distribution.manifest.v2+json" };
598648 if (model.find (' /' ) == std::string::npos) {
599649 model = " library/" + model;
600650 }
601651
602- std::string model_tag = " latest" ;
603- size_t colon_pos = model.find (' :' );
604- if (colon_pos != std::string::npos) {
605- model_tag = model.substr (colon_pos + 1 );
606- model = model.substr (0 , colon_pos);
607- }
608-
609- std::string manifest_url = " https://registry.ollama.ai/v2/" + model + " /manifests/" + model_tag;
610- std::string manifest_str;
611- const int ret = download (manifest_url, headers, " " , false , &manifest_str);
652+ auto [model_name, manifest_url] = extract_model_and_tag (model, " https://registry.ollama.ai/v2/" );
653+ nlohmann::json manifest;
654+ int ret = download_and_parse_manifest (manifest_url, {}, manifest);
612655 if (ret) {
613656 return ret;
614657 }
615658
616- nlohmann::json manifest = nlohmann::json::parse (manifest_str);
617- std::string layer;
659+ std::string layer;
618660 for (const auto & l : manifest[" layers" ]) {
619661 if (l[" mediaType" ] == " application/vnd.ollama.image.model" ) {
620662 layer = l[" digest" ];
621663 break ;
622664 }
623665 }
624666
625- std::string blob_url = " https://registry.ollama.ai/v2/" + model + " /blobs/" + layer;
626- return download (blob_url, headers, bn, true );
667+ std::string blob_url = " https://registry.ollama.ai/v2/" + model_name + " /blobs/" + layer;
668+
669+ return download (blob_url, bn, true , headers);
627670 }
628671
629672 std::string basename (const std::string & path) {
@@ -653,22 +696,18 @@ class LlamaData {
653696 return ret;
654697 }
655698
656- const std::string bn = basename (model_);
657- const std::vector<std::string> headers = { " --header" ,
658- " Accept: application/vnd.docker.distribution.manifest.v2+json" };
699+ const std::string bn = basename (model_);
659700 if (string_starts_with (model_, " hf://" ) || string_starts_with (model_, " huggingface://" )) {
660701 rm_until_substring (model_, " ://" );
661- ret = huggingface_dl (model_, headers, bn);
702+ ret = huggingface_dl (model_, bn);
662703 } else if (string_starts_with (model_, " hf.co/" )) {
663704 rm_until_substring (model_, " hf.co/" );
664- ret = huggingface_dl (model_, headers, bn);
665- } else if (string_starts_with (model_, " ollama://" )) {
666- rm_until_substring (model_, " ://" );
667- ret = ollama_dl (model_, headers, bn);
705+ ret = huggingface_dl (model_, bn);
668706 } else if (string_starts_with (model_, " https://" )) {
669- ret = download (model_, headers, bn, true );
670- } else {
671- ret = ollama_dl (model_, headers, bn);
707+ ret = download (model_, bn, true );
708+ } else { // ollama:// or nothing
709+ rm_until_substring (model_, " ://" );
710+ ret = ollama_dl (model_, bn);
672711 }
673712
674713 model_ = bn;
0 commit comments