Skip to content

Commit e6c4319

Browse files
committed
common : add common_remote_get_content
1 parent d5fe4e8 commit e6c4319

File tree

2 files changed

+55
-33
lines changed

2 files changed

+55
-33
lines changed

common/arg.cpp

Lines changed: 52 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -527,52 +527,30 @@ static bool common_download_model(
527527
return true;
528528
}
529529

530-
/**
531-
* Allow getting the HF file from the HF repo with tag (like ollama), for example:
532-
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
533-
* - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M
534-
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s
535-
* Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo)
536-
*
537-
* Return pair of <repo, file> (with "repo" already having tag removed)
538-
*
539-
* Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
540-
*/
541-
static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token) {
542-
auto parts = string_split<std::string>(hf_repo_with_tag, ':');
543-
std::string tag = parts.size() > 1 ? parts.back() : "latest";
544-
std::string hf_repo = parts[0];
545-
if (string_split<std::string>(hf_repo, '/').size() != 2) {
546-
throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
547-
}
548-
549-
// fetch model info from Hugging Face Hub API
530+
// get remote file content, returns <http_code, content>
531+
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const std::vector<std::string> & headers) {
550532
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
551533
curl_slist_ptr http_headers;
552-
std::string res_str;
553-
554-
std::string model_endpoint = get_model_endpoint();
534+
std::vector<char> res_buffer;
555535

556-
std::string url = model_endpoint + "v2/" + hf_repo + "/manifests/" + tag;
557536
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
558537
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
559538
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
560539
auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
561-
static_cast<std::string *>(data)->append((char * ) ptr, size * nmemb);
540+
auto data_vec = static_cast<std::vector<char> *>(data);
541+
data_vec->insert(data_vec->end(), (char *)ptr, (char *)ptr + size * nmemb);
562542
return size * nmemb;
563543
};
564544
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
565-
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str);
545+
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_buffer);
566546
#if defined(_WIN32)
567547
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
568548
#endif
569-
if (!bearer_token.empty()) {
570-
std::string auth_header = "Authorization: Bearer " + bearer_token;
571-
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
572-
}
573549
// Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
574550
http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
575-
http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json");
551+
for (const auto & header : headers) {
552+
http_headers.ptr = curl_slist_append(http_headers.ptr, header.c_str());
553+
}
576554
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
577555

578556
CURLcode res = curl_easy_perform(curl.get());
@@ -582,9 +560,46 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_
582560
}
583561

584562
long res_code;
585-
std::string ggufFile = "";
586-
std::string mmprojFile = "";
587563
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code);
564+
565+
return { res_code, res_buffer };
566+
}
567+
568+
/**
569+
* Allow getting the HF file from the HF repo with tag (like ollama), for example:
570+
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
571+
* - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M
572+
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s
573+
* Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo)
574+
*
575+
* Return pair of <repo, file> (with "repo" already having tag removed)
576+
*
577+
* Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
578+
*/
579+
static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token) {
580+
auto parts = string_split<std::string>(hf_repo_with_tag, ':');
581+
std::string tag = parts.size() > 1 ? parts.back() : "latest";
582+
std::string hf_repo = parts[0];
583+
if (string_split<std::string>(hf_repo, '/').size() != 2) {
584+
throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
585+
}
586+
587+
std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag;
588+
589+
// headers
590+
std::vector<std::string> headers;
591+
headers.push_back("Accept: application/json");
592+
if (!bearer_token.empty()) {
593+
headers.push_back("Authorization: Bearer " + bearer_token);
594+
}
595+
596+
// make the request
597+
auto res = common_remote_get_content(url, headers);
598+
long res_code = res.first;
599+
std::string res_str(res.second.data(), res.second.size());
600+
std::string ggufFile;
601+
std::string mmprojFile;
602+
588603
if (res_code == 200) {
589604
// extract ggufFile.rfilename in json, using regex
590605
{
@@ -640,6 +655,10 @@ static struct common_hf_file_res common_get_hf_file(const std::string &, const s
640655
return {};
641656
}
642657

658+
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const std::vector<std::string> & headers) {
659+
throw std::runtime_error("error: built without CURL, cannot download model from the internet");
660+
}
661+
643662
#endif // LLAMA_USE_CURL
644663

645664
//

common/arg.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,6 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
7878

7979
// function to be used by test-arg-parser
8080
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
81+
82+
// get remote file content, returns <http_code, content>
83+
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const std::vector<std::string> & headers);

0 commit comments

Comments
 (0)