Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 205 additions & 11 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,101 @@ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int ma
return false;
}

struct FILE_deleter {
void operator()(FILE * f) const { fclose(f); }
};

// download one single file from remote URL to local path
static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token, bool offline) {
static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token, bool offline, bool is_docker = false) {
// For docker downloads, use simplified logic without caching/metadata
if (is_docker) {
// Check if the file already exists locally (simple existence check for docker)
if (std::filesystem::exists(path)) {
LOG_INF("%s: docker file already cached: %s\n", __func__, path.c_str());
return true;
}

if (offline) {
LOG_ERR("%s: required docker file is not available in cache (offline mode): %s\n", __func__, path.c_str());
return false;
}

// For docker downloads, proceed directly to download without HEAD requests or metadata
std::string path_temporary = path + ".tmp";

// Remove any existing temporary file
if (std::filesystem::exists(path_temporary)) {
std::filesystem::remove(path_temporary);
}

std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "wb"));
if (!outfile) {
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_temporary.c_str());
return false;
}

// Initialize libcurl
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
curl_slist_ptr http_headers;
if (!curl) {
LOG_ERR("%s: error initializing libcurl\n", __func__);
return false;
}

// Set the URL and options
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);

http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
if (!bearer_token.empty()) {
std::string auth_header = "Authorization: Bearer " + bearer_token;
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
}
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);

#if defined(_WIN32)
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
#endif

typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd);
auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
return fwrite(data, size, nmemb, (FILE *)fd);
};
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get());

// Perform the download
CURLcode res = curl_easy_perform(curl.get());
if (res != CURLE_OK) {
LOG_ERR("%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res));
outfile.reset();
std::filesystem::remove(path_temporary);
return false;
}

// Check HTTP response code
long response_code;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &response_code);
if (response_code != 200) {
LOG_ERR("%s: HTTP error %ld\n", __func__, response_code);
outfile.reset();
std::filesystem::remove(path_temporary);
return false;
}

// Close the file and move to final location
outfile.reset();

if (std::filesystem::exists(path)) {
std::filesystem::remove(path);
}

std::filesystem::rename(path_temporary, path);
return true;
}

// Standard download logic for non-docker files
// Check if the file already exists locally
auto file_exists = std::filesystem::exists(path);

Expand Down Expand Up @@ -389,13 +482,6 @@ static bool common_download_file_single(const std::string & url, const std::stri
}

// Set the output file

struct FILE_deleter {
void operator()(FILE * f) const {
fclose(f);
}
};

std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "wb"));
if (!outfile) {
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str());
Expand Down Expand Up @@ -473,7 +559,7 @@ static bool common_download_file_multiple(const std::vector<std::pair<std::strin
std::vector<std::future<bool>> futures_download;
for (auto const & item : urls) {
futures_download.push_back(std::async(std::launch::async, [bearer_token, offline](const std::pair<std::string, std::string> & it) -> bool {
return common_download_file_single(it.first, it.second, bearer_token, offline);
return common_download_file_single(it.first, it.second, bearer_token, offline, false);
}, item));
}

Expand All @@ -497,7 +583,7 @@ static bool common_download_model(
return false;
}

if (!common_download_file_single(model.url, model.path, bearer_token, offline)) {
if (!common_download_file_single(model.url, model.path, bearer_token, offline, false)) {
return false;
}

Expand Down Expand Up @@ -712,7 +798,7 @@ bool common_has_curl() {
return false;
}

static bool common_download_file_single(const std::string &, const std::string &, const std::string &, bool) {
static bool common_download_file_single(const std::string &, const std::string &, const std::string &, bool, bool = false) {
LOG_ERR("error: built without CURL, cannot download model from internet\n");
return false;
}
Expand Down Expand Up @@ -745,6 +831,108 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &

#endif // LLAMA_USE_CURL

//
// Docker registry functions
//

std::string common_docker_get_token(const std::string & repo) {
std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull";

common_remote_params params;
auto res = common_remote_get_content(url, params);

if (res.first != 200) {
throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first));
}

std::string response_str(res.second.begin(), res.second.end());
nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str);

if (!response.contains("token")) {
throw std::runtime_error("Docker registry token response missing 'token' field");
}

return response["token"].get<std::string>();
}

std::string common_docker_resolve_model(const std::string & docker_url) {
// Parse docker://ai/smollm2:135M-Q4_K_M
if (docker_url.substr(0, 9) != "docker://") {
return docker_url; // Not a docker URL, return as-is
}

std::string model_spec = docker_url.substr(9); // Remove "docker://"

// Parse ai/smollm2:135M-Q4_K_M
size_t colon_pos = model_spec.find(':');
std::string repo, tag;
if (colon_pos != std::string::npos) {
repo = model_spec.substr(0, colon_pos);
tag = model_spec.substr(colon_pos + 1);
} else {
repo = model_spec;
tag = "latest";
}

LOG_INF("Downloading Docker AI model: %s:%s\n", repo.c_str(), tag.c_str());
try {
std::string token = common_docker_get_token(repo); // Get authentication token

// Get manifest
std::string manifest_url = "https://registry-1.docker.io/v2/" + repo + "/manifests/" + tag;
common_remote_params manifest_params;
manifest_params.headers.push_back("Authorization: Bearer " + token);
manifest_params.headers.push_back(
"Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json");
auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
if (manifest_res.first != 200) {
throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
}

std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end());
nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str);
std::string gguf_digest; // Find the GGUF layer
if (manifest.contains("layers")) {
for (const auto & layer : manifest["layers"]) {
if (layer.contains("mediaType")) {
std::string media_type = layer["mediaType"].get<std::string>();
if (media_type == "application/vnd.docker.ai.gguf.v3" ||
media_type.find("gguf") != std::string::npos) {
gguf_digest = layer["digest"].get<std::string>();
break;
}
}
}
}

if (gguf_digest.empty()) {
throw std::runtime_error("No GGUF layer found in Docker manifest");
}

// Prepare local filename
std::string model_filename = repo;
std::replace(model_filename.begin(), model_filename.end(), '/', '_');
model_filename += "_" + tag + ".gguf";
std::string local_path = fs_get_cache_file(model_filename);
if (std::filesystem::exists(local_path)) { // Check if already downloaded
LOG_INF("Docker model already cached: %s\n", local_path.c_str());
return local_path;
}

// Download the blob using common_download_file_single with is_docker=true
std::string blob_url = "https://registry-1.docker.io/v2/" + repo + "/blobs/" + gguf_digest;
if (!common_download_file_single(blob_url, local_path, token, false, true)) {
throw std::runtime_error("Failed to download Docker blob");
}

LOG_INF("Downloaded Docker model to: %s\n", local_path.c_str());
return local_path;
} catch (const std::exception & e) {
LOG_ERR("Docker model download failed: %s\n", e.what());
throw;
}
}

//
// utils
//
Expand Down Expand Up @@ -793,6 +981,12 @@ static handle_model_result common_params_handle_model(
const std::string & model_path_default,
bool offline) {
handle_model_result result;

// Handle Docker URLs by resolving them to local paths
if (!model.path.empty()) {
model.path = common_docker_resolve_model(model.path);
}

// handle pre-fill default model path and url based on hf_repo and hf_file
{
if (!model.hf_repo.empty()) {
Expand Down
4 changes: 4 additions & 0 deletions common/arg.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,7 @@ struct common_remote_params {
};
// get remote file content, returns <http_code, raw_response_body>
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params);

// Docker registry functions
std::string common_docker_get_token(const std::string & repo);
std::string common_docker_resolve_model(const std::string & docker_url);
4 changes: 3 additions & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
#include "common.h"
#include "log.h"
#include "llama.h"
#include "arg.h"

#include <nlohmann/json.hpp>

#include <algorithm>
#include <cinttypes>
Expand Down Expand Up @@ -890,7 +893,6 @@ std::string fs_get_cache_file(const std::string & filename) {
return cache_directory + filename;
}


//
// Model utils
//
Expand Down