Skip to content

Commit da6dd67

Browse files
committed
refactor model endpoint
1 parent 4c2abb3 commit da6dd67

File tree

4 files changed

+16
-19
lines changed

4 files changed

+16
-19
lines changed

common/arg.cpp

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ static bool common_download_file_single(const std::string & url, const std::stri
228228
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
229229
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
230230

231+
http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
231232
// Check if hf-token or bearer-token was specified
232233
if (!bearer_token.empty()) {
233234
std::string auth_header = "Authorization: Bearer " + bearer_token;
@@ -374,7 +375,6 @@ static bool common_download_file_single(const std::string & url, const std::stri
374375

375376
// display download progress
376377
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L);
377-
curl_easy_setopt(curl.get(), CURLOPT_USERAGENT, "llama.cpp/1.0");
378378

379379

380380
// helper function to hide password in URL
@@ -547,12 +547,7 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_
547547
curl_slist_ptr http_headers;
548548
std::string res_str;
549549

550-
std::string model_endpoint = "https://huggingface.co/";
551-
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
552-
if (model_endpoint_env) {
553-
model_endpoint = model_endpoint_env;
554-
if (model_endpoint.back() != '/') model_endpoint += '/';
555-
}
550+
std::string model_endpoint = get_model_endpoint();
556551

557552
std::string url = model_endpoint + "v2/" + hf_repo + "/manifests/" + tag;
558553
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
@@ -669,12 +664,7 @@ static void common_params_handle_model(
669664
}
670665
}
671666

672-
std::string model_endpoint = "https://huggingface.co/";
673-
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
674-
if (model_endpoint_env) {
675-
model_endpoint = model_endpoint_env;
676-
if (model_endpoint.back() != '/') model_endpoint += '/';
677-
}
667+
std::string model_endpoint = get_model_endpoint();
678668
model.url = model_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file;
679669
// make sure model path is present (for caching purposes)
680670
if (model.path.empty()) {

common/common.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,6 +1027,16 @@ struct common_init_result common_init_from_params(common_params & params) {
10271027
return iparams;
10281028
}
10291029

1030+
std::string get_model_endpoint() {
1031+
std::string model_endpoint = "https://huggingface.co/";
1032+
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
1033+
if (model_endpoint_env) {
1034+
model_endpoint = model_endpoint_env;
1035+
if (model_endpoint.back() != '/') model_endpoint += '/';
1036+
}
1037+
return model_endpoint;
1038+
}
1039+
10301040
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
10311041
llama_clear_adapter_lora(ctx);
10321042
for (auto & la : lora) {

common/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,8 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p
543543
// clear LoRA adapters from context, then apply new list of adapters
544544
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
545545

546+
std::string get_model_endpoint();
547+
546548
//
547549
// Batch utils
548550
//

examples/run/run.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -697,12 +697,7 @@ class LlamaData {
697697
std::vector<std::string> headers = { "User-Agent: llama-cpp", "Accept: application/json" };
698698
std::string url;
699699

700-
std::string model_endpoint = "https://huggingface.co/";
701-
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
702-
if (model_endpoint_env) {
703-
model_endpoint = model_endpoint_env;
704-
if (model_endpoint.back() != '/') model_endpoint += '/';
705-
}
700+
std::string model_endpoint = get_model_endpoint();
706701

707702
if (pos == std::string::npos) {
708703
auto [model_name, manifest_url] = extract_model_and_tag(model, model_endpoint + "v2/");

0 commit comments

Comments
 (0)