Skip to content

Commit dbc12c4

Browse files
committed
common : support custom HTTP headers for model downloads
Signed-off-by: Adrien Gallouët <[email protected]>
1 parent 28175f8 commit dbc12c4

File tree

2 files changed

+52
-26
lines changed

2 files changed

+52
-26
lines changed

common/download.cpp

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,8 @@ static bool common_download_head(CURL * curl,
303303
// download one single file from remote URL to local path
304304
static bool common_download_file_single_online(const std::string & url,
305305
const std::string & path,
306-
const std::string & bearer_token) {
306+
const std::string & bearer_token,
307+
const std::vector<std::pair<std::string, std::string>> & headers) {
307308
static const int max_attempts = 3;
308309
static const int retry_delay_seconds = 2;
309310
for (int i = 0; i < max_attempts; ++i) {
@@ -322,10 +323,14 @@ static bool common_download_file_single_online(const std::string & url,
322323

323324
// Initialize libcurl
324325
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
325-
common_load_model_from_url_headers headers;
326-
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
326+
common_load_model_from_url_headers response_headers;
327+
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &response_headers);
327328
curl_slist_ptr http_headers;
328-
const bool was_perform_successful = common_download_head(curl.get(), http_headers, url, bearer_token);
329+
for (const auto & h : headers) {
330+
auto header_str = h.first + ": " + h.second;
331+
http_headers.ptr = curl_slist_append(http_headers.ptr, header_str.c_str());
332+
}
333+
const bool was_perform_successful = common_download_head(curl.get(), http_headers, url, bearer_token);
329334
if (!was_perform_successful) {
330335
head_request_ok = false;
331336
}
@@ -345,15 +350,15 @@ static bool common_download_file_single_online(const std::string & url,
345350
if (head_request_ok) {
346351
// check if ETag or Last-Modified headers are different
347352
// if it is, we need to download the file again
348-
if (!etag.empty() && etag != headers.etag) {
353+
if (!etag.empty() && etag != response_headers.etag) {
349354
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(),
350-
headers.etag.c_str());
355+
response_headers.etag.c_str());
351356
should_download = true;
352357
should_download_from_scratch = true;
353358
}
354359
}
355360

356-
const bool accept_ranges_supported = !headers.accept_ranges.empty() && headers.accept_ranges != "none";
361+
const bool accept_ranges_supported = !response_headers.accept_ranges.empty() && response_headers.accept_ranges != "none";
357362
if (should_download) {
358363
if (file_exists &&
359364
!accept_ranges_supported) { // Resumable downloads not supported, delete and start again.
@@ -381,13 +386,13 @@ static bool common_download_file_single_online(const std::string & url,
381386
}
382387
}
383388
if (head_request_ok) {
384-
write_etag(path, headers.etag);
389+
write_etag(path, response_headers.etag);
385390
}
386391

387392
// start the download
388393
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n",
389394
__func__, llama_download_hide_password_in_url(url).c_str(), path_temporary.c_str(),
390-
headers.etag.c_str(), headers.last_modified.c_str());
395+
response_headers.etag.c_str(), response_headers.last_modified.c_str());
391396
const bool was_pull_successful = common_pull_file(curl.get(), path_temporary);
392397
if (!was_pull_successful) {
393398
if (i + 1 < max_attempts) {
@@ -433,7 +438,7 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
433438
curl_easy_setopt(curl.get(), CURLOPT_VERBOSE, 1L);
434439
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
435440
auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
436-
auto data_vec = static_cast<std::vector<char> *>(data);
441+
auto *data_vec = static_cast<std::vector<char> *>(data);
437442
data_vec->insert(data_vec->end(), (char *)ptr, (char *)ptr + size * nmemb);
438443
return size * nmemb;
439444
};
@@ -565,7 +570,8 @@ static bool common_pull_file(httplib::Client & cli,
565570
// download one single file from remote URL to local path
566571
static bool common_download_file_single_online(const std::string & url,
567572
const std::string & path,
568-
const std::string & bearer_token) {
573+
const std::string & bearer_token,
574+
const std::vector<std::pair<std::string, std::string>> & headers) {
569575
static const int max_attempts = 3;
570576
static const int retry_delay_seconds = 2;
571577

@@ -575,6 +581,9 @@ static bool common_download_file_single_online(const std::string & url,
575581
if (!bearer_token.empty()) {
576582
default_headers.insert({"Authorization", "Bearer " + bearer_token});
577583
}
584+
for (const auto & h : headers) {
585+
default_headers.insert({h.first, h.second});
586+
}
578587
cli.set_default_headers(default_headers);
579588

580589
const bool file_exists = std::filesystem::exists(path);
@@ -718,9 +727,10 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
718727
static bool common_download_file_single(const std::string & url,
719728
const std::string & path,
720729
const std::string & bearer_token,
721-
bool offline) {
730+
bool offline,
731+
const std::vector<std::pair<std::string, std::string>> & headers) {
722732
if (!offline) {
723-
return common_download_file_single_online(url, path, bearer_token);
733+
return common_download_file_single_online(url, path, bearer_token, headers);
724734
}
725735

726736
if (!std::filesystem::exists(path)) {
@@ -734,13 +744,24 @@ static bool common_download_file_single(const std::string & url,
734744

735745
// download multiple files from remote URLs to local paths
736746
// the input is a vector of pairs <url, path>
737-
static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> & urls, const std::string & bearer_token, bool offline) {
747+
static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> & urls,
748+
const std::string & bearer_token,
749+
bool offline,
750+
const std::vector<std::pair<std::string, std::string>> & headers) {
738751
// Prepare download in parallel
739752
std::vector<std::future<bool>> futures_download;
753+
futures_download.reserve(urls.size());
754+
740755
for (auto const & item : urls) {
741-
futures_download.push_back(std::async(std::launch::async, [bearer_token, offline](const std::pair<std::string, std::string> & it) -> bool {
742-
return common_download_file_single(it.first, it.second, bearer_token, offline);
743-
}, item));
756+
futures_download.push_back(
757+
std::async(
758+
std::launch::async,
759+
[&bearer_token, offline, &headers](const std::pair<std::string, std::string> & it) -> bool {
760+
return common_download_file_single(it.first, it.second, bearer_token, offline, headers);
761+
},
762+
item
763+
)
764+
);
744765
}
745766

746767
// Wait for all downloads to complete
@@ -753,17 +774,17 @@ static bool common_download_file_multiple(const std::vector<std::pair<std::strin
753774
return true;
754775
}
755776

756-
bool common_download_model(
757-
const common_params_model & model,
758-
const std::string & bearer_token,
759-
bool offline) {
777+
bool common_download_model(const common_params_model & model,
778+
const std::string & bearer_token,
779+
bool offline,
780+
const std::vector<std::pair<std::string, std::string>> & headers) {
760781
// Basic validation of the model.url
761782
if (model.url.empty()) {
762783
LOG_ERR("%s: invalid model url\n", __func__);
763784
return false;
764785
}
765786

766-
if (!common_download_file_single(model.url, model.path, bearer_token, offline)) {
787+
if (!common_download_file_single(model.url, model.path, bearer_token, offline, headers)) {
767788
return false;
768789
}
769790

@@ -822,7 +843,7 @@ bool common_download_model(
822843
}
823844

824845
// Download in parallel
825-
common_download_file_multiple(urls, bearer_token, offline);
846+
common_download_file_multiple(urls, bearer_token, offline, headers);
826847
}
827848

828849
return true;
@@ -1016,7 +1037,7 @@ std::string common_docker_resolve_model(const std::string & docker) {
10161037
std::string local_path = fs_get_cache_file(model_filename);
10171038

10181039
const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
1019-
if (!common_download_file_single(blob_url, local_path, token, false)) {
1040+
if (!common_download_file_single(blob_url, local_path, token, false, {})) {
10201041
throw std::runtime_error("Failed to download Docker Model");
10211042
}
10221043

@@ -1034,7 +1055,10 @@ common_hf_file_res common_get_hf_file(const std::string &, const std::string &,
10341055
throw std::runtime_error("download functionality is not enabled in this build");
10351056
}
10361057

1037-
bool common_download_model(const common_params_model &, const std::string &, bool) {
1058+
bool common_download_model(const common_params_model &,
1059+
const std::string &,
1060+
bool,
1061+
const std::vector<std::pair<std::string, std::string>> &) {
10381062
throw std::runtime_error("download functionality is not enabled in this build");
10391063
}
10401064

common/download.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <string>
4+
#include <vector>
45

56
struct common_params_model;
67

@@ -45,7 +46,8 @@ common_hf_file_res common_get_hf_file(
4546
bool common_download_model(
4647
const common_params_model & model,
4748
const std::string & bearer_token,
48-
bool offline);
49+
bool offline,
50+
const std::vector<std::pair<std::string, std::string>> & headers = {});
4951

5052
// returns list of cached models
5153
std::vector<common_cached_model_info> common_list_cached_models();

0 commit comments

Comments
 (0)