Skip to content

Commit d166c42

Browse files
committed
Add docker protocol support for llama-server model loading
To pull and run models via: llama-server -d ai/smollm2:135M-Q4_K_M Signed-off-by: Eric Curtin <[email protected]>
1 parent 4f63cd7 commit d166c42

File tree

2 files changed

+118
-14
lines changed

2 files changed

+118
-14
lines changed

common/arg.cpp

Lines changed: 113 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,16 @@ static bool common_download_file_single(const std::string & url, const std::stri
431431
// start the download
432432
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
433433
llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str());
434+
435+
// Write the updated JSON metadata file.
436+
metadata.update({
437+
{"url", url},
438+
{"etag", headers.etag},
439+
{"lastModified", headers.last_modified}
440+
});
441+
write_file(metadata_path, metadata.dump(4));
442+
LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
443+
434444
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS, "GET");
435445
if (!was_perform_successful) {
436446
return false;
@@ -446,15 +456,6 @@ static bool common_download_file_single(const std::string & url, const std::stri
446456
// Causes file to be closed explicitly here before we rename it.
447457
outfile.reset();
448458

449-
// Write the updated JSON metadata file.
450-
metadata.update({
451-
{"url", url},
452-
{"etag", headers.etag},
453-
{"lastModified", headers.last_modified}
454-
});
455-
write_file(metadata_path, metadata.dump(4));
456-
LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
457-
458459
if (rename(path_temporary.c_str(), path.c_str()) != 0) {
459460
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
460461
return false;
@@ -745,6 +746,97 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
745746

746747
#endif // LLAMA_USE_CURL
747748

749+
//
750+
// Docker registry functions
751+
//
752+
753+
static std::string common_docker_get_token(const std::string & repo) {
754+
std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull";
755+
756+
common_remote_params params;
757+
auto res = common_remote_get_content(url, params);
758+
759+
if (res.first != 200) {
760+
throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first));
761+
}
762+
763+
std::string response_str(res.second.begin(), res.second.end());
764+
nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str);
765+
766+
if (!response.contains("token")) {
767+
throw std::runtime_error("Docker registry token response missing 'token' field");
768+
}
769+
770+
return response["token"].get<std::string>();
771+
}
772+
773+
static std::string common_docker_resolve_model(const std::string & docker) {
774+
// Parse ai/smollm2:135M-Q4_K_M
775+
size_t colon_pos = docker.find(':');
776+
std::string repo, tag;
777+
if (colon_pos != std::string::npos) {
778+
repo = docker.substr(0, colon_pos);
779+
tag = docker.substr(colon_pos + 1);
780+
} else {
781+
repo = docker;
782+
tag = "latest";
783+
}
784+
785+
LOG_INF("Downloading Docker Model: %s:%s\n", repo.c_str(), tag.c_str());
786+
try {
787+
std::string token = common_docker_get_token(repo); // Get authentication token
788+
789+
// Get manifest
790+
std::string manifest_url = "https://registry-1.docker.io/v2/" + repo + "/manifests/" + tag;
791+
common_remote_params manifest_params;
792+
manifest_params.headers.push_back("Authorization: Bearer " + token);
793+
manifest_params.headers.push_back(
794+
"Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json");
795+
auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
796+
if (manifest_res.first != 200) {
797+
throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
798+
}
799+
800+
std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end());
801+
nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str);
802+
std::string gguf_digest; // Find the GGUF layer
803+
if (manifest.contains("layers")) {
804+
for (const auto & layer : manifest["layers"]) {
805+
if (layer.contains("mediaType")) {
806+
std::string media_type = layer["mediaType"].get<std::string>();
807+
if (media_type == "application/vnd.docker.ai.gguf.v3" ||
808+
media_type.find("gguf") != std::string::npos) {
809+
gguf_digest = layer["digest"].get<std::string>();
810+
break;
811+
}
812+
}
813+
}
814+
}
815+
816+
if (gguf_digest.empty()) {
817+
throw std::runtime_error("No GGUF layer found in Docker manifest");
818+
}
819+
820+
// Prepare local filename
821+
std::string model_filename = repo;
822+
std::replace(model_filename.begin(), model_filename.end(), '/', '_');
823+
model_filename += "_" + tag + ".gguf";
824+
std::string local_path = fs_get_cache_file(model_filename);
825+
826+
// Download the blob using common_download_file_single with is_docker=true
827+
std::string blob_url = "https://registry-1.docker.io/v2/" + repo + "/blobs/" + gguf_digest;
828+
if (!common_download_file_single(blob_url, local_path, token, false)) {
829+
throw std::runtime_error("Failed to download Docker blob");
830+
}
831+
832+
LOG_INF("Downloaded Docker Model to: %s\n", local_path.c_str());
833+
return local_path;
834+
} catch (const std::exception & e) {
835+
LOG_ERR("Docker Model download failed: %s\n", e.what());
836+
throw;
837+
}
838+
}
839+
748840
//
749841
// utils
750842
//
@@ -795,7 +887,9 @@ static handle_model_result common_params_handle_model(
795887
handle_model_result result;
796888
// handle pre-fill default model path and url based on hf_repo and hf_file
797889
{
798-
if (!model.hf_repo.empty()) {
890+
if (!model.docker_repo.empty()) { // Handle Docker URLs by resolving them to local paths
891+
model.path = common_docker_resolve_model(model.docker_repo);
892+
} else if (!model.hf_repo.empty()) {
799893
// short-hand to avoid specifying --hf-file -> default it to --model
800894
if (model.hf_file.empty()) {
801895
if (model.path.empty()) {
@@ -2636,6 +2730,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
26362730
params.model.url = value;
26372731
}
26382732
).set_env("LLAMA_ARG_MODEL_URL"));
2733+
add_opt(common_arg(
2734+
{ "-d", "-dr", "--docker-repo" }, "<repo>/<model>[:quant]",
2735+
"Docker Hub model repository; quant is optional, default to latest.\n"
2736+
"example: ai/smollm2:135M-Q4_K_M\n"
2737+
"(default: unused)",
2738+
[](common_params & params, const std::string & value) {
2739+
params.model.docker_repo = value;
2740+
}
2741+
).set_env("LLAMA_ARG_DOCKER"));
26392742
add_opt(common_arg(
26402743
{"-hf", "-hfr", "--hf-repo"}, "<user>/<model>[:quant]",
26412744
"Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n"

common/common.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,11 @@ struct common_params_sampling {
193193
};
194194

195195
struct common_params_model {
196-
std::string path = ""; // model local path // NOLINT
197-
std::string url = ""; // model url to download // NOLINT
198-
std::string hf_repo = ""; // HF repo // NOLINT
199-
std::string hf_file = ""; // HF file // NOLINT
196+
std::string path = ""; // model local path // NOLINT
197+
std::string url = ""; // model url to download // NOLINT
198+
std::string hf_repo = ""; // HF repo // NOLINT
199+
std::string hf_file = ""; // HF file // NOLINT
200+
std::string docker_repo = ""; // Docker Model url to download // NOLINT
200201
};
201202

202203
struct common_params_speculative {

0 commit comments

Comments
 (0)