@@ -181,6 +181,10 @@ class Opt {
181181 }
182182 }
183183
184+ if (model_.empty ()){
185+ return 1 ;
186+ }
187+
184188 return 0 ;
185189 }
186190
@@ -319,6 +323,10 @@ class HttpClient {
319323 public:
320324 int init (const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
321325 const bool progress, std::string * response_str = nullptr ) {
326+ if (std::filesystem::exists (output_file)) {
327+ return 0 ;
328+ }
329+
322330 std::string output_file_partial;
323331 curl = curl_easy_init ();
324332 if (!curl) {
@@ -346,7 +354,11 @@ class HttpClient {
346354 data.file_size = set_resume_point (output_file_partial);
347355 set_progress_options (progress, data);
348356 set_headers (headers);
349- perform (url);
357+ CURLcode res = perform (url);
358+ if (res != CURLE_OK){
359+ printe (" Fetching resource '%s' failed: %s\n " , url.c_str (), curl_easy_strerror (res));
360+ return 1 ;
361+ }
350362 if (!output_file.empty ()) {
351363 std::filesystem::rename (output_file_partial, output_file);
352364 }
@@ -411,16 +423,12 @@ class HttpClient {
411423 }
412424 }
413425
414- void perform (const std::string & url) {
415- CURLcode res;
426+ CURLcode perform (const std::string & url) {
416427 curl_easy_setopt (curl, CURLOPT_URL, url.c_str ());
417428 curl_easy_setopt (curl, CURLOPT_FOLLOWLOCATION, 1L );
418429 curl_easy_setopt (curl, CURLOPT_DEFAULT_PROTOCOL, " https" );
419430 curl_easy_setopt (curl, CURLOPT_FAILONERROR, 1L );
420- res = curl_easy_perform (curl);
421- if (res != CURLE_OK) {
422- printe (" curl_easy_perform() failed: %s\n " , curl_easy_strerror (res));
423- }
431+ return curl_easy_perform (curl);
424432 }
425433
426434 static std::string human_readable_time (double seconds) {
@@ -558,13 +566,14 @@ class LlamaData {
558566 }
559567
560568 sampler = initialize_sampler (opt);
569+
561570 return 0 ;
562571 }
563572
564573 private:
565574#ifdef LLAMA_USE_CURL
566- int download (const std::string & url, const std::vector<std:: string> & headers , const std::string & output_file ,
567- const bool progress , std::string * response_str = nullptr ) {
575+ int download (const std::string & url, const std::string & output_file , const bool progress ,
576+ const std::vector<std::string> & headers = {} , std::string * response_str = nullptr ) {
568577 HttpClient http;
569578 if (http.init (url, headers, output_file, progress, response_str)) {
570579 return 1 ;
@@ -573,57 +582,120 @@ class LlamaData {
573582 return 0 ;
574583 }
575584#else
576- int download (const std::string &, const std::vector<std:: string> &, const std::string &, const bool ,
585+ int download (const std::string &, const std::string &, const bool , const std::vector<std:: string> & = {} ,
577586 std::string * = nullptr ) {
578587 printe (" %s: llama.cpp built without libcurl, downloading from an url not supported.\n " , __func__);
588+
579589 return 1 ;
580590 }
581591#endif
582592
583- int huggingface_dl (const std::string & model, const std::vector<std::string> headers, const std::string & bn) {
593+ // Helper function to handle model tag extraction and URL construction
594+ std::pair<std::string, std::string> extract_model_and_tag (std::string & model, const std::string & base_url) {
595+ std::string model_tag = " latest" ;
596+ const size_t colon_pos = model.find (' :' );
597+ if (colon_pos != std::string::npos) {
598+ model_tag = model.substr (colon_pos + 1 );
599+ model = model.substr (0 , colon_pos);
600+ }
601+
602+ std::string url = base_url + model + " /manifests/" + model_tag;
603+
604+ return { model, url };
605+ }
606+
607+ // Helper function to download and parse the manifest
608+ int download_and_parse_manifest (const std::string & url, const std::vector<std::string> & headers,
609+ nlohmann::json & manifest) {
610+ std::string manifest_str;
611+ int ret = download (url, " " , false , headers, &manifest_str);
612+ if (ret) {
613+ return ret;
614+ }
615+
616+ manifest = nlohmann::json::parse (manifest_str);
617+
618+ return 0 ;
619+ }
620+
621+ int huggingface_dl (std::string & model, const std::string & bn) {
584622 // Find the second occurrence of '/' after protocol string
585623 size_t pos = model.find (' /' );
586624 pos = model.find (' /' , pos + 1 );
625+ std::string hfr, hff;
626+ std::vector<std::string> headers = { " User-Agent: llama-cpp" , " Accept: application/json" };
627+ std::string url;
628+
587629 if (pos == std::string::npos) {
588- return 1 ;
630+ auto [model_name, manifest_url] = extract_model_and_tag (model, " https://huggingface.co/v2/" );
631+ hfr = model_name;
632+
633+ nlohmann::json manifest;
634+ int ret = download_and_parse_manifest (manifest_url, headers, manifest);
635+ if (ret) {
636+ return ret;
637+ }
638+
639+ hff = manifest[" ggufFile" ][" rfilename" ];
640+ } else {
641+ hfr = model.substr (0 , pos);
642+ hff = model.substr (pos + 1 );
589643 }
590644
591- const std::string hfr = model.substr (0 , pos);
592- const std::string hff = model.substr (pos + 1 );
593- const std::string url = " https://huggingface.co/" + hfr + " /resolve/main/" + hff;
594- return download (url, headers, bn, true );
645+ url = " https://huggingface.co/" + hfr + " /resolve/main/" + hff;
646+
647+ return download (url, bn, true , headers);
595648 }
596649
597- int ollama_dl (std::string & model, const std::vector<std::string> headers, const std::string & bn) {
650+ int ollama_dl (std::string & model, const std::string & bn) {
651+ const std::vector<std::string> headers = { " Accept: application/vnd.docker.distribution.manifest.v2+json" };
598652 if (model.find (' /' ) == std::string::npos) {
599653 model = " library/" + model;
600654 }
601655
602- std::string model_tag = " latest" ;
603- size_t colon_pos = model.find (' :' );
604- if (colon_pos != std::string::npos) {
605- model_tag = model.substr (colon_pos + 1 );
606- model = model.substr (0 , colon_pos);
607- }
608-
609- std::string manifest_url = " https://registry.ollama.ai/v2/" + model + " /manifests/" + model_tag;
610- std::string manifest_str;
611- const int ret = download (manifest_url, headers, " " , false , &manifest_str);
656+ auto [model_name, manifest_url] = extract_model_and_tag (model, " https://registry.ollama.ai/v2/" );
657+ nlohmann::json manifest;
658+ int ret = download_and_parse_manifest (manifest_url, {}, manifest);
612659 if (ret) {
613660 return ret;
614661 }
615662
616- nlohmann::json manifest = nlohmann::json::parse (manifest_str);
617- std::string layer;
663+ std::string layer;
618664 for (const auto & l : manifest[" layers" ]) {
619665 if (l[" mediaType" ] == " application/vnd.ollama.image.model" ) {
620666 layer = l[" digest" ];
621667 break ;
622668 }
623669 }
624670
625- std::string blob_url = " https://registry.ollama.ai/v2/" + model + " /blobs/" + layer;
626- return download (blob_url, headers, bn, true );
671+ std::string blob_url = " https://registry.ollama.ai/v2/" + model_name + " /blobs/" + layer;
672+
673+ return download (blob_url, bn, true , headers);
674+ }
675+
676+ int github_dl (const std::string & model, const std::string & bn) {
677+ std::string repository = model;
678+ std::string branch = " main" ;
679+ const size_t at_pos = model.find (' @' );
680+ if (at_pos != std::string::npos) {
681+ repository = model.substr (0 , at_pos);
682+ branch = model.substr (at_pos + 1 );
683+ }
684+
685+ const std::vector<std::string> repo_parts = string_split (repository, " /" );
686+ if (repo_parts.size () < 3 ) {
687+ printe (" Invalid GitHub repository format\n " );
688+ return 1 ;
689+ }
690+
691+ const std::string & org = repo_parts[0 ];
692+ const std::string & project = repo_parts[1 ];
693+ std::string url = " https://raw.githubusercontent.com/" + org + " /" + project + " /" + branch;
694+ for (size_t i = 2 ; i < repo_parts.size (); ++i) {
695+ url += " /" + repo_parts[i];
696+ }
697+
698+ return download (url, bn, true );
627699 }
628700
629701 std::string basename (const std::string & path) {
@@ -653,22 +725,23 @@ class LlamaData {
653725 return ret;
654726 }
655727
656- const std::string bn = basename (model_);
657- const std::vector<std::string> headers = { " --header" ,
658- " Accept: application/vnd.docker.distribution.manifest.v2+json" };
659- if (string_starts_with (model_, " hf://" ) || string_starts_with (model_, " huggingface://" )) {
660- rm_until_substring (model_, " ://" );
661- ret = huggingface_dl (model_, headers, bn);
662- } else if (string_starts_with (model_, " hf.co/" )) {
728+ const std::string bn = basename (model_);
729+ if (string_starts_with (model_, " hf://" ) || string_starts_with (model_, " huggingface://" ) ||
730+ string_starts_with (model_, " hf.co/" )) {
663731 rm_until_substring (model_, " hf.co/" );
664- ret = huggingface_dl (model_, headers, bn);
665- } else if (string_starts_with (model_, " ollama://" )) {
666732 rm_until_substring (model_, " ://" );
667- ret = ollama_dl (model_, headers, bn);
668- } else if (string_starts_with (model_, " https://" )) {
669- ret = download (model_, headers, bn, true );
670- } else {
671- ret = ollama_dl (model_, headers, bn);
733+ ret = huggingface_dl (model_, bn);
734+ } else if ((string_starts_with (model_, " https://" ) || string_starts_with (model_, " http://" )) &&
735+ !string_starts_with (model_, " https://ollama.com/library/" )) {
736+ ret = download (model_, bn, true );
737+ } else if (string_starts_with (model_, " github:" ) || string_starts_with (model_, " github://" )) {
738+ rm_until_substring (model_, " github:" );
739+ rm_until_substring (model_, " ://" );
740+ ret = github_dl (model_, bn);
741+ } else { // ollama:// or nothing
742+ rm_until_substring (model_, " ollama.com/library/" );
743+ rm_until_substring (model_, " ://" );
744+ ret = ollama_dl (model_, bn);
672745 }
673746
674747 model_ = bn;
0 commit comments