Skip to content

Commit 4bf5549

Browse files
authored
Add docker protocol support for llama-server model loading (ggml-org#15790)
To pull and run models via: llama-server -dr gemma3 Add some validators and sanitizers for Docker Model urls and metadata Signed-off-by: Eric Curtin <[email protected]>
1 parent f4e664f commit 4bf5549

File tree

2 files changed

+135
-5
lines changed

2 files changed

+135
-5
lines changed

common/arg.cpp

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,124 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
745745

746746
#endif // LLAMA_USE_CURL
747747

748+
//
749+
// Docker registry functions
750+
//
751+
752+
static std::string common_docker_get_token(const std::string & repo) {
753+
std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull";
754+
755+
common_remote_params params;
756+
auto res = common_remote_get_content(url, params);
757+
758+
if (res.first != 200) {
759+
throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first));
760+
}
761+
762+
std::string response_str(res.second.begin(), res.second.end());
763+
nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str);
764+
765+
if (!response.contains("token")) {
766+
throw std::runtime_error("Docker registry token response missing 'token' field");
767+
}
768+
769+
return response["token"].get<std::string>();
770+
}
771+
772+
static std::string common_docker_resolve_model(const std::string & docker) {
773+
// Parse ai/smollm2:135M-Q4_K_M
774+
size_t colon_pos = docker.find(':');
775+
std::string repo, tag;
776+
if (colon_pos != std::string::npos) {
777+
repo = docker.substr(0, colon_pos);
778+
tag = docker.substr(colon_pos + 1);
779+
} else {
780+
repo = docker;
781+
tag = "latest";
782+
}
783+
784+
// ai/ is the default
785+
size_t slash_pos = docker.find('/');
786+
if (slash_pos == std::string::npos) {
787+
repo.insert(0, "ai/");
788+
}
789+
790+
LOG_INF("%s: Downloading Docker Model: %s:%s\n", __func__, repo.c_str(), tag.c_str());
791+
try {
792+
// --- helper: digest validation ---
793+
auto validate_oci_digest = [](const std::string & digest) -> std::string {
794+
// Expected: algo:hex ; start with sha256 (64 hex chars)
795+
// You can extend this map if supporting other algorithms in future.
796+
static const std::regex re("^sha256:([a-fA-F0-9]{64})$");
797+
std::smatch m;
798+
if (!std::regex_match(digest, m, re)) {
799+
throw std::runtime_error("Invalid OCI digest format received in manifest: " + digest);
800+
}
801+
// normalize hex to lowercase
802+
std::string normalized = digest;
803+
std::transform(normalized.begin()+7, normalized.end(), normalized.begin()+7, [](unsigned char c){
804+
return std::tolower(c);
805+
});
806+
return normalized;
807+
};
808+
809+
std::string token = common_docker_get_token(repo); // Get authentication token
810+
811+
// Get manifest
812+
const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo;
813+
std::string manifest_url = url_prefix + "/manifests/" + tag;
814+
common_remote_params manifest_params;
815+
manifest_params.headers.push_back("Authorization: Bearer " + token);
816+
manifest_params.headers.push_back(
817+
"Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json");
818+
auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
819+
if (manifest_res.first != 200) {
820+
throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
821+
}
822+
823+
std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end());
824+
nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str);
825+
std::string gguf_digest; // Find the GGUF layer
826+
if (manifest.contains("layers")) {
827+
for (const auto & layer : manifest["layers"]) {
828+
if (layer.contains("mediaType")) {
829+
std::string media_type = layer["mediaType"].get<std::string>();
830+
if (media_type == "application/vnd.docker.ai.gguf.v3" ||
831+
media_type.find("gguf") != std::string::npos) {
832+
gguf_digest = layer["digest"].get<std::string>();
833+
break;
834+
}
835+
}
836+
}
837+
}
838+
839+
if (gguf_digest.empty()) {
840+
throw std::runtime_error("No GGUF layer found in Docker manifest");
841+
}
842+
843+
// Validate & normalize digest
844+
gguf_digest = validate_oci_digest(gguf_digest);
845+
LOG_DBG("%s: Using validated digest: %s\n", __func__, gguf_digest.c_str());
846+
847+
// Prepare local filename
848+
std::string model_filename = repo;
849+
std::replace(model_filename.begin(), model_filename.end(), '/', '_');
850+
model_filename += "_" + tag + ".gguf";
851+
std::string local_path = fs_get_cache_file(model_filename);
852+
853+
const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
854+
if (!common_download_file_single(blob_url, local_path, token, false)) {
855+
throw std::runtime_error("Failed to download Docker Model");
856+
}
857+
858+
LOG_INF("%s: Downloaded Docker Model to: %s\n", __func__, local_path.c_str());
859+
return local_path;
860+
} catch (const std::exception & e) {
861+
LOG_ERR("%s: Docker Model download failed: %s\n", __func__, e.what());
862+
throw;
863+
}
864+
}
865+
748866
//
749867
// utils
750868
//
@@ -795,7 +913,9 @@ static handle_model_result common_params_handle_model(
795913
handle_model_result result;
796914
// handle pre-fill default model path and url based on hf_repo and hf_file
797915
{
798-
if (!model.hf_repo.empty()) {
916+
if (!model.docker_repo.empty()) { // Handle Docker URLs by resolving them to local paths
917+
model.path = common_docker_resolve_model(model.docker_repo);
918+
} else if (!model.hf_repo.empty()) {
799919
// short-hand to avoid specifying --hf-file -> default it to --model
800920
if (model.hf_file.empty()) {
801921
if (model.path.empty()) {
@@ -2636,6 +2756,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
26362756
params.model.url = value;
26372757
}
26382758
).set_env("LLAMA_ARG_MODEL_URL"));
2759+
add_opt(common_arg(
2760+
{ "-dr", "--docker-repo" }, "[<repo>/]<model>[:quant]",
2761+
"Docker Hub model repository. repo is optional, default to ai/. quant is optional, default to :latest.\n"
2762+
"example: gemma3\n"
2763+
"(default: unused)",
2764+
[](common_params & params, const std::string & value) {
2765+
params.model.docker_repo = value;
2766+
}
2767+
).set_env("LLAMA_ARG_DOCKER_REPO"));
26392768
add_opt(common_arg(
26402769
{"-hf", "-hfr", "--hf-repo"}, "<user>/<model>[:quant]",
26412770
"Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n"

common/common.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,11 @@ struct common_params_sampling {
193193
};
194194

195195
struct common_params_model {
196-
std::string path = ""; // model local path // NOLINT
197-
std::string url = ""; // model url to download // NOLINT
198-
std::string hf_repo = ""; // HF repo // NOLINT
199-
std::string hf_file = ""; // HF file // NOLINT
196+
std::string path = ""; // model local path // NOLINT
197+
std::string url = ""; // model url to download // NOLINT
198+
std::string hf_repo = ""; // HF repo // NOLINT
199+
std::string hf_file = ""; // HF file // NOLINT
200+
std::string docker_repo = ""; // Docker repo // NOLINT
200201
};
201202

202203
struct common_params_speculative {

0 commit comments

Comments
 (0)