diff --git a/CODEOWNERS b/CODEOWNERS index 6a6468fc27d95..89b84ce850640 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -14,6 +14,7 @@ /common/build-info.* @ggerganov /common/common.* @ggerganov /common/console.* @ggerganov +/common/http.* @angt /common/llguidance.* @ggerganov /common/log.* @ggerganov /common/sampling.* @ggerganov diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 8ab3d445104a7..fe290bf8fdda4 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -56,6 +56,7 @@ add_library(${TARGET} STATIC common.h console.cpp console.h + http.h json-partial.cpp json-partial.h json-schema-to-grammar.cpp diff --git a/common/arg.cpp b/common/arg.cpp index 8da74f909764b..cbca8b5ac5abb 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -32,13 +32,11 @@ #include #include -//#define LLAMA_USE_CURL - #if defined(LLAMA_USE_CURL) #include #include #else -#include +#include "http.h" #endif #ifdef __linux__ @@ -596,77 +594,6 @@ std::pair> common_remote_get_content(const std::string & #else -struct common_url { - std::string scheme; - std::string user; - std::string password; - std::string host; - std::string path; -}; - -static common_url parse_url(const std::string & url) { - common_url parts; - auto scheme_end = url.find("://"); - - if (scheme_end == std::string::npos) { - throw std::runtime_error("invalid URL: no scheme"); - } - parts.scheme = url.substr(0, scheme_end); - - if (parts.scheme != "http" && parts.scheme != "https") { - throw std::runtime_error("unsupported URL scheme: " + parts.scheme); - } - - auto rest = url.substr(scheme_end + 3); - auto at_pos = rest.find('@'); - - if (at_pos != std::string::npos) { - auto auth = rest.substr(0, at_pos); - auto colon_pos = auth.find(':'); - if (colon_pos != std::string::npos) { - parts.user = auth.substr(0, colon_pos); - parts.password = auth.substr(colon_pos + 1); - } else { - parts.user = auth; - } - rest = rest.substr(at_pos + 1); - } - - auto slash_pos = rest.find('/'); - - if (slash_pos != std::string::npos) { - parts.host = rest.substr(0, slash_pos); - parts.path = rest.substr(slash_pos); - } else { - parts.host = rest; - parts.path = "/"; - } - return parts; -} - -static std::pair http_client(const std::string & url) { - common_url parts = parse_url(url); - - if (parts.host.empty()) { - throw std::runtime_error("error: invalid URL format"); - } - - if (!parts.user.empty()) { - throw std::runtime_error("error: user:password@ not supported yet"); // TODO - } - - httplib::Client cli(parts.scheme + "://" + parts.host); - cli.set_follow_location(true); - - // TODO cert - - return { std::move(cli), std::move(parts) }; -} - -static std::string show_masked_url(const common_url & parts) { - return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path; -} - static void print_progress(size_t current, size_t total) { if (!is_output_a_tty()) { return; @@ -759,7 +686,7 @@ static bool common_download_file_single_online(const std::string & url, static const int max_attempts = 3; static const int retry_delay_seconds = 2; - auto [cli, parts] = http_client(url); + auto [cli, parts] = common_http_client(url); httplib::Headers default_headers = {{"User-Agent", "llama-cpp"}}; if (!bearer_token.empty()) { @@ -839,7 +766,7 @@ static bool common_download_file_single_online(const std::string & url, // start the download LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n", - __func__, show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str()); + __func__, common_http_show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str()); const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size); if (!was_pull_successful) { if (i + 1 < max_attempts) { @@ -867,7 +794,7 @@ static bool common_download_file_single_online(const std::string & url, std::pair> common_remote_get_content(const std::string & url, const common_remote_params & params) { - auto [cli, parts] = http_client(url); + auto [cli, parts] = common_http_client(url); httplib::Headers headers = {{"User-Agent", "llama-cpp"}}; for (const auto & header : params.headers) { diff --git a/common/http.h b/common/http.h new file mode 100644 index 0000000000000..8e29787dcc6f7 --- /dev/null +++ b/common/http.h @@ -0,0 +1,73 @@ +#pragma once + +#include + +struct common_http_url { + std::string scheme; + std::string user; + std::string password; + std::string host; + std::string path; +}; + +static common_http_url common_http_parse_url(const std::string & url) { + common_http_url parts; + auto scheme_end = url.find("://"); + + if (scheme_end == std::string::npos) { + throw std::runtime_error("invalid URL: no scheme"); + } + parts.scheme = url.substr(0, scheme_end); + + if (parts.scheme != "http" && parts.scheme != "https") { + throw std::runtime_error("unsupported URL scheme: " + parts.scheme); + } + + auto rest = url.substr(scheme_end + 3); + auto at_pos = rest.find('@'); + + if (at_pos != std::string::npos) { + auto auth = rest.substr(0, at_pos); + auto colon_pos = auth.find(':'); + if (colon_pos != std::string::npos) { + parts.user = auth.substr(0, colon_pos); + parts.password = auth.substr(colon_pos + 1); + } else { + parts.user = auth; + } + rest = rest.substr(at_pos + 1); + } + + auto slash_pos = rest.find('/'); + + if (slash_pos != std::string::npos) { + parts.host = rest.substr(0, slash_pos); + parts.path = rest.substr(slash_pos); + } else { + parts.host = rest; + parts.path = "/"; + } + return parts; +} + +static std::pair common_http_client(const std::string & url) { + common_http_url parts = common_http_parse_url(url); + + if (parts.host.empty()) { + throw std::runtime_error("error: invalid URL format"); + } + + httplib::Client cli(parts.scheme + "://" + parts.host); + + if (!parts.user.empty()) { + cli.set_basic_auth(parts.user, parts.password); + } + + cli.set_follow_location(true); + + return { std::move(cli), std::move(parts) }; +} + +static std::string common_http_show_masked_url(const common_http_url & parts) { + return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path; +} diff --git a/tools/run/run.cpp b/tools/run/run.cpp index 772d66c921caf..b90a7253c4346 100644 --- a/tools/run/run.cpp +++ b/tools/run/run.cpp @@ -9,6 +9,7 @@ #include #if defined(_WIN32) +# define WIN32_LEAN_AND_MEAN # ifndef NOMINMAX # define NOMINMAX # endif @@ -22,6 +23,8 @@ #if defined(LLAMA_USE_CURL) # include +#else +# include "http.h" #endif #include @@ -397,7 +400,6 @@ class File { # endif }; -#ifdef LLAMA_USE_CURL class HttpClient { public: int init(const std::string & url, const std::vector & headers, const std::string & output_file, @@ -428,6 +430,8 @@ class HttpClient { return 0; } +#ifdef LLAMA_USE_CURL + ~HttpClient() { if (chunk) { curl_slist_free_all(chunk); @@ -532,6 +536,117 @@ class HttpClient { return curl_easy_perform(curl); } +#else // LLAMA_USE_CURL is not defined + +#define curl_off_t long long // temporary hack + + private: + // this is a direct translation of the cURL download() above + int download(const std::string & url, const std::vector & headers_vec, const std::string & output_file, + const bool progress, std::string * response_str = nullptr) { + try { + auto [cli, url_parts] = common_http_client(url); + + httplib::Headers headers; + for (const auto & h : headers_vec) { + size_t pos = h.find(':'); + if (pos != std::string::npos) { + headers.emplace(h.substr(0, pos), h.substr(pos + 2)); + } + } + + File out; + if (!output_file.empty()) { + if (!out.open(output_file, "ab")) { + printe("Failed to open file for writing\n"); + return 1; + } + if (out.lock()) { + printe("Failed to exclusively lock file\n"); + return 1; + } + } + + size_t resume_offset = 0; + if (!output_file.empty() && std::filesystem::exists(output_file)) { + resume_offset = std::filesystem::file_size(output_file); + if (resume_offset > 0) { + headers.emplace("Range", "bytes=" + std::to_string(resume_offset) + "-"); + } + } + + progress_data data; + data.file_size = resume_offset; + + long long total_size = 0; + long long received_this_session = 0; + + auto response_handler = + [&](const httplib::Response & response) { + if (resume_offset > 0 && response.status != 206) { + printe("\nServer does not support resuming. Restarting download.\n"); + out.file = freopen(output_file.c_str(), "wb", out.file); + if (!out.file) { + return false; + } + data.file_size = 0; + } + if (progress) { + if (response.has_header("Content-Length")) { + total_size = std::stoll(response.get_header_value("Content-Length")); + } else if (response.has_header("Content-Range")) { + auto range = response.get_header_value("Content-Range"); + auto slash = range.find('/'); + if (slash != std::string::npos) { + total_size = std::stoll(range.substr(slash + 1)); + } + } + } + return true; + }; + + auto content_receiver = + [&](const char * chunk, size_t length) { + if (out.file && fwrite(chunk, 1, length, out.file) != length) { + return false; + } + if (response_str) { + response_str->append(chunk, length); + } + received_this_session += length; + + if (progress && total_size > 0) { + update_progress(&data, total_size, received_this_session, 0, 0); + } + return true; + }; + + auto res = cli.Get(url_parts.path, headers, response_handler, content_receiver); + + if (data.printed) { + printe("\n"); + } + + if (!res) { + auto err = res.error(); + printe("Fetching resource '%s' failed: %s\n", url.c_str(), httplib::to_string(err).c_str()); + return 1; + } + + if (res->status >= 400) { + printe("Fetching resource '%s' failed with status code: %d\n", url.c_str(), res->status); + return 1; + } + + } catch (const std::exception & e) { + printe("HTTP request failed: %s\n", e.what()); + return 1; + } + return 0; + } + +#endif // LLAMA_USE_CURL + static std::string human_readable_time(double seconds) { int hrs = static_cast(seconds) / 3600; int mins = (static_cast(seconds) % 3600) / 60; @@ -644,8 +759,8 @@ class HttpClient { str->append(static_cast(ptr), size * nmemb); return size * nmemb; } + }; -#endif class LlamaData { public: @@ -673,7 +788,6 @@ class LlamaData { } private: -#ifdef LLAMA_USE_CURL int download(const std::string & url, const std::string & output_file, const bool progress, const std::vector & headers = {}, std::string * response_str = nullptr) { HttpClient http; @@ -683,14 +797,6 @@ class LlamaData { return 0; } -#else - int download(const std::string &, const std::string &, const bool, const std::vector & = {}, - std::string * = nullptr) { - printe("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__); - - return 1; - } -#endif // Helper function to handle model tag extraction and URL construction std::pair extract_model_and_tag(std::string & model, const std::string & base_url) {