Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
/common/build-info.* @ggerganov
/common/common.* @ggerganov
/common/console.* @ggerganov
/common/http.* @angt
/common/llguidance.* @ggerganov
/common/log.* @ggerganov
/common/sampling.* @ggerganov
Expand Down
1 change: 1 addition & 0 deletions common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ add_library(${TARGET} STATIC
common.h
console.cpp
console.h
http.h
json-partial.cpp
json-partial.h
json-schema-to-grammar.cpp
Expand Down
81 changes: 4 additions & 77 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,11 @@
#include <thread>
#include <vector>

//#define LLAMA_USE_CURL

#if defined(LLAMA_USE_CURL)
#include <curl/curl.h>
#include <curl/easy.h>
#else
#include <cpp-httplib/httplib.h>
#include "http.h"
#endif

#ifdef __linux__
Expand Down Expand Up @@ -596,77 +594,6 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &

#else

struct common_url {
std::string scheme;
std::string user;
std::string password;
std::string host;
std::string path;
};

static common_url parse_url(const std::string & url) {
common_url parts;
auto scheme_end = url.find("://");

if (scheme_end == std::string::npos) {
throw std::runtime_error("invalid URL: no scheme");
}
parts.scheme = url.substr(0, scheme_end);

if (parts.scheme != "http" && parts.scheme != "https") {
throw std::runtime_error("unsupported URL scheme: " + parts.scheme);
}

auto rest = url.substr(scheme_end + 3);
auto at_pos = rest.find('@');

if (at_pos != std::string::npos) {
auto auth = rest.substr(0, at_pos);
auto colon_pos = auth.find(':');
if (colon_pos != std::string::npos) {
parts.user = auth.substr(0, colon_pos);
parts.password = auth.substr(colon_pos + 1);
} else {
parts.user = auth;
}
rest = rest.substr(at_pos + 1);
}

auto slash_pos = rest.find('/');

if (slash_pos != std::string::npos) {
parts.host = rest.substr(0, slash_pos);
parts.path = rest.substr(slash_pos);
} else {
parts.host = rest;
parts.path = "/";
}
return parts;
}

static std::pair<httplib::Client, common_url> http_client(const std::string & url) {
common_url parts = parse_url(url);

if (parts.host.empty()) {
throw std::runtime_error("error: invalid URL format");
}

if (!parts.user.empty()) {
throw std::runtime_error("error: user:password@ not supported yet"); // TODO
}

httplib::Client cli(parts.scheme + "://" + parts.host);
cli.set_follow_location(true);

// TODO cert

return { std::move(cli), std::move(parts) };
}

static std::string show_masked_url(const common_url & parts) {
return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path;
}

static void print_progress(size_t current, size_t total) {
if (!is_output_a_tty()) {
return;
Expand Down Expand Up @@ -759,7 +686,7 @@ static bool common_download_file_single_online(const std::string & url,
static const int max_attempts = 3;
static const int retry_delay_seconds = 2;

auto [cli, parts] = http_client(url);
auto [cli, parts] = common_http_client(url);

httplib::Headers default_headers = {{"User-Agent", "llama-cpp"}};
if (!bearer_token.empty()) {
Expand Down Expand Up @@ -839,7 +766,7 @@ static bool common_download_file_single_online(const std::string & url,

// start the download
LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n",
__func__, show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str());
__func__, common_http_show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str());
const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size);
if (!was_pull_successful) {
if (i + 1 < max_attempts) {
Expand Down Expand Up @@ -867,7 +794,7 @@ static bool common_download_file_single_online(const std::string & url,

std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url,
const common_remote_params & params) {
auto [cli, parts] = http_client(url);
auto [cli, parts] = common_http_client(url);

httplib::Headers headers = {{"User-Agent", "llama-cpp"}};
for (const auto & header : params.headers) {
Expand Down
73 changes: 73 additions & 0 deletions common/http.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#pragma once

#include <cpp-httplib/httplib.h>

struct common_http_url {
std::string scheme;
std::string user;
std::string password;
std::string host;
std::string path;
};

static common_http_url common_http_parse_url(const std::string & url) {
common_http_url parts;
auto scheme_end = url.find("://");

if (scheme_end == std::string::npos) {
throw std::runtime_error("invalid URL: no scheme");
}
parts.scheme = url.substr(0, scheme_end);

if (parts.scheme != "http" && parts.scheme != "https") {
throw std::runtime_error("unsupported URL scheme: " + parts.scheme);
}

auto rest = url.substr(scheme_end + 3);
auto at_pos = rest.find('@');

if (at_pos != std::string::npos) {
auto auth = rest.substr(0, at_pos);
auto colon_pos = auth.find(':');
if (colon_pos != std::string::npos) {
parts.user = auth.substr(0, colon_pos);
parts.password = auth.substr(colon_pos + 1);
} else {
parts.user = auth;
}
rest = rest.substr(at_pos + 1);
}

auto slash_pos = rest.find('/');

if (slash_pos != std::string::npos) {
parts.host = rest.substr(0, slash_pos);
parts.path = rest.substr(slash_pos);
} else {
parts.host = rest;
parts.path = "/";
}
return parts;
}

static std::pair<httplib::Client, common_http_url> common_http_client(const std::string & url) {
common_http_url parts = common_http_parse_url(url);

if (parts.host.empty()) {
throw std::runtime_error("error: invalid URL format");
}

httplib::Client cli(parts.scheme + "://" + parts.host);

if (!parts.user.empty()) {
cli.set_basic_auth(parts.user, parts.password);
}

cli.set_follow_location(true);

return { std::move(cli), std::move(parts) };
}

static std::string common_http_show_masked_url(const common_http_url & parts) {
return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path;
}
128 changes: 117 additions & 11 deletions tools/run/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <nlohmann/json.hpp>

#if defined(_WIN32)
# define WIN32_LEAN_AND_MEAN
# ifndef NOMINMAX
# define NOMINMAX
# endif
Expand All @@ -22,6 +23,8 @@

#if defined(LLAMA_USE_CURL)
# include <curl/curl.h>
#else
# include "http.h"
#endif

#include <signal.h>
Expand Down Expand Up @@ -397,7 +400,6 @@ class File {
# endif
};

#ifdef LLAMA_USE_CURL
class HttpClient {
public:
int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
Expand Down Expand Up @@ -428,6 +430,8 @@ class HttpClient {
return 0;
}

#ifdef LLAMA_USE_CURL

~HttpClient() {
if (chunk) {
curl_slist_free_all(chunk);
Expand Down Expand Up @@ -532,6 +536,117 @@ class HttpClient {
return curl_easy_perform(curl);
}

#else // LLAMA_USE_CURL is not defined

#define curl_off_t long long // temporary hack

private:
// this is a direct translation of the cURL download() above
int download(const std::string & url, const std::vector<std::string> & headers_vec, const std::string & output_file,
const bool progress, std::string * response_str = nullptr) {
try {
auto [cli, url_parts] = common_http_client(url);

httplib::Headers headers;
for (const auto & h : headers_vec) {
size_t pos = h.find(':');
if (pos != std::string::npos) {
headers.emplace(h.substr(0, pos), h.substr(pos + 2));
}
}

File out;
if (!output_file.empty()) {
if (!out.open(output_file, "ab")) {
printe("Failed to open file for writing\n");
return 1;
}
if (out.lock()) {
printe("Failed to exclusively lock file\n");
return 1;
}
}

size_t resume_offset = 0;
if (!output_file.empty() && std::filesystem::exists(output_file)) {
resume_offset = std::filesystem::file_size(output_file);
if (resume_offset > 0) {
headers.emplace("Range", "bytes=" + std::to_string(resume_offset) + "-");
}
}

progress_data data;
data.file_size = resume_offset;

long long total_size = 0;
long long received_this_session = 0;

auto response_handler =
[&](const httplib::Response & response) {
if (resume_offset > 0 && response.status != 206) {
printe("\nServer does not support resuming. Restarting download.\n");
out.file = freopen(output_file.c_str(), "wb", out.file);
if (!out.file) {
return false;
}
data.file_size = 0;
}
if (progress) {
if (response.has_header("Content-Length")) {
total_size = std::stoll(response.get_header_value("Content-Length"));
} else if (response.has_header("Content-Range")) {
auto range = response.get_header_value("Content-Range");
auto slash = range.find('/');
if (slash != std::string::npos) {
total_size = std::stoll(range.substr(slash + 1));
}
}
}
return true;
};

auto content_receiver =
[&](const char * chunk, size_t length) {
if (out.file && fwrite(chunk, 1, length, out.file) != length) {
return false;
}
if (response_str) {
response_str->append(chunk, length);
}
received_this_session += length;

if (progress && total_size > 0) {
update_progress(&data, total_size, received_this_session, 0, 0);
}
return true;
};

auto res = cli.Get(url_parts.path, headers, response_handler, content_receiver);

if (data.printed) {
printe("\n");
}

if (!res) {
auto err = res.error();
printe("Fetching resource '%s' failed: %s\n", url.c_str(), httplib::to_string(err).c_str());
return 1;
}

if (res->status >= 400) {
printe("Fetching resource '%s' failed with status code: %d\n", url.c_str(), res->status);
return 1;
}

} catch (const std::exception & e) {
printe("HTTP request failed: %s\n", e.what());
return 1;
}
return 0;
}

#endif // LLAMA_USE_CURL

static std::string human_readable_time(double seconds) {
int hrs = static_cast<int>(seconds) / 3600;
int mins = (static_cast<int>(seconds) % 3600) / 60;
Expand Down Expand Up @@ -644,8 +759,8 @@ class HttpClient {
str->append(static_cast<char *>(ptr), size * nmemb);
return size * nmemb;
}

};
#endif

class LlamaData {
public:
Expand Down Expand Up @@ -673,7 +788,6 @@ class LlamaData {
}

private:
#ifdef LLAMA_USE_CURL
int download(const std::string & url, const std::string & output_file, const bool progress,
const std::vector<std::string> & headers = {}, std::string * response_str = nullptr) {
HttpClient http;
Expand All @@ -683,14 +797,6 @@ class LlamaData {

return 0;
}
#else
int download(const std::string &, const std::string &, const bool, const std::vector<std::string> & = {},
std::string * = nullptr) {
printe("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);

return 1;
}
#endif

// Helper function to handle model tag extraction and URL construction
std::pair<std::string, std::string> extract_model_and_tag(std::string & model, const std::string & base_url) {
Expand Down
Loading