@@ -303,7 +303,8 @@ static bool common_download_head(CURL * curl,
303303// download one single file from remote URL to local path
304304static bool common_download_file_single_online (const std::string & url,
305305 const std::string & path,
306- const std::string & bearer_token) {
306+ const std::string & bearer_token,
307+ const std::vector<std::pair<std::string, std::string>> & headers) {
307308 static const int max_attempts = 3 ;
308309 static const int retry_delay_seconds = 2 ;
309310 for (int i = 0 ; i < max_attempts; ++i) {
@@ -322,10 +323,14 @@ static bool common_download_file_single_online(const std::string & url,
322323
323324 // Initialize libcurl
324325 curl_ptr curl (curl_easy_init (), &curl_easy_cleanup);
325- common_load_model_from_url_headers headers ;
326- curl_easy_setopt (curl.get (), CURLOPT_HEADERDATA, &headers );
326+ common_load_model_from_url_headers response_headers ;
327+ curl_easy_setopt (curl.get (), CURLOPT_HEADERDATA, &response_headers );
327328 curl_slist_ptr http_headers;
328- const bool was_perform_successful = common_download_head (curl.get (), http_headers, url, bearer_token);
329+ for (const auto & h : headers) {
330+ auto header_str = h.first + " : " + h.second ;
331+ http_headers.ptr = curl_slist_append (http_headers.ptr , header_str.c_str ());
332+ }
333+ const bool was_perform_successful = common_download_head (curl.get (), http_headers, url, bearer_token);
329334 if (!was_perform_successful) {
330335 head_request_ok = false ;
331336 }
@@ -345,15 +350,15 @@ static bool common_download_file_single_online(const std::string & url,
345350 if (head_request_ok) {
346351 // check if ETag or Last-Modified headers are different
347352 // if it is, we need to download the file again
348- if (!etag.empty () && etag != headers .etag ) {
353+ if (!etag.empty () && etag != response_headers .etag ) {
349354 LOG_WRN (" %s: ETag header is different (%s != %s): triggering a new download\n " , __func__, etag.c_str (),
350- headers .etag .c_str ());
355+ response_headers .etag .c_str ());
351356 should_download = true ;
352357 should_download_from_scratch = true ;
353358 }
354359 }
355360
356- const bool accept_ranges_supported = !headers .accept_ranges .empty () && headers .accept_ranges != " none" ;
361+ const bool accept_ranges_supported = !response_headers .accept_ranges .empty () && response_headers .accept_ranges != " none" ;
357362 if (should_download) {
358363 if (file_exists &&
359364 !accept_ranges_supported) { // Resumable downloads not supported, delete and start again.
@@ -381,13 +386,13 @@ static bool common_download_file_single_online(const std::string & url,
381386 }
382387 }
383388 if (head_request_ok) {
384- write_etag (path, headers .etag );
389+ write_etag (path, response_headers .etag );
385390 }
386391
387392 // start the download
388393 LOG_INF (" %s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n " ,
389394 __func__, llama_download_hide_password_in_url (url).c_str (), path_temporary.c_str (),
390- headers .etag .c_str (), headers .last_modified .c_str ());
395+ response_headers .etag .c_str (), response_headers .last_modified .c_str ());
391396 const bool was_pull_successful = common_pull_file (curl.get (), path_temporary);
392397 if (!was_pull_successful) {
393398 if (i + 1 < max_attempts) {
@@ -433,7 +438,7 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
433438 curl_easy_setopt (curl.get (), CURLOPT_VERBOSE, 1L );
434439 typedef size_t (*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
435440 auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
436- auto data_vec = static_cast <std::vector<char > *>(data);
441+ auto * data_vec = static_cast <std::vector<char > *>(data);
437442 data_vec->insert (data_vec->end (), (char *)ptr, (char *)ptr + size * nmemb);
438443 return size * nmemb;
439444 };
@@ -565,7 +570,8 @@ static bool common_pull_file(httplib::Client & cli,
565570// download one single file from remote URL to local path
566571static bool common_download_file_single_online (const std::string & url,
567572 const std::string & path,
568- const std::string & bearer_token) {
573+ const std::string & bearer_token,
574+ const std::vector<std::pair<std::string, std::string>> & headers) {
569575 static const int max_attempts = 3 ;
570576 static const int retry_delay_seconds = 2 ;
571577
@@ -575,6 +581,9 @@ static bool common_download_file_single_online(const std::string & url,
575581 if (!bearer_token.empty ()) {
576582 default_headers.insert ({" Authorization" , " Bearer " + bearer_token});
577583 }
584+ for (const auto & h : headers) {
585+ default_headers.insert ({h.first , h.second });
586+ }
578587 cli.set_default_headers (default_headers);
579588
580589 const bool file_exists = std::filesystem::exists (path);
@@ -718,9 +727,10 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
718727static bool common_download_file_single (const std::string & url,
719728 const std::string & path,
720729 const std::string & bearer_token,
721- bool offline) {
730+ bool offline,
731+ const std::vector<std::pair<std::string, std::string>> & headers) {
722732 if (!offline) {
723- return common_download_file_single_online (url, path, bearer_token);
733+ return common_download_file_single_online (url, path, bearer_token, headers );
724734 }
725735
726736 if (!std::filesystem::exists (path)) {
@@ -734,13 +744,24 @@ static bool common_download_file_single(const std::string & url,
734744
735745// download multiple files from remote URLs to local paths
736746// the input is a vector of pairs <url, path>
737- static bool common_download_file_multiple (const std::vector<std::pair<std::string, std::string>> & urls, const std::string & bearer_token, bool offline) {
747+ static bool common_download_file_multiple (const std::vector<std::pair<std::string, std::string>> & urls,
748+ const std::string & bearer_token,
749+ bool offline,
750+ const std::vector<std::pair<std::string, std::string>> & headers) {
738751 // Prepare download in parallel
739752 std::vector<std::future<bool >> futures_download;
753+ futures_download.reserve (urls.size ());
754+
740755 for (auto const & item : urls) {
741- futures_download.push_back (std::async (std::launch::async, [bearer_token, offline](const std::pair<std::string, std::string> & it) -> bool {
742- return common_download_file_single (it.first , it.second , bearer_token, offline);
743- }, item));
756+ futures_download.push_back (
757+ std::async (
758+ std::launch::async,
759+ [&bearer_token, offline, &headers](const std::pair<std::string, std::string> & it) -> bool {
760+ return common_download_file_single (it.first , it.second , bearer_token, offline, headers);
761+ },
762+ item
763+ )
764+ );
744765 }
745766
746767 // Wait for all downloads to complete
@@ -753,17 +774,17 @@ static bool common_download_file_multiple(const std::vector<std::pair<std::strin
753774 return true ;
754775}
755776
756- bool common_download_model (
757- const common_params_model & model ,
758- const std::string & bearer_token ,
759- bool offline ) {
777+ bool common_download_model (const common_params_model & model,
778+ const std::string & bearer_token ,
779+ bool offline ,
780+ const std::vector<std::pair<std::string, std::string>> & headers ) {
760781 // Basic validation of the model.url
761782 if (model.url .empty ()) {
762783 LOG_ERR (" %s: invalid model url\n " , __func__);
763784 return false ;
764785 }
765786
766- if (!common_download_file_single (model.url , model.path , bearer_token, offline)) {
787+ if (!common_download_file_single (model.url , model.path , bearer_token, offline, headers )) {
767788 return false ;
768789 }
769790
@@ -822,7 +843,7 @@ bool common_download_model(
822843 }
823844
824845 // Download in parallel
825- common_download_file_multiple (urls, bearer_token, offline);
846+ common_download_file_multiple (urls, bearer_token, offline, headers );
826847 }
827848
828849 return true ;
@@ -1016,7 +1037,7 @@ std::string common_docker_resolve_model(const std::string & docker) {
10161037 std::string local_path = fs_get_cache_file (model_filename);
10171038
10181039 const std::string blob_url = url_prefix + " /blobs/" + gguf_digest;
1019- if (!common_download_file_single (blob_url, local_path, token, false )) {
1040+ if (!common_download_file_single (blob_url, local_path, token, false , {} )) {
10201041 throw std::runtime_error (" Failed to download Docker Model" );
10211042 }
10221043
@@ -1034,7 +1055,10 @@ common_hf_file_res common_get_hf_file(const std::string &, const std::string &,
10341055 throw std::runtime_error (" download functionality is not enabled in this build" );
10351056}
10361057
1037- bool common_download_model (const common_params_model &, const std::string &, bool ) {
1058+ bool common_download_model (const common_params_model &,
1059+ const std::string &,
1060+ bool ,
1061+ const std::vector<std::pair<std::string, std::string>> &) {
10381062 throw std::runtime_error (" download functionality is not enabled in this build" );
10391063}
10401064
0 commit comments