@@ -296,172 +296,174 @@ static bool common_download_file_single(const std::string & url, const std::stri
296296 bool head_request_ok = false ;
297297 bool should_download = !file_exists; // by default, we should download if the file does not exist
298298
299- // Initialize libcurl
300- curl_ptr curl (curl_easy_init (), &curl_easy_cleanup);
301- curl_slist_ptr http_headers;
302- if (!curl) {
303- LOG_ERR (" %s: error initializing libcurl\n " , __func__);
304- return false ;
305- }
299+ {
300+ // Initialize libcurl
301+ curl_ptr curl (curl_easy_init (), &curl_easy_cleanup);
302+ curl_slist_ptr http_headers;
303+ if (!curl) {
304+ LOG_ERR (" %s: error initializing libcurl\n " , __func__);
305+ return false ;
306+ }
306307
307- // Set the URL, allow to follow http redirection
308- curl_easy_setopt (curl.get (), CURLOPT_URL, url.c_str ());
309- curl_easy_setopt (curl.get (), CURLOPT_FOLLOWLOCATION, 1L );
308+ // Set the URL, allow to follow http redirection
309+ curl_easy_setopt (curl.get (), CURLOPT_URL, url.c_str ());
310+ curl_easy_setopt (curl.get (), CURLOPT_FOLLOWLOCATION, 1L );
310311
311- http_headers.ptr = curl_slist_append (http_headers.ptr , " User-Agent: llama-cpp" );
312- // Check if hf-token or bearer-token was specified
313- if (!bearer_token.empty ()) {
314- std::string auth_header = " Authorization: Bearer " + bearer_token;
315- http_headers.ptr = curl_slist_append (http_headers.ptr , auth_header.c_str ());
316- }
317- curl_easy_setopt (curl.get (), CURLOPT_HTTPHEADER, http_headers.ptr );
312+ http_headers.ptr = curl_slist_append (http_headers.ptr , " User-Agent: llama-cpp" );
313+ // Check if hf-token or bearer-token was specified
314+ if (!bearer_token.empty ()) {
315+ std::string auth_header = " Authorization: Bearer " + bearer_token;
316+ http_headers.ptr = curl_slist_append (http_headers.ptr , auth_header.c_str ());
317+ }
318+ curl_easy_setopt (curl.get (), CURLOPT_HTTPHEADER, http_headers.ptr );
318319
319320#if defined(_WIN32)
320- // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
321- // operating system. Currently implemented under MS-Windows.
322- curl_easy_setopt (curl.get (), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
321+ // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
322+ // operating system. Currently implemented under MS-Windows.
323+ curl_easy_setopt (curl.get (), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
323324#endif
324325
325- typedef size_t (*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t , size_t , void *);
326- auto header_callback = [](char * buffer, size_t /* size*/ , size_t n_items, void * userdata) -> size_t {
327- common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
326+ typedef size_t (*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t , size_t , void *);
327+ auto header_callback = [](char * buffer, size_t /* size*/ , size_t n_items, void * userdata) -> size_t {
328+ common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
328329
329- static std::regex header_regex (" ([^:]+): (.*)\r\n " );
330- static std::regex etag_regex (" ETag" , std::regex_constants::icase);
331- static std::regex last_modified_regex (" Last-Modified" , std::regex_constants::icase);
330+ static std::regex header_regex (" ([^:]+): (.*)\r\n " );
331+ static std::regex etag_regex (" ETag" , std::regex_constants::icase);
332+ static std::regex last_modified_regex (" Last-Modified" , std::regex_constants::icase);
332333
333- std::string header (buffer, n_items);
334- std::smatch match;
335- if (std::regex_match (header, match, header_regex)) {
336- const std::string & key = match[1 ];
337- const std::string & value = match[2 ];
338- if (std::regex_match (key, match, etag_regex)) {
339- headers->etag = value;
340- } else if (std::regex_match (key, match, last_modified_regex)) {
341- headers->last_modified = value;
334+ std::string header (buffer, n_items);
335+ std::smatch match;
336+ if (std::regex_match (header, match, header_regex)) {
337+ const std::string & key = match[1 ];
338+ const std::string & value = match[2 ];
339+ if (std::regex_match (key, match, etag_regex)) {
340+ headers->etag = value;
341+ } else if (std::regex_match (key, match, last_modified_regex)) {
342+ headers->last_modified = value;
343+ }
342344 }
343- }
344- return n_items;
345- };
345+ return n_items;
346+ };
346347
347- curl_easy_setopt (curl.get (), CURLOPT_NOBODY, 1L ); // will trigger the HEAD verb
348- curl_easy_setopt (curl.get (), CURLOPT_NOPROGRESS, 1L ); // hide head request progress
349- curl_easy_setopt (curl.get (), CURLOPT_HEADERFUNCTION, static_cast <CURLOPT_HEADERFUNCTION_PTR>(header_callback));
350- curl_easy_setopt (curl.get (), CURLOPT_HEADERDATA, &headers);
348+ curl_easy_setopt (curl.get (), CURLOPT_NOBODY, 1L ); // will trigger the HEAD verb
349+ curl_easy_setopt (curl.get (), CURLOPT_NOPROGRESS, 1L ); // hide head request progress
350+ curl_easy_setopt (curl.get (), CURLOPT_HEADERFUNCTION, static_cast <CURLOPT_HEADERFUNCTION_PTR>(header_callback));
351+ curl_easy_setopt (curl.get (), CURLOPT_HEADERDATA, &headers);
351352
352- // we only allow retrying once for HEAD requests
353- // this is for the use case of using running offline (no internet), retrying can be annoying
354- bool was_perform_successful = curl_perform_with_retry (url, curl.get (), 1 , 0 , " HEAD" );
355- if (!was_perform_successful) {
356- head_request_ok = false ;
357- }
353+ // we only allow retrying once for HEAD requests
354+ // this is for the use case of using running offline (no internet), retrying can be annoying
355+ bool was_perform_successful = curl_perform_with_retry (url, curl.get (), 1 , 0 , " HEAD" );
356+ if (!was_perform_successful) {
357+ head_request_ok = false ;
358+ }
358359
359- long http_code = 0 ;
360- curl_easy_getinfo (curl.get (), CURLINFO_RESPONSE_CODE, &http_code);
361- if (http_code == 200 ) {
362- head_request_ok = true ;
363- } else {
364- LOG_WRN (" %s: HEAD invalid http status code received: %ld\n " , __func__, http_code);
365- head_request_ok = false ;
366- }
360+ long http_code = 0 ;
361+ curl_easy_getinfo (curl.get (), CURLINFO_RESPONSE_CODE, &http_code);
362+ if (http_code == 200 ) {
363+ head_request_ok = true ;
364+ } else {
365+ LOG_WRN (" %s: HEAD invalid http status code received: %ld\n " , __func__, http_code);
366+ head_request_ok = false ;
367+ }
368+
369+ // if head_request_ok is false, we don't have the etag or last-modified headers
370+ // we leave should_download as-is, which is true if the file does not exist
371+ if (head_request_ok) {
372+ // check if ETag or Last-Modified headers are different
373+ // if it is, we need to download the file again
374+ if (!etag.empty () && etag != headers.etag ) {
375+ LOG_WRN (" %s: ETag header is different (%s != %s): triggering a new download\n " , __func__, etag.c_str (), headers.etag .c_str ());
376+ should_download = true ;
377+ } else if (!last_modified.empty () && last_modified != headers.last_modified ) {
378+ LOG_WRN (" %s: Last-Modified header is different (%s != %s): triggering a new download\n " , __func__, last_modified.c_str (), headers.last_modified .c_str ());
379+ should_download = true ;
380+ }
381+ }
382+
383+ if (should_download) {
384+ std::string path_temporary = path + " .downloadInProgress" ;
385+ if (file_exists) {
386+ LOG_WRN (" %s: deleting previous downloaded file: %s\n " , __func__, path.c_str ());
387+ if (remove (path.c_str ()) != 0 ) {
388+ LOG_ERR (" %s: unable to delete file: %s\n " , __func__, path.c_str ());
389+ return false ;
390+ }
391+ }
367392
368- // if head_request_ok is false, we don't have the etag or last-modified headers
369- // we leave should_download as-is, which is true if the file does not exist
370- if (head_request_ok) {
371- // check if ETag or Last-Modified headers are different
372- // if it is, we need to download the file again
373- if (!etag.empty () && etag != headers.etag ) {
374- LOG_WRN (" %s: ETag header is different (%s != %s): triggering a new download\n " , __func__, etag.c_str (), headers.etag .c_str ());
375- should_download = true ;
376- } else if (!last_modified.empty () && last_modified != headers.last_modified ) {
377- LOG_WRN (" %s: Last-Modified header is different (%s != %s): triggering a new download\n " , __func__, last_modified.c_str (), headers.last_modified .c_str ());
378- should_download = true ;
379- }
380- }
393+ // Set the output file
381394
382- if (should_download) {
383- std::string path_temporary = path + " .downloadInProgress" ;
384- if (file_exists) {
385- LOG_WRN (" %s: deleting previous downloaded file: %s\n " , __func__, path.c_str ());
386- if (remove (path.c_str ()) != 0 ) {
387- LOG_ERR (" %s: unable to delete file: %s\n " , __func__, path.c_str ());
395+ struct FILE_deleter {
396+ void operator ()(FILE * f) const {
397+ fclose (f);
398+ }
399+ };
400+
401+ std::unique_ptr<FILE, FILE_deleter> outfile (fopen (path_temporary.c_str (), " wb" ));
402+ if (!outfile) {
403+ LOG_ERR (" %s: error opening local file for writing: %s\n " , __func__, path.c_str ());
388404 return false ;
389405 }
390- }
391406
392- // Set the output file
407+ typedef size_t (*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd);
408+ auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
409+ return fwrite (data, size, nmemb, (FILE *)fd);
410+ };
411+ curl_easy_setopt (curl.get (), CURLOPT_NOBODY, 0L );
412+ curl_easy_setopt (curl.get (), CURLOPT_WRITEFUNCTION, static_cast <CURLOPT_WRITEFUNCTION_PTR>(write_callback));
413+ curl_easy_setopt (curl.get (), CURLOPT_WRITEDATA, outfile.get ());
393414
394- struct FILE_deleter {
395- void operator ()(FILE * f) const {
396- fclose (f);
397- }
398- };
415+ // display download progress
416+ curl_easy_setopt (curl.get (), CURLOPT_NOPROGRESS, 0L );
399417
400- std::unique_ptr<FILE, FILE_deleter> outfile (fopen (path_temporary.c_str (), " wb" ));
401- if (!outfile) {
402- LOG_ERR (" %s: error opening local file for writing: %s\n " , __func__, path.c_str ());
403- return false ;
404- }
418+ // helper function to hide password in URL
419+ auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
420+ std::size_t protocol_pos = url.find (" ://" );
421+ if (protocol_pos == std::string::npos) {
422+ return url; // Malformed URL
423+ }
405424
406- typedef size_t (*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd);
407- auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
408- return fwrite (data, size, nmemb, (FILE *)fd);
409- };
410- curl_easy_setopt (curl.get (), CURLOPT_NOBODY, 0L );
411- curl_easy_setopt (curl.get (), CURLOPT_WRITEFUNCTION, static_cast <CURLOPT_WRITEFUNCTION_PTR>(write_callback));
412- curl_easy_setopt (curl.get (), CURLOPT_WRITEDATA, outfile.get ());
425+ std::size_t at_pos = url.find (' @' , protocol_pos + 3 );
426+ if (at_pos == std::string::npos) {
427+ return url; // No password in URL
428+ }
413429
414- // display download progress
415- curl_easy_setopt (curl. get (), CURLOPT_NOPROGRESS, 0L ) ;
430+ return url. substr ( 0 , protocol_pos + 3 ) + " ******** " + url. substr (at_pos);
431+ } ;
416432
417- // helper function to hide password in URL
418- auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
419- std::size_t protocol_pos = url.find (" ://" );
420- if (protocol_pos == std::string::npos) {
421- return url; // Malformed URL
433+ // start the download
434+ LOG_INF (" %s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n " , __func__,
435+ llama_download_hide_password_in_url (url).c_str (), path.c_str (), headers.etag .c_str (), headers.last_modified .c_str ());
436+ bool was_perform_successful = curl_perform_with_retry (url, curl.get (), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS, " GET" );
437+ if (!was_perform_successful) {
438+ return false ;
422439 }
423440
424- std::size_t at_pos = url.find (' @' , protocol_pos + 3 );
425- if (at_pos == std::string::npos) {
426- return url; // No password in URL
441+ long http_code = 0 ;
442+ curl_easy_getinfo (curl.get (), CURLINFO_RESPONSE_CODE, &http_code);
443+ if (http_code < 200 || http_code >= 400 ) {
444+ LOG_ERR (" %s: invalid http status code received: %ld\n " , __func__, http_code);
445+ return false ;
427446 }
428447
429- return url. substr ( 0 , protocol_pos + 3 ) + " ******** " + url. substr (at_pos);
430- } ;
448+ // Causes file to be closed explicitly here before we rename it.
449+ outfile. reset () ;
431450
432- // start the download
433- LOG_INF (" %s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n " , __func__,
434- llama_download_hide_password_in_url (url).c_str (), path.c_str (), headers.etag .c_str (), headers.last_modified .c_str ());
435- bool was_perform_successful = curl_perform_with_retry (url, curl.get (), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS, " GET" );
436- if (!was_perform_successful) {
437- return false ;
438- }
439-
440- long http_code = 0 ;
441- curl_easy_getinfo (curl.get (), CURLINFO_RESPONSE_CODE, &http_code);
442- if (http_code < 200 || http_code >= 400 ) {
443- LOG_ERR (" %s: invalid http status code received: %ld\n " , __func__, http_code);
444- return false ;
445- }
451+ // Write the updated JSON metadata file.
452+ metadata.update ({
453+ {" url" , url},
454+ {" etag" , headers.etag },
455+ {" lastModified" , headers.last_modified }
456+ });
457+ write_file (metadata_path, metadata.dump (4 ));
458+ LOG_DBG (" %s: file metadata saved: %s\n " , __func__, metadata_path.c_str ());
446459
447- // Causes file to be closed explicitly here before we rename it.
448- outfile.reset ();
449-
450- // Write the updated JSON metadata file.
451- metadata.update ({
452- {" url" , url},
453- {" etag" , headers.etag },
454- {" lastModified" , headers.last_modified }
455- });
456- write_file (metadata_path, metadata.dump (4 ));
457- LOG_DBG (" %s: file metadata saved: %s\n " , __func__, metadata_path.c_str ());
458-
459- if (rename (path_temporary.c_str (), path.c_str ()) != 0 ) {
460- LOG_ERR (" %s: unable to rename file: %s to %s\n " , __func__, path_temporary.c_str (), path.c_str ());
461- return false ;
460+ if (rename (path_temporary.c_str (), path.c_str ()) != 0 ) {
461+ LOG_ERR (" %s: unable to rename file: %s to %s\n " , __func__, path_temporary.c_str (), path.c_str ());
462+ return false ;
463+ }
464+ } else {
465+ LOG_INF (" %s: using cached file: %s\n " , __func__, path.c_str ());
462466 }
463- } else {
464- LOG_INF (" %s: using cached file: %s\n " , __func__, path.c_str ());
465467 }
466468
467469 return true ;
0 commit comments