@@ -190,6 +190,12 @@ int32_t cpu_get_num_math() {
190190// CLI argument parsing
191191//
192192
193+ void gpt_params_handle_hf_token (gpt_params & params) {
194+ if (params.hf_token .empty () && std::getenv (" HF_TOKEN" )) {
195+ params.hf_token = std::getenv (" HF_TOKEN" );
196+ }
197+ }
198+
193199void gpt_params_handle_model_default (gpt_params & params) {
194200 if (!params.hf_repo .empty ()) {
195201 // short-hand to avoid specifying --hf-file -> default it to --model
@@ -237,6 +243,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
237243
238244 gpt_params_handle_model_default (params);
239245
246+ gpt_params_handle_hf_token (params);
247+
240248 if (params.escape ) {
241249 string_process_escapes (params.prompt );
242250 string_process_escapes (params.input_prefix );
@@ -652,6 +660,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
652660 params.model_url = argv[i];
653661 return true ;
654662 }
663+ if (arg == " -hft" || arg == " --hf-token" ) {
664+ if (++i >= argc) {
665+ invalid_param = true ;
666+ return true ;
667+ }
668+ params.hf_token = argv[i];
669+ return true ;
670+ }
655671 if (arg == " -hfr" || arg == " --hf-repo" ) {
656672 CHECK_ARG
657673 params.hf_repo = argv[i];
@@ -1576,6 +1592,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
15761592 options.push_back ({ " *" , " -mu, --model-url MODEL_URL" , " model download url (default: unused)" });
15771593 options.push_back ({ " *" , " -hfr, --hf-repo REPO" , " Hugging Face model repository (default: unused)" });
15781594 options.push_back ({ " *" , " -hff, --hf-file FILE" , " Hugging Face model file (default: unused)" });
1595+ options.push_back ({ " *" , " -hft, --hf-token TOKEN" , " Hugging Face access token (default: value from HF_TOKEN environment variable)" });
15791596
15801597 options.push_back ({ " retrieval" });
15811598 options.push_back ({ " retrieval" , " --context-file FNAME" , " file to load context from (repeat to specify multiple files)" });
@@ -2015,9 +2032,9 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
20152032 llama_model * model = nullptr ;
20162033
20172034 if (!params.hf_repo .empty () && !params.hf_file .empty ()) {
2018- model = llama_load_model_from_hf (params.hf_repo .c_str (), params.hf_file .c_str (), params.model .c_str (), mparams);
2035+ model = llama_load_model_from_hf (params.hf_repo .c_str (), params.hf_file .c_str (), params.model .c_str (), params. hf_token . c_str (), mparams);
20192036 } else if (!params.model_url .empty ()) {
2020- model = llama_load_model_from_url (params.model_url .c_str (), params.model .c_str (), mparams);
2037+ model = llama_load_model_from_url (params.model_url .c_str (), params.model .c_str (), params. hf_token . c_str (), mparams);
20212038 } else {
20222039 model = llama_load_model_from_file (params.model .c_str (), mparams);
20232040 }
@@ -2205,7 +2222,7 @@ static bool starts_with(const std::string & str, const std::string & prefix) {
22052222 return str.rfind (prefix, 0 ) == 0 ;
22062223}
22072224
2208- static bool llama_download_file (const std::string & url, const std::string & path) {
2225+ static bool llama_download_file (const std::string & url, const std::string & path, const std::string & hf_token ) {
22092226
22102227 // Initialize libcurl
22112228 std::unique_ptr<CURL, decltype (&curl_easy_cleanup)> curl (curl_easy_init (), &curl_easy_cleanup);
@@ -2220,6 +2237,15 @@ static bool llama_download_file(const std::string & url, const std::string & pat
22202237 curl_easy_setopt (curl.get (), CURLOPT_URL, url.c_str ());
22212238 curl_easy_setopt (curl.get (), CURLOPT_FOLLOWLOCATION, 1L );
22222239
2240+ // Check if hf-token or bearer-token was specified
2241+ if (!hf_token.empty ()) {
2242+ std::string auth_header = " Authorization: Bearer " ;
2243+ auth_header += hf_token.c_str ();
2244+ struct curl_slist *http_headers = NULL ;
2245+ http_headers = curl_slist_append (http_headers, auth_header.c_str ());
2246+ curl_easy_setopt (curl.get (), CURLOPT_HTTPHEADER, http_headers);
2247+ }
2248+
22232249#if defined(_WIN32)
22242250 // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
22252251 // operating system. Currently implemented under MS-Windows.
@@ -2415,14 +2441,15 @@ static bool llama_download_file(const std::string & url, const std::string & pat
24152441struct llama_model * llama_load_model_from_url (
24162442 const char * model_url,
24172443 const char * path_model,
2444+ const char * hf_token,
24182445 const struct llama_model_params & params) {
24192446 // Basic validation of the model_url
24202447 if (!model_url || strlen (model_url) == 0 ) {
24212448 fprintf (stderr, " %s: invalid model_url\n " , __func__);
24222449 return NULL ;
24232450 }
24242451
2425- if (!llama_download_file (model_url, path_model)) {
2452+ if (!llama_download_file (model_url, path_model, hf_token )) {
24262453 return NULL ;
24272454 }
24282455
@@ -2470,14 +2497,14 @@ struct llama_model * llama_load_model_from_url(
24702497 // Prepare download in parallel
24712498 std::vector<std::future<bool >> futures_download;
24722499 for (int idx = 1 ; idx < n_split; idx++) {
2473- futures_download.push_back (std::async (std::launch::async, [&split_prefix, &split_url_prefix, &n_split](int download_idx) -> bool {
2500+ futures_download.push_back (std::async (std::launch::async, [&split_prefix, &split_url_prefix, &n_split, hf_token ](int download_idx) -> bool {
24742501 char split_path[PATH_MAX] = {0 };
24752502 llama_split_path (split_path, sizeof (split_path), split_prefix, download_idx, n_split);
24762503
24772504 char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0 };
24782505 llama_split_path (split_url, sizeof (split_url), split_url_prefix, download_idx, n_split);
24792506
2480- return llama_download_file (split_url, split_path);
2507+ return llama_download_file (split_url, split_path, hf_token );
24812508 }, idx));
24822509 }
24832510
@@ -2496,6 +2523,7 @@ struct llama_model * llama_load_model_from_hf(
24962523 const char * repo,
24972524 const char * model,
24982525 const char * path_model,
2526+ const char * hf_token,
24992527 const struct llama_model_params & params) {
25002528 // construct hugging face model url:
25012529 //
@@ -2511,14 +2539,15 @@ struct llama_model * llama_load_model_from_hf(
25112539 model_url += " /resolve/main/" ;
25122540 model_url += model;
25132541
2514- return llama_load_model_from_url (model_url.c_str (), path_model, params);
2542+ return llama_load_model_from_url (model_url.c_str (), path_model, hf_token, params);
25152543}
25162544
25172545#else
25182546
25192547struct llama_model * llama_load_model_from_url (
25202548 const char * /* model_url*/ ,
25212549 const char * /* path_model*/ ,
2550+ const char * /* hf_token*/ ,
25222551 const struct llama_model_params & /* params*/ ) {
25232552 fprintf (stderr, " %s: llama.cpp built without libcurl, downloading from an url not supported.\n " , __func__);
25242553 return nullptr ;
@@ -2528,6 +2557,7 @@ struct llama_model * llama_load_model_from_hf(
25282557 const char * /* repo*/ ,
25292558 const char * /* model*/ ,
25302559 const char * /* path_model*/ ,
2560+ const char * /* hf_token*/ ,
25312561 const struct llama_model_params & /* params*/ ) {
25322562 fprintf (stderr, " %s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n " , __func__);
25332563 return nullptr ;
0 commit comments