diff --git a/common/arg.cpp b/common/arg.cpp index fcee0c4470077..032a9e76e6f83 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -244,8 +244,101 @@ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int ma return false; } +struct FILE_deleter { + void operator()(FILE * f) const { fclose(f); } +}; + // download one single file from remote URL to local path -static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token, bool offline) { +static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token, bool offline, bool is_docker = false) { + // For docker downloads, use simplified logic without caching/metadata + if (is_docker) { + // Check if the file already exists locally (simple existence check for docker) + if (std::filesystem::exists(path)) { + LOG_INF("%s: docker file already cached: %s\n", __func__, path.c_str()); + return true; + } + + if (offline) { + LOG_ERR("%s: required docker file is not available in cache (offline mode): %s\n", __func__, path.c_str()); + return false; + } + + // For docker downloads, proceed directly to download without HEAD requests or metadata + std::string path_temporary = path + ".tmp"; + + // Remove any existing temporary file + if (std::filesystem::exists(path_temporary)) { + std::filesystem::remove(path_temporary); + } + + std::unique_ptr outfile(fopen(path_temporary.c_str(), "wb")); + if (!outfile) { + LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_temporary.c_str()); + return false; + } + + // Initialize libcurl + curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); + curl_slist_ptr http_headers; + if (!curl) { + LOG_ERR("%s: error initializing libcurl\n", __func__); + return false; + } + + // Set the URL and options + curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); + + http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); + if (!bearer_token.empty()) { + std::string auth_header = "Authorization: Bearer " + bearer_token; + http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); + } + curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); + +#if defined(_WIN32) + curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); +#endif + + typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd); + auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t { + return fwrite(data, size, nmemb, (FILE *)fd); + }; + curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); + curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get()); + + // Perform the download + CURLcode res = curl_easy_perform(curl.get()); + if (res != CURLE_OK) { + LOG_ERR("%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res)); + outfile.reset(); + std::filesystem::remove(path_temporary); + return false; + } + + // Check HTTP response code + long response_code; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &response_code); + if (response_code != 200) { + LOG_ERR("%s: HTTP error %ld\n", __func__, response_code); + outfile.reset(); + std::filesystem::remove(path_temporary); + return false; + } + + // Close the file and move to final location + outfile.reset(); + + if (std::filesystem::exists(path)) { + std::filesystem::remove(path); + } + + std::filesystem::rename(path_temporary, path); + return true; + } + + // Standard download logic for non-docker files // Check if the file already exists locally auto file_exists = std::filesystem::exists(path); @@ -389,13 +482,6 @@ static bool common_download_file_single(const std::string & url, const std::stri } // Set the output file - - struct FILE_deleter { - void operator()(FILE * f) const { - fclose(f); - } - }; - std::unique_ptr outfile(fopen(path_temporary.c_str(), "wb")); if (!outfile) { LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str()); @@ -473,7 +559,7 @@ static bool common_download_file_multiple(const std::vector> futures_download; for (auto const & item : urls) { futures_download.push_back(std::async(std::launch::async, [bearer_token, offline](const std::pair & it) -> bool { - return common_download_file_single(it.first, it.second, bearer_token, offline); + return common_download_file_single(it.first, it.second, bearer_token, offline, false); }, item)); } @@ -497,7 +583,7 @@ static bool common_download_model( return false; } - if (!common_download_file_single(model.url, model.path, bearer_token, offline)) { + if (!common_download_file_single(model.url, model.path, bearer_token, offline, false)) { return false; } @@ -712,7 +798,7 @@ bool common_has_curl() { return false; } -static bool common_download_file_single(const std::string &, const std::string &, const std::string &, bool) { +static bool common_download_file_single(const std::string &, const std::string &, const std::string &, bool, bool = false) { LOG_ERR("error: built without CURL, cannot download model from internet\n"); return false; } @@ -745,6 +831,108 @@ std::pair> common_remote_get_content(const std::string & #endif // LLAMA_USE_CURL +// +// Docker registry functions +// + +std::string common_docker_get_token(const std::string & repo) { + std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull"; + + common_remote_params params; + auto res = common_remote_get_content(url, params); + + if (res.first != 200) { + throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first)); + } + + std::string response_str(res.second.begin(), res.second.end()); + nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str); + + if (!response.contains("token")) { + throw std::runtime_error("Docker registry token response missing 'token' field"); + } + + return response["token"].get(); +} + +std::string common_docker_resolve_model(const std::string & docker_url) { + // Parse docker://ai/smollm2:135M-Q4_K_M + if (docker_url.substr(0, 9) != "docker://") { + return docker_url; // Not a docker URL, return as-is + } + + std::string model_spec = docker_url.substr(9); // Remove "docker://" + + // Parse ai/smollm2:135M-Q4_K_M + size_t colon_pos = model_spec.find(':'); + std::string repo, tag; + if (colon_pos != std::string::npos) { + repo = model_spec.substr(0, colon_pos); + tag = model_spec.substr(colon_pos + 1); + } else { + repo = model_spec; + tag = "latest"; + } + + LOG_INF("Downloading Docker AI model: %s:%s\n", repo.c_str(), tag.c_str()); + try { + std::string token = common_docker_get_token(repo); // Get authentication token + + // Get manifest + std::string manifest_url = "https://registry-1.docker.io/v2/" + repo + "/manifests/" + tag; + common_remote_params manifest_params; + manifest_params.headers.push_back("Authorization: Bearer " + token); + manifest_params.headers.push_back( + "Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json"); + auto manifest_res = common_remote_get_content(manifest_url, manifest_params); + if (manifest_res.first != 200) { + throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first)); + } + + std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end()); + nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str); + std::string gguf_digest; // Find the GGUF layer + if (manifest.contains("layers")) { + for (const auto & layer : manifest["layers"]) { + if (layer.contains("mediaType")) { + std::string media_type = layer["mediaType"].get(); + if (media_type == "application/vnd.docker.ai.gguf.v3" || + media_type.find("gguf") != std::string::npos) { + gguf_digest = layer["digest"].get(); + break; + } + } + } + } + + if (gguf_digest.empty()) { + throw std::runtime_error("No GGUF layer found in Docker manifest"); + } + + // Prepare local filename + std::string model_filename = repo; + std::replace(model_filename.begin(), model_filename.end(), '/', '_'); + model_filename += "_" + tag + ".gguf"; + std::string local_path = fs_get_cache_file(model_filename); + if (std::filesystem::exists(local_path)) { // Check if already downloaded + LOG_INF("Docker model already cached: %s\n", local_path.c_str()); + return local_path; + } + + // Download the blob using common_download_file_single with is_docker=true + std::string blob_url = "https://registry-1.docker.io/v2/" + repo + "/blobs/" + gguf_digest; + if (!common_download_file_single(blob_url, local_path, token, false, true)) { + throw std::runtime_error("Failed to download Docker blob"); + } + + LOG_INF("Downloaded Docker model to: %s\n", local_path.c_str()); + return local_path; + } catch (const std::exception & e) { + LOG_ERR("Docker model download failed: %s\n", e.what()); + throw; + } +} + // // utils // @@ -793,6 +981,12 @@ static handle_model_result common_params_handle_model( const std::string & model_path_default, bool offline) { handle_model_result result; + + // Handle Docker URLs by resolving them to local paths + if (!model.path.empty()) { + model.path = common_docker_resolve_model(model.path); + } + // handle pre-fill default model path and url based on hf_repo and hf_file { if (!model.hf_repo.empty()) { diff --git a/common/arg.h b/common/arg.h index 70bea100fd4f2..ac2a3bc246993 100644 --- a/common/arg.h +++ b/common/arg.h @@ -87,3 +87,7 @@ struct common_remote_params { }; // get remote file content, returns std::pair> common_remote_get_content(const std::string & url, const common_remote_params & params); + +// Docker registry functions +std::string common_docker_get_token(const std::string & repo); +std::string common_docker_resolve_model(const std::string & docker_url); diff --git a/common/common.cpp b/common/common.cpp index 0c92d4d57ddbf..293eac370a4fa 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -8,6 +8,9 @@ #include "common.h" #include "log.h" #include "llama.h" +#include "arg.h" + +#include #include #include @@ -890,7 +893,6 @@ std::string fs_get_cache_file(const std::string & filename) { return cache_directory + filename; } - // // Model utils //