Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
311 changes: 213 additions & 98 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,12 @@ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int ma
}

// download one single file from remote URL to local path
static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token, bool offline) {
static bool common_download_file_single(const std::string & url,
const std::string & path,
const std::string & bearer_token,
bool offline,
bool is_docker = false) {
// Standard download logic for non-docker files
// Check if the file already exists locally
auto file_exists = std::filesystem::exists(path);

Expand Down Expand Up @@ -306,6 +311,7 @@ static bool common_download_file_single(const std::string & url, const std::stri
// Set the URL, allow to follow http redirection
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress

http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
// Check if hf-token or bearer-token was specified
Expand All @@ -321,7 +327,7 @@ static bool common_download_file_single(const std::string & url, const std::stri
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
#endif

typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
typedef size_t (*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;

Expand All @@ -332,7 +338,7 @@ static bool common_download_file_single(const std::string & url, const std::stri
std::string header(buffer, n_items);
std::smatch match;
if (std::regex_match(header, match, header_regex)) {
const std::string & key = match[1];
const std::string & key = match[1];
const std::string & value = match[2];
if (std::regex_match(key, match, etag_regex)) {
headers->etag = value;
Expand All @@ -343,8 +349,7 @@ static bool common_download_file_single(const std::string & url, const std::stri
return n_items;
};

curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);

Expand All @@ -353,115 +358,117 @@ static bool common_download_file_single(const std::string & url, const std::stri
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), 1, 0, "HEAD");
if (!was_perform_successful) {
head_request_ok = false;
}
}

long http_code = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
if (http_code == 200) {
head_request_ok = true;
} else {
LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
head_request_ok = false;
}
long http_code = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
if (http_code == 200) {
head_request_ok = true;
} else {
LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
head_request_ok = false;
}

// if head_request_ok is false, we don't have the etag or last-modified headers
// we leave should_download as-is, which is true if the file does not exist
if (head_request_ok) {
// check if ETag or Last-Modified headers are different
// if it is, we need to download the file again
if (!etag.empty() && etag != headers.etag) {
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(),
headers.etag.c_str());
should_download = true;
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__,
last_modified.c_str(), headers.last_modified.c_str());
should_download = true;
}
}

if (should_download) {
std::string path_temporary = path + ".downloadInProgress";
if (file_exists) {
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
if (remove(path.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
return false;
}
}

// if head_request_ok is false, we don't have the etag or last-modified headers
// we leave should_download as-is, which is true if the file does not exist
if (head_request_ok) {
// check if ETag or Last-Modified headers are different
// if it is, we need to download the file again
if (!etag.empty() && etag != headers.etag) {
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str());
should_download = true;
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str());
should_download = true;
}
}
// Set the output file

struct FILE_deleter {
void operator()(FILE * f) const { fclose(f); }
};

if (should_download) {
std::string path_temporary = path + ".downloadInProgress";
if (file_exists) {
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
if (remove(path.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "wb"));
if (!outfile) {
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str());
return false;
}
}

// Set the output file
typedef size_t (*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd);
auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
return fwrite(data, size, nmemb, (FILE *) fd);
};
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L);
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get());

struct FILE_deleter {
void operator()(FILE * f) const {
fclose(f);
}
};
// display download progress
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L);

std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "wb"));
if (!outfile) {
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str());
return false;
}
// helper function to hide password in URL
auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
std::size_t protocol_pos = url.find("://");
if (protocol_pos == std::string::npos) {
return url; // Malformed URL
}

typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd);
auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
return fwrite(data, size, nmemb, (FILE *)fd);
};
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L);
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get());
std::size_t at_pos = url.find('@', protocol_pos + 3);
if (at_pos == std::string::npos) {
return url; // No password in URL
}

// display download progress
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L);
return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos);
};

// helper function to hide password in URL
auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
std::size_t protocol_pos = url.find("://");
if (protocol_pos == std::string::npos) {
return url; // Malformed URL
// start the download
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n",
__func__, llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(),
headers.last_modified.c_str());
bool was_perform_successful =
curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS, "GET");
if (!was_perform_successful) {
return false;
}

std::size_t at_pos = url.find('@', protocol_pos + 3);
if (at_pos == std::string::npos) {
return url; // No password in URL
long http_code = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
if (http_code < 200 || http_code >= 400) {
LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
return false;
}

return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos);
};
// Causes file to be closed explicitly here before we rename it.
outfile.reset();

// start the download
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str());
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS, "GET");
if (!was_perform_successful) {
return false;
}
// Write the updated JSON metadata file.
metadata.update({
{ "url", url },
{ "etag", headers.etag },
{ "lastModified", headers.last_modified }
});
write_file(metadata_path, metadata.dump(4));
LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());

long http_code = 0;
curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
if (http_code < 200 || http_code >= 400) {
LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
return false;
}

// Causes file to be closed explicitly here before we rename it.
outfile.reset();

// Write the updated JSON metadata file.
metadata.update({
{"url", url},
{"etag", headers.etag},
{"lastModified", headers.last_modified}
});
write_file(metadata_path, metadata.dump(4));
LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());

if (rename(path_temporary.c_str(), path.c_str()) != 0) {
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return false;
if (rename(path_temporary.c_str(), path.c_str()) != 0) {
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return false;
}
} else {
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
}
} else {
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
}

return true;
}
Expand Down Expand Up @@ -712,7 +719,11 @@ bool common_has_curl() {
return false;
}

static bool common_download_file_single(const std::string &, const std::string &, const std::string &, bool) {
static bool common_download_file_single(const std::string &,
const std::string &,
const std::string &,
bool,
bool = false) {
LOG_ERR("error: built without CURL, cannot download model from internet\n");
return false;
}
Expand Down Expand Up @@ -745,6 +756,101 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &

#endif // LLAMA_USE_CURL

//
// Docker registry functions
//

static std::string common_docker_get_token(const std::string & repo) {
std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull";

common_remote_params params;
auto res = common_remote_get_content(url, params);

if (res.first != 200) {
throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first));
}

std::string response_str(res.second.begin(), res.second.end());
nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str);

if (!response.contains("token")) {
throw std::runtime_error("Docker registry token response missing 'token' field");
}

return response["token"].get<std::string>();
}

static std::string common_docker_resolve_model(const std::string & docker) {
// Parse ai/smollm2:135M-Q4_K_M
size_t colon_pos = docker.find(':');
std::string repo, tag;
if (colon_pos != std::string::npos) {
repo = docker.substr(0, colon_pos);
tag = docker.substr(colon_pos + 1);
} else {
repo = docker;
tag = "latest";
}

LOG_INF("Downloading Docker AI model: %s:%s\n", repo.c_str(), tag.c_str());
try {
std::string token = common_docker_get_token(repo); // Get authentication token

// Get manifest
std::string manifest_url = "https://registry-1.docker.io/v2/" + repo + "/manifests/" + tag;
common_remote_params manifest_params;
manifest_params.headers.push_back("Authorization: Bearer " + token);
manifest_params.headers.push_back(
"Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json");
auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
if (manifest_res.first != 200) {
throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
}

std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end());
nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str);
std::string gguf_digest; // Find the GGUF layer
if (manifest.contains("layers")) {
for (const auto & layer : manifest["layers"]) {
if (layer.contains("mediaType")) {
std::string media_type = layer["mediaType"].get<std::string>();
if (media_type == "application/vnd.docker.ai.gguf.v3" ||
media_type.find("gguf") != std::string::npos) {
gguf_digest = layer["digest"].get<std::string>();
break;
}
}
}
}

if (gguf_digest.empty()) {
throw std::runtime_error("No GGUF layer found in Docker manifest");
}

// Prepare local filename
std::string model_filename = repo;
std::replace(model_filename.begin(), model_filename.end(), '/', '_');
model_filename += "_" + tag + ".gguf";
std::string local_path = fs_get_cache_file(model_filename);
if (std::filesystem::exists(local_path)) { // Check if already downloaded
LOG_INF("Docker model already cached: %s\n", local_path.c_str());
return local_path;
}

// Download the blob using common_download_file_single with is_docker=true
std::string blob_url = "https://registry-1.docker.io/v2/" + repo + "/blobs/" + gguf_digest;
if (!common_download_file_single(blob_url, local_path, token, false, true)) {
throw std::runtime_error("Failed to download Docker blob");
}

LOG_INF("Downloaded Docker model to: %s\n", local_path.c_str());
return local_path;
} catch (const std::exception & e) {
LOG_ERR("Docker model download failed: %s\n", e.what());
throw;
}
}

//
// utils
//
Expand Down Expand Up @@ -795,7 +901,9 @@ static handle_model_result common_params_handle_model(
handle_model_result result;
// handle pre-fill default model path and url based on hf_repo and hf_file
{
if (!model.hf_repo.empty()) {
if (!model.docker.empty()) { // Handle Docker URLs by resolving them to local paths
model.path = common_docker_resolve_model(model.docker);
} else if (!model.hf_repo.empty()) {
// short-hand to avoid specifying --hf-file -> default it to --model
if (model.hf_file.empty()) {
if (model.path.empty()) {
Expand Down Expand Up @@ -2636,6 +2744,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.model.url = value;
}
).set_env("LLAMA_ARG_MODEL_URL"));
add_opt(common_arg(
{ "-d", "-dr", "--docker", "--docker-repo" }, "<repo>/<model>[:quant]",
"Docker Hub model repository; quant is optional, default to latest.\n"
"example: ai/smollm2:135M-Q4_K_M\n"
"(default: unused)",
[](common_params & params, const std::string & value) { params.model.docker = value; })
.set_env("LLAMA_ARG_DOCKER"));
add_opt(common_arg(
{"-hf", "-hfr", "--hf-repo"}, "<user>/<model>[:quant]",
"Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n"
Expand Down
Loading