Skip to content

Commit 142fd71

Browse files
committed
Merge remote-tracking branch 'origin/common-support-custom-http-headers-for-model-downloads' into installama
2 parents 08a0180 + dbc12c4 commit 142fd71

File tree

2 files changed

+52
-26
lines changed

2 files changed

+52
-26
lines changed

common/download.cpp

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,8 @@ static bool common_download_head(CURL * curl,
303303
// download one single file from remote URL to local path
304304
static bool common_download_file_single_online(const std::string & url,
305305
const std::string & path,
306-
const std::string & bearer_token) {
306+
const std::string & bearer_token,
307+
const std::vector<std::pair<std::string, std::string>> & headers) {
307308
static const int max_attempts = 3;
308309
static const int retry_delay_seconds = 2;
309310
for (int i = 0; i < max_attempts; ++i) {
@@ -322,10 +323,14 @@ static bool common_download_file_single_online(const std::string & url,
322323

323324
// Initialize libcurl
324325
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
325-
common_load_model_from_url_headers headers;
326-
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
326+
common_load_model_from_url_headers response_headers;
327+
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &response_headers);
327328
curl_slist_ptr http_headers;
328-
const bool was_perform_successful = common_download_head(curl.get(), http_headers, url, bearer_token);
329+
for (const auto & h : headers) {
330+
auto header_str = h.first + ": " + h.second;
331+
http_headers.ptr = curl_slist_append(http_headers.ptr, header_str.c_str());
332+
}
333+
const bool was_perform_successful = common_download_head(curl.get(), http_headers, url, bearer_token);
329334
if (!was_perform_successful) {
330335
head_request_ok = false;
331336
}
@@ -345,15 +350,15 @@ static bool common_download_file_single_online(const std::string & url,
345350
if (head_request_ok) {
346351
// check if ETag or Last-Modified headers are different
347352
// if it is, we need to download the file again
348-
if (!etag.empty() && etag != headers.etag) {
353+
if (!etag.empty() && etag != response_headers.etag) {
349354
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(),
350-
headers.etag.c_str());
355+
response_headers.etag.c_str());
351356
should_download = true;
352357
should_download_from_scratch = true;
353358
}
354359
}
355360

356-
const bool accept_ranges_supported = !headers.accept_ranges.empty() && headers.accept_ranges != "none";
361+
const bool accept_ranges_supported = !response_headers.accept_ranges.empty() && response_headers.accept_ranges != "none";
357362
if (should_download) {
358363
if (file_exists &&
359364
!accept_ranges_supported) { // Resumable downloads not supported, delete and start again.
@@ -381,13 +386,13 @@ static bool common_download_file_single_online(const std::string & url,
381386
}
382387
}
383388
if (head_request_ok) {
384-
write_etag(path, headers.etag);
389+
write_etag(path, response_headers.etag);
385390
}
386391

387392
// start the download
388393
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n",
389394
__func__, llama_download_hide_password_in_url(url).c_str(), path_temporary.c_str(),
390-
headers.etag.c_str(), headers.last_modified.c_str());
395+
response_headers.etag.c_str(), response_headers.last_modified.c_str());
391396
const bool was_pull_successful = common_pull_file(curl.get(), path_temporary);
392397
if (!was_pull_successful) {
393398
if (i + 1 < max_attempts) {
@@ -433,7 +438,7 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
433438
curl_easy_setopt(curl.get(), CURLOPT_VERBOSE, 1L);
434439
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
435440
auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
436-
auto data_vec = static_cast<std::vector<char> *>(data);
441+
auto *data_vec = static_cast<std::vector<char> *>(data);
437442
data_vec->insert(data_vec->end(), (char *)ptr, (char *)ptr + size * nmemb);
438443
return size * nmemb;
439444
};
@@ -572,7 +577,8 @@ static bool common_pull_file(httplib::Client & cli,
572577
// download one single file from remote URL to local path
573578
static bool common_download_file_single_online(const std::string & url,
574579
const std::string & path,
575-
const std::string & bearer_token) {
580+
const std::string & bearer_token,
581+
const std::vector<std::pair<std::string, std::string>> & headers) {
576582
static const int max_attempts = 3;
577583
static const int retry_delay_seconds = 2;
578584

@@ -582,6 +588,9 @@ static bool common_download_file_single_online(const std::string & url,
582588
if (!bearer_token.empty()) {
583589
default_headers.insert({"Authorization", "Bearer " + bearer_token});
584590
}
591+
for (const auto & h : headers) {
592+
default_headers.insert({h.first, h.second});
593+
}
585594
cli.set_default_headers(default_headers);
586595

587596
const bool file_exists = std::filesystem::exists(path);
@@ -725,9 +734,10 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
725734
static bool common_download_file_single(const std::string & url,
726735
const std::string & path,
727736
const std::string & bearer_token,
728-
bool offline) {
737+
bool offline,
738+
const std::vector<std::pair<std::string, std::string>> & headers) {
729739
if (!offline) {
730-
return common_download_file_single_online(url, path, bearer_token);
740+
return common_download_file_single_online(url, path, bearer_token, headers);
731741
}
732742

733743
if (!std::filesystem::exists(path)) {
@@ -741,13 +751,24 @@ static bool common_download_file_single(const std::string & url,
741751

742752
// download multiple files from remote URLs to local paths
743753
// the input is a vector of pairs <url, path>
744-
static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> & urls, const std::string & bearer_token, bool offline) {
754+
static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> & urls,
755+
const std::string & bearer_token,
756+
bool offline,
757+
const std::vector<std::pair<std::string, std::string>> & headers) {
745758
// Prepare download in parallel
746759
std::vector<std::future<bool>> futures_download;
760+
futures_download.reserve(urls.size());
761+
747762
for (auto const & item : urls) {
748-
futures_download.push_back(std::async(std::launch::async, [bearer_token, offline](const std::pair<std::string, std::string> & it) -> bool {
749-
return common_download_file_single(it.first, it.second, bearer_token, offline);
750-
}, item));
763+
futures_download.push_back(
764+
std::async(
765+
std::launch::async,
766+
[&bearer_token, offline, &headers](const std::pair<std::string, std::string> & it) -> bool {
767+
return common_download_file_single(it.first, it.second, bearer_token, offline, headers);
768+
},
769+
item
770+
)
771+
);
751772
}
752773

753774
// Wait for all downloads to complete
@@ -760,17 +781,17 @@ static bool common_download_file_multiple(const std::vector<std::pair<std::strin
760781
return true;
761782
}
762783

763-
bool common_download_model(
764-
const common_params_model & model,
765-
const std::string & bearer_token,
766-
bool offline) {
784+
bool common_download_model(const common_params_model & model,
785+
const std::string & bearer_token,
786+
bool offline,
787+
const std::vector<std::pair<std::string, std::string>> & headers) {
767788
// Basic validation of the model.url
768789
if (model.url.empty()) {
769790
LOG_ERR("%s: invalid model url\n", __func__);
770791
return false;
771792
}
772793

773-
if (!common_download_file_single(model.url, model.path, bearer_token, offline)) {
794+
if (!common_download_file_single(model.url, model.path, bearer_token, offline, headers)) {
774795
return false;
775796
}
776797

@@ -829,7 +850,7 @@ bool common_download_model(
829850
}
830851

831852
// Download in parallel
832-
common_download_file_multiple(urls, bearer_token, offline);
853+
common_download_file_multiple(urls, bearer_token, offline, headers);
833854
}
834855

835856
return true;
@@ -1023,7 +1044,7 @@ std::string common_docker_resolve_model(const std::string & docker) {
10231044
std::string local_path = fs_get_cache_file(model_filename);
10241045

10251046
const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
1026-
if (!common_download_file_single(blob_url, local_path, token, false)) {
1047+
if (!common_download_file_single(blob_url, local_path, token, false, {})) {
10271048
throw std::runtime_error("Failed to download Docker Model");
10281049
}
10291050

@@ -1041,7 +1062,10 @@ common_hf_file_res common_get_hf_file(const std::string &, const std::string &,
10411062
throw std::runtime_error("download functionality is not enabled in this build");
10421063
}
10431064

1044-
bool common_download_model(const common_params_model &, const std::string &, bool) {
1065+
bool common_download_model(const common_params_model &,
1066+
const std::string &,
1067+
bool,
1068+
const std::vector<std::pair<std::string, std::string>> &) {
10451069
throw std::runtime_error("download functionality is not enabled in this build");
10461070
}
10471071

common/download.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <string>
4+
#include <vector>
45

56
struct common_params_model;
67

@@ -45,7 +46,8 @@ common_hf_file_res common_get_hf_file(
4546
bool common_download_model(
4647
const common_params_model & model,
4748
const std::string & bearer_token,
48-
bool offline);
49+
bool offline,
50+
const std::vector<std::pair<std::string, std::string>> & headers = {});
4951

5052
// returns list of cached models
5153
std::vector<common_cached_model_info> common_list_cached_models();

0 commit comments

Comments
 (0)