@@ -303,7 +303,8 @@ static bool common_download_head(CURL * curl,
303303// download one single file from remote URL to local path
304304static bool common_download_file_single_online (const std::string & url,
305305 const std::string & path,
306- const std::string & bearer_token) {
306+ const std::string & bearer_token,
307+ const std::vector<std::pair<std::string, std::string>> & headers) {
307308 static const int max_attempts = 3 ;
308309 static const int retry_delay_seconds = 2 ;
309310 for (int i = 0 ; i < max_attempts; ++i) {
@@ -322,10 +323,14 @@ static bool common_download_file_single_online(const std::string & url,
322323
323324 // Initialize libcurl
324325 curl_ptr curl (curl_easy_init (), &curl_easy_cleanup);
325- common_load_model_from_url_headers headers ;
326- curl_easy_setopt (curl.get (), CURLOPT_HEADERDATA, &headers );
326+ common_load_model_from_url_headers response_headers ;
327+ curl_easy_setopt (curl.get (), CURLOPT_HEADERDATA, &response_headers );
327328 curl_slist_ptr http_headers;
328- const bool was_perform_successful = common_download_head (curl.get (), http_headers, url, bearer_token);
329+ for (const auto & h : headers) {
330+ auto header_str = h.first + " : " + h.second ;
331+ http_headers.ptr = curl_slist_append (http_headers.ptr , header_str.c_str ());
332+ }
333+ const bool was_perform_successful = common_download_head (curl.get (), http_headers, url, bearer_token);
329334 if (!was_perform_successful) {
330335 head_request_ok = false ;
331336 }
@@ -345,15 +350,15 @@ static bool common_download_file_single_online(const std::string & url,
345350 if (head_request_ok) {
346351 // check if ETag or Last-Modified headers are different
347352 // if it is, we need to download the file again
348- if (!etag.empty () && etag != headers .etag ) {
353+ if (!etag.empty () && etag != response_headers .etag ) {
349354 LOG_WRN (" %s: ETag header is different (%s != %s): triggering a new download\n " , __func__, etag.c_str (),
350- headers .etag .c_str ());
355+ response_headers .etag .c_str ());
351356 should_download = true ;
352357 should_download_from_scratch = true ;
353358 }
354359 }
355360
356- const bool accept_ranges_supported = !headers .accept_ranges .empty () && headers .accept_ranges != " none" ;
361+ const bool accept_ranges_supported = !response_headers .accept_ranges .empty () && response_headers .accept_ranges != " none" ;
357362 if (should_download) {
358363 if (file_exists &&
359364 !accept_ranges_supported) { // Resumable downloads not supported, delete and start again.
@@ -381,13 +386,13 @@ static bool common_download_file_single_online(const std::string & url,
381386 }
382387 }
383388 if (head_request_ok) {
384- write_etag (path, headers .etag );
389+ write_etag (path, response_headers .etag );
385390 }
386391
387392 // start the download
388393 LOG_INF (" %s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n " ,
389394 __func__, llama_download_hide_password_in_url (url).c_str (), path_temporary.c_str (),
390- headers .etag .c_str (), headers .last_modified .c_str ());
395+ response_headers .etag .c_str (), response_headers .last_modified .c_str ());
391396 const bool was_pull_successful = common_pull_file (curl.get (), path_temporary);
392397 if (!was_pull_successful) {
393398 if (i + 1 < max_attempts) {
@@ -433,7 +438,7 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
433438 curl_easy_setopt (curl.get (), CURLOPT_VERBOSE, 1L );
434439 typedef size_t (*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
435440 auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
436- auto data_vec = static_cast <std::vector<char > *>(data);
441+ auto * data_vec = static_cast <std::vector<char > *>(data);
437442 data_vec->insert (data_vec->end (), (char *)ptr, (char *)ptr + size * nmemb);
438443 return size * nmemb;
439444 };
@@ -572,7 +577,8 @@ static bool common_pull_file(httplib::Client & cli,
572577// download one single file from remote URL to local path
573578static bool common_download_file_single_online (const std::string & url,
574579 const std::string & path,
575- const std::string & bearer_token) {
580+ const std::string & bearer_token,
581+ const std::vector<std::pair<std::string, std::string>> & headers) {
576582 static const int max_attempts = 3 ;
577583 static const int retry_delay_seconds = 2 ;
578584
@@ -582,6 +588,9 @@ static bool common_download_file_single_online(const std::string & url,
582588 if (!bearer_token.empty ()) {
583589 default_headers.insert ({" Authorization" , " Bearer " + bearer_token});
584590 }
591+ for (const auto & h : headers) {
592+ default_headers.insert ({h.first , h.second });
593+ }
585594 cli.set_default_headers (default_headers);
586595
587596 const bool file_exists = std::filesystem::exists (path);
@@ -725,9 +734,10 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
725734static bool common_download_file_single (const std::string & url,
726735 const std::string & path,
727736 const std::string & bearer_token,
728- bool offline) {
737+ bool offline,
738+ const std::vector<std::pair<std::string, std::string>> & headers) {
729739 if (!offline) {
730- return common_download_file_single_online (url, path, bearer_token);
740+ return common_download_file_single_online (url, path, bearer_token, headers );
731741 }
732742
733743 if (!std::filesystem::exists (path)) {
@@ -741,13 +751,24 @@ static bool common_download_file_single(const std::string & url,
741751
742752// download multiple files from remote URLs to local paths
743753// the input is a vector of pairs <url, path>
744- static bool common_download_file_multiple (const std::vector<std::pair<std::string, std::string>> & urls, const std::string & bearer_token, bool offline) {
754+ static bool common_download_file_multiple (const std::vector<std::pair<std::string, std::string>> & urls,
755+ const std::string & bearer_token,
756+ bool offline,
757+ const std::vector<std::pair<std::string, std::string>> & headers) {
745758 // Prepare download in parallel
746759 std::vector<std::future<bool >> futures_download;
760+ futures_download.reserve (urls.size ());
761+
747762 for (auto const & item : urls) {
748- futures_download.push_back (std::async (std::launch::async, [bearer_token, offline](const std::pair<std::string, std::string> & it) -> bool {
749- return common_download_file_single (it.first , it.second , bearer_token, offline);
750- }, item));
763+ futures_download.push_back (
764+ std::async (
765+ std::launch::async,
766+ [&bearer_token, offline, &headers](const std::pair<std::string, std::string> & it) -> bool {
767+ return common_download_file_single (it.first , it.second , bearer_token, offline, headers);
768+ },
769+ item
770+ )
771+ );
751772 }
752773
753774 // Wait for all downloads to complete
@@ -760,17 +781,17 @@ static bool common_download_file_multiple(const std::vector<std::pair<std::strin
760781 return true ;
761782}
762783
763- bool common_download_model (
764- const common_params_model & model ,
765- const std::string & bearer_token ,
766- bool offline ) {
784+ bool common_download_model (const common_params_model & model,
785+ const std::string & bearer_token ,
786+ bool offline ,
787+ const std::vector<std::pair<std::string, std::string>> & headers ) {
767788 // Basic validation of the model.url
768789 if (model.url .empty ()) {
769790 LOG_ERR (" %s: invalid model url\n " , __func__);
770791 return false ;
771792 }
772793
773- if (!common_download_file_single (model.url , model.path , bearer_token, offline)) {
794+ if (!common_download_file_single (model.url , model.path , bearer_token, offline, headers )) {
774795 return false ;
775796 }
776797
@@ -829,7 +850,7 @@ bool common_download_model(
829850 }
830851
831852 // Download in parallel
832- common_download_file_multiple (urls, bearer_token, offline);
853+ common_download_file_multiple (urls, bearer_token, offline, headers );
833854 }
834855
835856 return true ;
@@ -1023,7 +1044,7 @@ std::string common_docker_resolve_model(const std::string & docker) {
10231044 std::string local_path = fs_get_cache_file (model_filename);
10241045
10251046 const std::string blob_url = url_prefix + " /blobs/" + gguf_digest;
1026- if (!common_download_file_single (blob_url, local_path, token, false )) {
1047+ if (!common_download_file_single (blob_url, local_path, token, false , {} )) {
10271048 throw std::runtime_error (" Failed to download Docker Model" );
10281049 }
10291050
@@ -1041,7 +1062,10 @@ common_hf_file_res common_get_hf_file(const std::string &, const std::string &,
10411062 throw std::runtime_error (" download functionality is not enabled in this build" );
10421063}
10431064
1044- bool common_download_model (const common_params_model &, const std::string &, bool ) {
1065+ bool common_download_model (const common_params_model &,
1066+ const std::string &,
1067+ bool ,
1068+ const std::vector<std::pair<std::string, std::string>> &) {
10451069 throw std::runtime_error (" download functionality is not enabled in this build" );
10461070}
10471071
0 commit comments