diff --git a/common/arg.cpp b/common/arg.cpp index 0f01bb31454a4..23f649f21bc48 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -37,6 +38,11 @@ #include #include #include +#ifdef _WIN32 +#include +#else +#include +#endif #endif using json = nlohmann::ordered_json; @@ -220,19 +226,123 @@ struct curl_slist_ptr { #define CURL_MAX_RETRY 3 #define CURL_RETRY_DELAY_SECONDS 2 +// Header callback state for detecting resume support +struct curl_resume_state { + bool attempting_resume; + curl_off_t resume_offset; + FILE * file_ptr; + CURL * curl_handle; + bool server_supports_resume; + bool decided; // Only process once per response +}; + +static size_t curl_resume_header_callback(char * buffer, size_t size, size_t n_items, void * userdata) { + size_t n_bytes = size * n_items; + curl_resume_state * state = static_cast(userdata); + + if (!state || !state->attempting_resume || state->decided) { + return n_bytes; // Not resuming or already processed + } + + // Parse status line (e.g., "HTTP/1.1 200 OK\r\n" or "HTTP/2 200\r\n") + if (n_bytes > 9 && strncmp(buffer, "HTTP/", 5) == 0) { + // Find first space and parse status code from there + const char * space = strchr(buffer, ' '); + if (space) { + int status_code = atoi(space + 1); + if (status_code == 200) { + // Server ignored our range request - need to start fresh + LOG_WRN("curl: server returned 200 instead of 206 - doesn't support resume\n"); + state->server_supports_resume = false; + state->decided = true; + + // Truncate file to 0 and rewind immediately + if (state->file_ptr) { + fflush(state->file_ptr); +#ifdef _WIN32 + _chsize_s(_fileno(state->file_ptr), 0); +#else + if (ftruncate(fileno(state->file_ptr), 0) != 0) { + LOG_WRN("curl: failed to truncate file\n"); + } +#endif + rewind(state->file_ptr); + } + } else if (status_code == 206) { + // Partial content - resume is working + state->server_supports_resume = true; + state->decided = true; + } + } + } + + // Also check for Content-Range header as confirmation of resume support + if (n_bytes > 14 && strncasecmp(buffer, "Content-Range:", 14) == 0) { + state->server_supports_resume = true; + state->decided = true; + } + + return n_bytes; +} + static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds, const char * method_name) { int remaining_attempts = max_attempts; + // Check if this is a download operation (GET with WRITEDATA set) + void * pv = nullptr; + curl_easy_getinfo(curl, CURLINFO_PRIVATE, &pv); + FILE * write_file = static_cast(pv); + while (remaining_attempts > 0) { + // Initialize state for this attempt + curl_resume_state state = {false, 0, write_file, curl, false, false}; + + // For resume support on GET requests with file output after failure + if (write_file && strcmp(method_name, "GET") == 0 && max_attempts - remaining_attempts > 0) { + // Flush any pending data and get current file position + fflush(write_file); + curl_off_t file_size = ftell(write_file); + if (file_size > 0) { + // Use CURLOPT_RESUME_FROM_LARGE for proper resume + curl_easy_setopt(curl, CURLOPT_RESUME_FROM_LARGE, file_size); + LOG_INF("%s: resuming download from byte %lld\n", __func__, (long long)file_size); + + // Set up header callback to detect if server supports resume + state.attempting_resume = true; + state.resume_offset = file_size; + curl_easy_setopt(curl, CURLOPT_HEADERFUNCTION, curl_resume_header_callback); + curl_easy_setopt(curl, CURLOPT_HEADERDATA, &state); + } + } + LOG_INF("%s: %s %s (attempt %d of %d)...\n", __func__ , method_name, url.c_str(), max_attempts - remaining_attempts + 1, max_attempts); CURLcode res = curl_easy_perform(curl); + + // Restore original header callback if we changed it + if (state.attempting_resume) { + curl_easy_setopt(curl, CURLOPT_HEADERFUNCTION, nullptr); + curl_easy_setopt(curl, CURLOPT_HEADERDATA, nullptr); + } + if (res == CURLE_OK) { return true; } - int exponential_backoff_delay = std::pow(retry_delay_seconds, max_attempts - remaining_attempts) * 1000; - LOG_WRN("%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay); + // Check for specific error conditions + long http_code = 0; + curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code); + + // HTTP 416 Range Not Satisfiable means file is already complete + if (http_code == 416) { + LOG_INF("%s: file already complete (HTTP 416)\n", __func__); + return true; + } + + // Fixed exponential backoff: multiply delay by 2 for each retry + int exponential_backoff_delay = retry_delay_seconds * std::pow(2, max_attempts - remaining_attempts) * 1000; + LOG_WRN("%s: curl_easy_perform() failed: %s (HTTP %ld), retrying after %d milliseconds...\n", + __func__, curl_easy_strerror(res), http_code, exponential_backoff_delay); remaining_attempts--; if (remaining_attempts == 0) break; @@ -249,6 +359,16 @@ static bool common_download_file_single(const std::string & url, const std::stri // Check if the file already exists locally auto file_exists = std::filesystem::exists(path); + // Send a HEAD request to retrieve the etag and last-modified headers + struct common_load_model_from_url_headers { + std::string etag; + std::string last_modified; + std::string x_linked_etag; // SHA256 hash from HuggingFace + curl_off_t content_length = 0; // Total file size from Content-Length or x-linked-size + }; + + common_load_model_from_url_headers headers; + // If the file exists, check its JSON metadata companion file. std::string metadata_path = path + ".json"; nlohmann::json metadata; // TODO @ngxson : get rid of this json, use regex instead @@ -272,6 +392,14 @@ static bool common_download_file_single(const std::string & url, const std::stri if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) { last_modified = metadata.at("lastModified"); } + // Load the expected file size if available + if (metadata.contains("contentLength") && metadata.at("contentLength").is_number()) { + headers.content_length = metadata.at("contentLength"); + } + // Load the SHA256 hash if available + if (metadata.contains("sha256") && metadata.at("sha256").is_string()) { + headers.x_linked_etag = metadata.at("sha256"); + } } catch (const nlohmann::json::exception & e) { LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); } @@ -285,13 +413,6 @@ static bool common_download_file_single(const std::string & url, const std::stri LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); } - // Send a HEAD request to retrieve the etag and last-modified headers - struct common_load_model_from_url_headers { - std::string etag; - std::string last_modified; - }; - - common_load_model_from_url_headers headers; bool head_request_ok = false; bool should_download = !file_exists; // by default, we should download if the file does not exist @@ -328,6 +449,9 @@ static bool common_download_file_single(const std::string & url, const std::stri static std::regex header_regex("([^:]+): (.*)\r\n"); static std::regex etag_regex("ETag", std::regex_constants::icase); static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase); + static std::regex content_length_regex("Content-Length", std::regex_constants::icase); + static std::regex x_linked_size_regex("x-linked-size", std::regex_constants::icase); + static std::regex x_linked_etag_regex("x-linked-etag", std::regex_constants::icase); std::string header(buffer, n_items); std::smatch match; @@ -338,6 +462,19 @@ static bool common_download_file_single(const std::string & url, const std::stri headers->etag = value; } else if (std::regex_match(key, match, last_modified_regex)) { headers->last_modified = value; + } else if (std::regex_match(key, match, content_length_regex)) { + headers->content_length = std::stoll(value); + } else if (std::regex_match(key, match, x_linked_size_regex)) { + // HuggingFace provides file size in x-linked-size header + headers->content_length = std::stoll(value); + } else if (std::regex_match(key, match, x_linked_etag_regex)) { + // HuggingFace provides SHA256 hash in x-linked-etag header + // Remove quotes if present + std::string hash = value; + if (hash.size() >= 2 && hash.front() == '"' && hash.back() == '"') { + hash = hash.substr(1, hash.size() - 2); + } + headers->x_linked_etag = hash; } } return n_items; @@ -396,7 +533,15 @@ static bool common_download_file_single(const std::string & url, const std::stri } }; - std::unique_ptr outfile(fopen(path_temporary.c_str(), "wb")); + // Check if partial file exists from previous attempt + bool resume_download = false; + curl_off_t resume_from = 0; + if (std::filesystem::exists(path_temporary)) { + resume_from = std::filesystem::file_size(path_temporary); + resume_download = resume_from > 0; + } + + std::unique_ptr outfile(fopen(path_temporary.c_str(), resume_download ? "ab" : "wb")); if (!outfile) { LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str()); return false; @@ -410,6 +555,15 @@ static bool common_download_file_single(const std::string & url, const std::stri curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get()); + // Set resume position if we have a partial file + if (resume_download) { + curl_easy_setopt(curl.get(), CURLOPT_RESUME_FROM_LARGE, resume_from); + LOG_INF("%s: resuming download from byte %lld\n", __func__, (long long)resume_from); + } + + // Store file pointer for retry mechanism + curl_easy_setopt(curl.get(), CURLOPT_PRIVATE, outfile.get()); + // display download progress curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L); @@ -438,7 +592,11 @@ static bool common_download_file_single(const std::string & url, const std::stri long http_code = 0; curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code); - if (http_code < 200 || http_code >= 400) { + // HTTP 416 means the file is already complete (e.g., when resuming a completed download) + if (http_code == 416) { + LOG_INF("%s: file already complete (HTTP 416), treating as success\n", __func__); + // File is already complete, we can proceed + } else if (http_code < 200 || http_code >= 400) { LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code); return false; } @@ -446,11 +604,65 @@ static bool common_download_file_single(const std::string & url, const std::stri // Causes file to be closed explicitly here before we rename it. outfile.reset(); + // Verify file size if we know the expected size + if (headers.content_length > 0) { + curl_off_t actual_size = std::filesystem::file_size(path_temporary); + if (actual_size != headers.content_length) { + LOG_ERR("%s: file size mismatch: expected %lld, got %lld\n", __func__, + (long long)headers.content_length, (long long)actual_size); + return false; + } + LOG_INF("%s: file size verified: %lld bytes\n", __func__, (long long)actual_size); + } + + // Verify SHA256 if we have it (from HuggingFace x-linked-etag) + if (!headers.x_linked_etag.empty()) { + LOG_INF("%s: verifying SHA256 hash...\n", __func__); + + // Use sha256sum command to compute hash + std::string cmd = "sha256sum \"" + path_temporary + "\" 2>/dev/null"; + FILE * pipe = popen(cmd.c_str(), "r"); + if (pipe) { + char buffer[130]; // SHA256 is 64 chars + filename + space + std::string result; + if (fgets(buffer, sizeof(buffer), pipe) != nullptr) { + result = buffer; + } + pclose(pipe); + + // Extract just the hash (first 64 characters) + if (result.size() >= 64) { + std::string computed_hash = result.substr(0, 64); + + // Compare with expected hash (case-insensitive) + std::string expected = headers.x_linked_etag; + std::transform(expected.begin(), expected.end(), expected.begin(), ::tolower); + std::transform(computed_hash.begin(), computed_hash.end(), computed_hash.begin(), ::tolower); + + if (computed_hash != expected) { + LOG_ERR("%s: SHA256 hash mismatch!\n", __func__); + LOG_ERR("%s: expected: %s\n", __func__, expected.c_str()); + LOG_ERR("%s: computed: %s\n", __func__, computed_hash.c_str()); + LOG_ERR("%s: file may be corrupted, deleting: %s\n", __func__, path_temporary.c_str()); + std::filesystem::remove(path_temporary); + return false; + } + LOG_INF("%s: SHA256 hash verified: %s\n", __func__, computed_hash.c_str()); + } else { + LOG_WRN("%s: sha256sum output format unexpected, skipping verification\n", __func__); + } + } else { + LOG_WRN("%s: sha256sum command not available, skipping hash verification\n", __func__); + } + } + // Write the updated JSON metadata file. metadata.update({ {"url", url}, {"etag", headers.etag}, - {"lastModified", headers.last_modified} + {"lastModified", headers.last_modified}, + {"contentLength", headers.content_length}, + {"sha256", headers.x_linked_etag} // SHA256 hash if available (from HuggingFace) }); write_file(metadata_path, metadata.dump(4)); LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str()); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9658abf969dd2..6e5dcab465b71 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -145,6 +145,7 @@ if (NOT WIN32 OR NOT BUILD_SHARED_LIBS) llama_build_and_test(test-grammar-integration.cpp) llama_build_and_test(test-llama-grammar.cpp) llama_build_and_test(test-chat.cpp) + llama_build_and_test(test-download-resume.cpp) # TODO: disabled on loongarch64 because the ggml-ci node lacks Python 3.8 if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64") llama_build_and_test(test-json-schema-to-grammar.cpp WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}) diff --git a/tests/test-download-resume.cpp b/tests/test-download-resume.cpp new file mode 100644 index 0000000000000..ae8e9562fa68c --- /dev/null +++ b/tests/test-download-resume.cpp @@ -0,0 +1,170 @@ +// Test curl download resume functionality +#include "arg.h" +#include "common.h" + +#include +#include +#include +#include +#include + +#undef NDEBUG +#include + +// Mock server that simulates partial downloads +class MockDownloadServer { + public: + MockDownloadServer(const std::string & test_file_path, size_t total_size) : + file_path(test_file_path), + file_size(total_size) { + // Create a test file with predictable content + std::ofstream f(file_path, std::ios::binary); + for (size_t i = 0; i < file_size; i++) { + char c = 'A' + (i % 26); + f.write(&c, 1); + } + } + + ~MockDownloadServer() { + // Cleanup test file + if (std::filesystem::exists(file_path)) { + std::filesystem::remove(file_path); + } + } + + bool simulate_partial_download(const std::string & dest_path, size_t bytes_to_transfer, size_t start_offset = 0) { + std::ifstream src(file_path, std::ios::binary); + std::ofstream dst(dest_path, std::ios::binary | (start_offset > 0 ? std::ios::app : std::ios::trunc)); + + if (!src || !dst) { + return false; + } + + src.seekg(start_offset); + + char buffer[1024]; + size_t transferred = 0; + while (transferred < bytes_to_transfer && src.good()) { + size_t to_read = std::min(sizeof(buffer), bytes_to_transfer - transferred); + src.read(buffer, to_read); + size_t read = src.gcount(); + if (read > 0) { + dst.write(buffer, read); + transferred += read; + } else { + break; + } + } + + return transferred == bytes_to_transfer; + } + + bool verify_downloaded_content(const std::string & downloaded_path) { + std::ifstream original(file_path, std::ios::binary); + std::ifstream downloaded(downloaded_path, std::ios::binary); + + if (!original || !downloaded) { + return false; + } + + // Compare file sizes first + original.seekg(0, std::ios::end); + downloaded.seekg(0, std::ios::end); + if (original.tellg() != downloaded.tellg()) { + return false; + } + + // Reset to beginning + original.seekg(0); + downloaded.seekg(0); + + // Compare content + char c1, c2; + while (original.get(c1) && downloaded.get(c2)) { + if (c1 != c2) { + return false; + } + } + + return true; + } + + private: + std::string file_path; + size_t file_size; +}; + +static void test_resume_download() { + printf("Testing download resume functionality...\n"); + + const std::string test_source = "test_source.bin"; + const std::string test_dest = "test_download.bin.downloadInProgress"; + const size_t file_size = 10000; // 10KB test file + + // Create mock server with test file + MockDownloadServer server(test_source, file_size); + + // Test 1: Simulate interrupted download at 30% + printf(" Test 1: Interrupt at 30%%... "); + size_t first_chunk = file_size * 0.3; + assert(server.simulate_partial_download(test_dest, first_chunk)); + assert(std::filesystem::file_size(test_dest) == first_chunk); + printf("OK\n"); + + // Test 2: Resume download from 30% to 70% + printf(" Test 2: Resume to 70%%... "); + size_t second_chunk = file_size * 0.4; + assert(server.simulate_partial_download(test_dest, second_chunk, first_chunk)); + assert(std::filesystem::file_size(test_dest) == first_chunk + second_chunk); + printf("OK\n"); + + // Test 3: Complete the download + printf(" Test 3: Complete download... "); + size_t final_chunk = file_size - (first_chunk + second_chunk); + assert(server.simulate_partial_download(test_dest, final_chunk, first_chunk + second_chunk)); + assert(std::filesystem::file_size(test_dest) == file_size); + printf("OK\n"); + + // Test 4: Verify content integrity + printf(" Test 4: Verify integrity... "); + assert(server.verify_downloaded_content(test_dest)); + printf("OK\n"); + + // Cleanup + if (std::filesystem::exists(test_dest)) { + std::filesystem::remove(test_dest); + } + + printf("All download resume tests passed!\n"); +} + +static void test_exponential_backoff() { + printf("Testing exponential backoff calculation...\n"); + + int base_delay = 2; // 2 seconds base + + // Test the corrected exponential backoff formula + for (int attempt = 0; attempt < 3; attempt++) { + int expected = base_delay * (1 << attempt) * 1000; // 2^attempt * base * 1000ms + printf(" Attempt %d: Expected delay = %d ms\n", attempt + 1, expected); + + // These should match our fixed implementation: + // Attempt 1: 2 * 2^0 * 1000 = 2000ms + // Attempt 2: 2 * 2^1 * 1000 = 4000ms + // Attempt 3: 2 * 2^2 * 1000 = 8000ms + assert((attempt == 0 && expected == 2000) || (attempt == 1 && expected == 4000) || + (attempt == 2 && expected == 8000)); + } + + printf("Exponential backoff tests passed!\n"); +} + +int main() { + printf("test-download-resume: Testing curl download resume functionality\n\n"); + + test_resume_download(); + test_exponential_backoff(); + + printf("\nAll tests passed successfully!\n"); + return 0; +}