diff --git a/common/arg.cpp b/common/arg.cpp index 406fbc2f06fe4..82510faff2ddb 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -245,7 +245,12 @@ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int ma } // download one single file from remote URL to local path -static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token, bool offline) { +static bool common_download_file_single(const std::string & url, + const std::string & path, + const std::string & bearer_token, + bool offline, + bool is_docker = false) { + // Standard download logic for non-docker files // Check if the file already exists locally auto file_exists = std::filesystem::exists(path); @@ -306,6 +311,7 @@ static bool common_download_file_single(const std::string & url, const std::stri // Set the URL, allow to follow http redirection curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); // Check if hf-token or bearer-token was specified @@ -321,7 +327,7 @@ static bool common_download_file_single(const std::string & url, const std::stri curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); #endif - typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *); + typedef size_t (*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *); auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t { common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata; @@ -332,7 +338,7 @@ static bool common_download_file_single(const std::string & url, const std::stri std::string header(buffer, n_items); std::smatch match; if (std::regex_match(header, match, header_regex)) { - const std::string & key = match[1]; + const std::string & key = match[1]; const std::string & value = match[2]; if (std::regex_match(key, match, etag_regex)) { headers->etag = value; @@ -343,8 +349,7 @@ static bool common_download_file_single(const std::string & url, const std::stri return n_items; }; - curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb - curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress + curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast(header_callback)); curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers); @@ -353,115 +358,117 @@ static bool common_download_file_single(const std::string & url, const std::stri bool was_perform_successful = curl_perform_with_retry(url, curl.get(), 1, 0, "HEAD"); if (!was_perform_successful) { head_request_ok = false; - } + } - long http_code = 0; - curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code); - if (http_code == 200) { - head_request_ok = true; - } else { - LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code); - head_request_ok = false; - } + long http_code = 0; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code); + if (http_code == 200) { + head_request_ok = true; + } else { + LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code); + head_request_ok = false; + } + + // if head_request_ok is false, we don't have the etag or last-modified headers + // we leave should_download as-is, which is true if the file does not exist + if (head_request_ok) { + // check if ETag or Last-Modified headers are different + // if it is, we need to download the file again + if (!etag.empty() && etag != headers.etag) { + LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), + headers.etag.c_str()); + should_download = true; + } else if (!last_modified.empty() && last_modified != headers.last_modified) { + LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, + last_modified.c_str(), headers.last_modified.c_str()); + should_download = true; + } + } + + if (should_download) { + std::string path_temporary = path + ".downloadInProgress"; + if (file_exists) { + LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str()); + if (remove(path.c_str()) != 0) { + LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str()); + return false; + } + } - // if head_request_ok is false, we don't have the etag or last-modified headers - // we leave should_download as-is, which is true if the file does not exist - if (head_request_ok) { - // check if ETag or Last-Modified headers are different - // if it is, we need to download the file again - if (!etag.empty() && etag != headers.etag) { - LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str()); - should_download = true; - } else if (!last_modified.empty() && last_modified != headers.last_modified) { - LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str()); - should_download = true; - } - } + // Set the output file + + struct FILE_deleter { + void operator()(FILE * f) const { fclose(f); } + }; - if (should_download) { - std::string path_temporary = path + ".downloadInProgress"; - if (file_exists) { - LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str()); - if (remove(path.c_str()) != 0) { - LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str()); + std::unique_ptr outfile(fopen(path_temporary.c_str(), "wb")); + if (!outfile) { + LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str()); return false; } - } - // Set the output file + typedef size_t (*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd); + auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t { + return fwrite(data, size, nmemb, (FILE *) fd); + }; + curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L); + curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); + curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get()); - struct FILE_deleter { - void operator()(FILE * f) const { - fclose(f); - } - }; + // display download progress + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L); - std::unique_ptr outfile(fopen(path_temporary.c_str(), "wb")); - if (!outfile) { - LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str()); - return false; - } + // helper function to hide password in URL + auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string { + std::size_t protocol_pos = url.find("://"); + if (protocol_pos == std::string::npos) { + return url; // Malformed URL + } - typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd); - auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t { - return fwrite(data, size, nmemb, (FILE *)fd); - }; - curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L); - curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); - curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get()); + std::size_t at_pos = url.find('@', protocol_pos + 3); + if (at_pos == std::string::npos) { + return url; // No password in URL + } - // display download progress - curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L); + return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos); + }; - // helper function to hide password in URL - auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string { - std::size_t protocol_pos = url.find("://"); - if (protocol_pos == std::string::npos) { - return url; // Malformed URL + // start the download + LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", + __func__, llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), + headers.last_modified.c_str()); + bool was_perform_successful = + curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS, "GET"); + if (!was_perform_successful) { + return false; } - std::size_t at_pos = url.find('@', protocol_pos + 3); - if (at_pos == std::string::npos) { - return url; // No password in URL + long http_code = 0; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code); + if (http_code < 200 || http_code >= 400) { + LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code); + return false; } - return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos); - }; + // Causes file to be closed explicitly here before we rename it. + outfile.reset(); - // start the download - LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__, - llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str()); - bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS, "GET"); - if (!was_perform_successful) { - return false; - } + // Write the updated JSON metadata file. + metadata.update({ + { "url", url }, + { "etag", headers.etag }, + { "lastModified", headers.last_modified } + }); + write_file(metadata_path, metadata.dump(4)); + LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str()); - long http_code = 0; - curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code); - if (http_code < 200 || http_code >= 400) { - LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code); - return false; - } - - // Causes file to be closed explicitly here before we rename it. - outfile.reset(); - - // Write the updated JSON metadata file. - metadata.update({ - {"url", url}, - {"etag", headers.etag}, - {"lastModified", headers.last_modified} - }); - write_file(metadata_path, metadata.dump(4)); - LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str()); - - if (rename(path_temporary.c_str(), path.c_str()) != 0) { - LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); - return false; + if (rename(path_temporary.c_str(), path.c_str()) != 0) { + LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); + return false; + } + } else { + LOG_INF("%s: using cached file: %s\n", __func__, path.c_str()); } - } else { - LOG_INF("%s: using cached file: %s\n", __func__, path.c_str()); - } return true; } @@ -712,7 +719,11 @@ bool common_has_curl() { return false; } -static bool common_download_file_single(const std::string &, const std::string &, const std::string &, bool) { +static bool common_download_file_single(const std::string &, + const std::string &, + const std::string &, + bool, + bool = false) { LOG_ERR("error: built without CURL, cannot download model from internet\n"); return false; } @@ -745,6 +756,101 @@ std::pair> common_remote_get_content(const std::string & #endif // LLAMA_USE_CURL +// +// Docker registry functions +// + +static std::string common_docker_get_token(const std::string & repo) { + std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull"; + + common_remote_params params; + auto res = common_remote_get_content(url, params); + + if (res.first != 200) { + throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first)); + } + + std::string response_str(res.second.begin(), res.second.end()); + nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str); + + if (!response.contains("token")) { + throw std::runtime_error("Docker registry token response missing 'token' field"); + } + + return response["token"].get(); +} + +static std::string common_docker_resolve_model(const std::string & docker) { + // Parse ai/smollm2:135M-Q4_K_M + size_t colon_pos = docker.find(':'); + std::string repo, tag; + if (colon_pos != std::string::npos) { + repo = docker.substr(0, colon_pos); + tag = docker.substr(colon_pos + 1); + } else { + repo = docker; + tag = "latest"; + } + + LOG_INF("Downloading Docker AI model: %s:%s\n", repo.c_str(), tag.c_str()); + try { + std::string token = common_docker_get_token(repo); // Get authentication token + + // Get manifest + std::string manifest_url = "https://registry-1.docker.io/v2/" + repo + "/manifests/" + tag; + common_remote_params manifest_params; + manifest_params.headers.push_back("Authorization: Bearer " + token); + manifest_params.headers.push_back( + "Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json"); + auto manifest_res = common_remote_get_content(manifest_url, manifest_params); + if (manifest_res.first != 200) { + throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first)); + } + + std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end()); + nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str); + std::string gguf_digest; // Find the GGUF layer + if (manifest.contains("layers")) { + for (const auto & layer : manifest["layers"]) { + if (layer.contains("mediaType")) { + std::string media_type = layer["mediaType"].get(); + if (media_type == "application/vnd.docker.ai.gguf.v3" || + media_type.find("gguf") != std::string::npos) { + gguf_digest = layer["digest"].get(); + break; + } + } + } + } + + if (gguf_digest.empty()) { + throw std::runtime_error("No GGUF layer found in Docker manifest"); + } + + // Prepare local filename + std::string model_filename = repo; + std::replace(model_filename.begin(), model_filename.end(), '/', '_'); + model_filename += "_" + tag + ".gguf"; + std::string local_path = fs_get_cache_file(model_filename); + if (std::filesystem::exists(local_path)) { // Check if already downloaded + LOG_INF("Docker model already cached: %s\n", local_path.c_str()); + return local_path; + } + + // Download the blob using common_download_file_single with is_docker=true + std::string blob_url = "https://registry-1.docker.io/v2/" + repo + "/blobs/" + gguf_digest; + if (!common_download_file_single(blob_url, local_path, token, false, true)) { + throw std::runtime_error("Failed to download Docker blob"); + } + + LOG_INF("Downloaded Docker model to: %s\n", local_path.c_str()); + return local_path; + } catch (const std::exception & e) { + LOG_ERR("Docker model download failed: %s\n", e.what()); + throw; + } +} + // // utils // @@ -795,7 +901,9 @@ static handle_model_result common_params_handle_model( handle_model_result result; // handle pre-fill default model path and url based on hf_repo and hf_file { - if (!model.hf_repo.empty()) { + if (!model.docker.empty()) { // Handle Docker URLs by resolving them to local paths + model.path = common_docker_resolve_model(model.docker); + } else if (!model.hf_repo.empty()) { // short-hand to avoid specifying --hf-file -> default it to --model if (model.hf_file.empty()) { if (model.path.empty()) { @@ -2636,6 +2744,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.model.url = value; } ).set_env("LLAMA_ARG_MODEL_URL")); + add_opt(common_arg( + { "-d", "-dr", "--docker", "--docker-repo" }, "/[:quant]", + "Docker Hub model repository; quant is optional, default to latest.\n" + "example: ai/smollm2:135M-Q4_K_M\n" + "(default: unused)", + [](common_params & params, const std::string & value) { params.model.docker = value; }) + .set_env("LLAMA_ARG_DOCKER")); add_opt(common_arg( {"-hf", "-hfr", "--hf-repo"}, "/[:quant]", "Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n" diff --git a/common/common.h b/common/common.h index 85b3b879d4536..764015d77f136 100644 --- a/common/common.h +++ b/common/common.h @@ -197,6 +197,7 @@ struct common_params_model { std::string url = ""; // model url to download // NOLINT std::string hf_repo = ""; // HF repo // NOLINT std::string hf_file = ""; // HF file // NOLINT + std::string docker = ""; // Docker AI model url to download // NOLINT }; struct common_params_speculative {