Skip to content

Commit a19202a

Browse files
ochafikochafik
andcommitted
refactor as suggested
Co-Authored-By: ochafik <[email protected]>
1 parent 7f7d859 commit a19202a

File tree

1 file changed

+146
-143
lines changed

1 file changed

+146
-143
lines changed

common/arg.cpp

Lines changed: 146 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -283,182 +283,185 @@ static bool common_download_file_single(const std::string & url, const std::stri
283283
std::string last_modified;
284284
};
285285

286+
if (offline) {
287+
if (file_exists) {
288+
LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
289+
return true; // skip verification/downloading
290+
}
291+
LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str());
292+
return false;
293+
}
294+
286295
common_load_model_from_url_headers headers;
287296
bool head_request_ok = false;
297+
bool should_download = !file_exists; // by default, we should download if the file does not exist
288298

289-
if (!file_exists || !offline)
290-
{
291-
bool should_download = !file_exists; // by default, we should download if the file does not exist
292-
293-
// Initialize libcurl
294-
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
295-
curl_slist_ptr http_headers;
296-
if (!curl) {
297-
LOG_ERR("%s: error initializing libcurl\n", __func__);
298-
return false;
299-
}
299+
// Initialize libcurl
300+
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
301+
curl_slist_ptr http_headers;
302+
if (!curl) {
303+
LOG_ERR("%s: error initializing libcurl\n", __func__);
304+
return false;
305+
}
300306

301-
// Set the URL, allow to follow http redirection
302-
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
303-
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
307+
// Set the URL, allow to follow http redirection
308+
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
309+
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
304310

305-
http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
306-
// Check if hf-token or bearer-token was specified
307-
if (!bearer_token.empty()) {
308-
std::string auth_header = "Authorization: Bearer " + bearer_token;
309-
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
310-
}
311-
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
311+
http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
312+
// Check if hf-token or bearer-token was specified
313+
if (!bearer_token.empty()) {
314+
std::string auth_header = "Authorization: Bearer " + bearer_token;
315+
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
316+
}
317+
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
312318

313319
#if defined(_WIN32)
314-
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
315-
// operating system. Currently implemented under MS-Windows.
316-
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
320+
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
321+
// operating system. Currently implemented under MS-Windows.
322+
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
317323
#endif
318324

319-
typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
320-
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
321-
common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
325+
typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
326+
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
327+
common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
322328

323-
static std::regex header_regex("([^:]+): (.*)\r\n");
324-
static std::regex etag_regex("ETag", std::regex_constants::icase);
325-
static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
329+
static std::regex header_regex("([^:]+): (.*)\r\n");
330+
static std::regex etag_regex("ETag", std::regex_constants::icase);
331+
static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
326332

327-
std::string header(buffer, n_items);
328-
std::smatch match;
329-
if (std::regex_match(header, match, header_regex)) {
330-
const std::string & key = match[1];
331-
const std::string & value = match[2];
332-
if (std::regex_match(key, match, etag_regex)) {
333-
headers->etag = value;
334-
} else if (std::regex_match(key, match, last_modified_regex)) {
335-
headers->last_modified = value;
336-
}
333+
std::string header(buffer, n_items);
334+
std::smatch match;
335+
if (std::regex_match(header, match, header_regex)) {
336+
const std::string & key = match[1];
337+
const std::string & value = match[2];
338+
if (std::regex_match(key, match, etag_regex)) {
339+
headers->etag = value;
340+
} else if (std::regex_match(key, match, last_modified_regex)) {
341+
headers->last_modified = value;
337342
}
338-
return n_items;
339-
};
340-
341-
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
342-
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress
343-
curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
344-
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
345-
346-
// we only allow retrying once for HEAD requests
347-
// this is for the use case of using running offline (no internet), retrying can be annoying
348-
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), 1, 0, "HEAD");
349-
if (!was_perform_successful) {
350-
head_request_ok = false;
351343
}
344+
return n_items;
345+
};
352346

353-
long http_code = 0;
354-
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
355-
if (http_code == 200) {
356-
head_request_ok = true;
357-
} else {
358-
LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
359-
head_request_ok = false;
360-
}
361-
362-
// if head_request_ok is false, we don't have the etag or last-modified headers
363-
// we leave should_download as-is, which is true if the file does not exist
364-
if (head_request_ok) {
365-
// check if ETag or Last-Modified headers are different
366-
// if it is, we need to download the file again
367-
if (!etag.empty() && etag != headers.etag) {
368-
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str());
369-
should_download = true;
370-
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
371-
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str());
372-
should_download = true;
373-
}
374-
}
375-
376-
if (should_download) {
377-
std::string path_temporary = path + ".downloadInProgress";
378-
if (file_exists) {
379-
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
380-
if (remove(path.c_str()) != 0) {
381-
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
382-
return false;
383-
}
384-
}
347+
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
348+
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress
349+
curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
350+
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
385351

386-
// Set the output file
352+
// we only allow retrying once for HEAD requests
353+
// this is for the use case of using running offline (no internet), retrying can be annoying
354+
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), 1, 0, "HEAD");
355+
if (!was_perform_successful) {
356+
head_request_ok = false;
357+
}
387358

388-
struct FILE_deleter {
389-
void operator()(FILE * f) const {
390-
fclose(f);
391-
}
392-
};
359+
long http_code = 0;
360+
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
361+
if (http_code == 200) {
362+
head_request_ok = true;
363+
} else {
364+
LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
365+
head_request_ok = false;
366+
}
393367

394-
std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "wb"));
395-
if (!outfile) {
396-
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str());
368+
// if head_request_ok is false, we don't have the etag or last-modified headers
369+
// we leave should_download as-is, which is true if the file does not exist
370+
if (head_request_ok) {
371+
// check if ETag or Last-Modified headers are different
372+
// if it is, we need to download the file again
373+
if (!etag.empty() && etag != headers.etag) {
374+
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str());
375+
should_download = true;
376+
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
377+
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str());
378+
should_download = true;
379+
}
380+
}
381+
382+
if (should_download) {
383+
std::string path_temporary = path + ".downloadInProgress";
384+
if (file_exists) {
385+
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
386+
if (remove(path.c_str()) != 0) {
387+
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
397388
return false;
398389
}
390+
}
399391

400-
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd);
401-
auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
402-
return fwrite(data, size, nmemb, (FILE *)fd);
403-
};
404-
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L);
405-
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
406-
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get());
392+
// Set the output file
407393

408-
// display download progress
409-
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L);
394+
struct FILE_deleter {
395+
void operator()(FILE * f) const {
396+
fclose(f);
397+
}
398+
};
410399

411-
// helper function to hide password in URL
412-
auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
413-
std::size_t protocol_pos = url.find("://");
414-
if (protocol_pos == std::string::npos) {
415-
return url; // Malformed URL
416-
}
400+
std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "wb"));
401+
if (!outfile) {
402+
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str());
403+
return false;
404+
}
417405

418-
std::size_t at_pos = url.find('@', protocol_pos + 3);
419-
if (at_pos == std::string::npos) {
420-
return url; // No password in URL
421-
}
406+
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd);
407+
auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
408+
return fwrite(data, size, nmemb, (FILE *)fd);
409+
};
410+
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L);
411+
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
412+
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get());
422413

423-
return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos);
424-
};
414+
// display download progress
415+
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L);
425416

426-
// start the download
427-
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
428-
llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str());
429-
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS, "GET");
430-
if (!was_perform_successful) {
431-
return false;
417+
// helper function to hide password in URL
418+
auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
419+
std::size_t protocol_pos = url.find("://");
420+
if (protocol_pos == std::string::npos) {
421+
return url; // Malformed URL
432422
}
433423

434-
long http_code = 0;
435-
curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
436-
if (http_code < 200 || http_code >= 400) {
437-
LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
438-
return false;
424+
std::size_t at_pos = url.find('@', protocol_pos + 3);
425+
if (at_pos == std::string::npos) {
426+
return url; // No password in URL
439427
}
440428

441-
// Causes file to be closed explicitly here before we rename it.
442-
outfile.reset();
429+
return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos);
430+
};
443431

444-
// Write the updated JSON metadata file.
445-
metadata.update({
446-
{"url", url},
447-
{"etag", headers.etag},
448-
{"lastModified", headers.last_modified}
449-
});
450-
write_file(metadata_path, metadata.dump(4));
451-
LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
432+
// start the download
433+
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
434+
llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str());
435+
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS, "GET");
436+
if (!was_perform_successful) {
437+
return false;
438+
}
452439

453-
if (rename(path_temporary.c_str(), path.c_str()) != 0) {
454-
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
455-
return false;
456-
}
457-
} else {
458-
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
440+
long http_code = 0;
441+
curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
442+
if (http_code < 200 || http_code >= 400) {
443+
LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
444+
return false;
445+
}
446+
447+
// Causes file to be closed explicitly here before we rename it.
448+
outfile.reset();
449+
450+
// Write the updated JSON metadata file.
451+
metadata.update({
452+
{"url", url},
453+
{"etag", headers.etag},
454+
{"lastModified", headers.last_modified}
455+
});
456+
write_file(metadata_path, metadata.dump(4));
457+
LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
458+
459+
if (rename(path_temporary.c_str(), path.c_str()) != 0) {
460+
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
461+
return false;
459462
}
460463
} else {
461-
LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
464+
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
462465
}
463466

464467
return true;

0 commit comments

Comments
 (0)