Skip to content

Commit 0bd8384

Browse files
committed
Add docker protocol support for llama-server model loading
- To pull and run models via: llama-server -d ai/smollm2:135M-Q4_K_M - Implement resumable downloads in common_download_file_single function - Add detection of partial download files (.downloadInProgress) - Check server support for HTTP Range requests via Accept-Ranges header - Implement HTTP Range request with "bytes=<start>-" header - Open files in append mode when resuming vs create mode for new downloads - Maintain backwards compatibility with existing functionality - Add some validators and sanitizers for Docker Model urls and metadata Signed-off-by: Eric Curtin <[email protected]>
1 parent 28b5f19 commit 0bd8384

File tree

2 files changed

+187
-24
lines changed

2 files changed

+187
-24
lines changed

common/arg.cpp

Lines changed: 182 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -220,12 +220,24 @@ struct curl_slist_ptr {
220220
#define CURL_MAX_RETRY 3
221221
#define CURL_RETRY_DELAY_SECONDS 2
222222

223-
static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds, const char * method_name) {
223+
static bool curl_perform_with_retry(const std::string & url,
224+
CURL * curl,
225+
int max_attempts,
226+
int retry_delay_seconds,
227+
const char * method_name,
228+
const std::string & path_temporary = "") {
224229
int remaining_attempts = max_attempts;
225230

226231
while (remaining_attempts > 0) {
227232
LOG_INF("%s: %s %s (attempt %d of %d)...\n", __func__ , method_name, url.c_str(), max_attempts - remaining_attempts + 1, max_attempts);
228233

234+
if (std::filesystem::exists(path_temporary)) {
235+
const long partial_size = static_cast<long>(std::filesystem::file_size(path_temporary));
236+
LOG_INF("%s: server supports range requests, resuming download from byte %ld\n", __func__, partial_size);
237+
const std::string range_str = std::to_string(partial_size) + "-";
238+
curl_easy_setopt(curl, CURLOPT_RANGE, range_str.c_str());
239+
}
240+
229241
CURLcode res = curl_easy_perform(curl);
230242
if (res == CURLE_OK) {
231243
return true;
@@ -246,16 +258,16 @@ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int ma
246258

247259
// download one single file from remote URL to local path
248260
static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token, bool offline) {
249-
// Check if the file already exists locally
250-
auto file_exists = std::filesystem::exists(path);
251-
252261
// If the file exists, check its JSON metadata companion file.
253262
std::string metadata_path = path + ".json";
254263
nlohmann::json metadata; // TODO @ngxson : get rid of this json, use regex instead
255264
std::string etag;
256265
std::string last_modified;
257266

258-
if (file_exists) {
267+
// Check if the file already exists locally
268+
auto file_exists = std::filesystem::exists(path);
269+
auto json_file_exists = std::filesystem::exists(metadata_path);
270+
if (json_file_exists) {
259271
if (offline) {
260272
LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
261273
return true; // skip verification/downloading
@@ -289,6 +301,7 @@ static bool common_download_file_single(const std::string & url, const std::stri
289301
struct common_load_model_from_url_headers {
290302
std::string etag;
291303
std::string last_modified;
304+
std::string accept_ranges;
292305
};
293306

294307
common_load_model_from_url_headers headers;
@@ -328,7 +341,7 @@ static bool common_download_file_single(const std::string & url, const std::stri
328341
static std::regex header_regex("([^:]+): (.*)\r\n");
329342
static std::regex etag_regex("ETag", std::regex_constants::icase);
330343
static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
331-
344+
static std::regex accept_ranges_regex("Accept-Ranges", std::regex_constants::icase);
332345
std::string header(buffer, n_items);
333346
std::smatch match;
334347
if (std::regex_match(header, match, header_regex)) {
@@ -338,6 +351,8 @@ static bool common_download_file_single(const std::string & url, const std::stri
338351
headers->etag = value;
339352
} else if (std::regex_match(key, match, last_modified_regex)) {
340353
headers->last_modified = value;
354+
} else if (std::regex_match(key, match, accept_ranges_regex)) {
355+
headers->accept_ranges = value;
341356
}
342357
}
343358
return n_items;
@@ -366,28 +381,48 @@ static bool common_download_file_single(const std::string & url, const std::stri
366381

367382
// if head_request_ok is false, we don't have the etag or last-modified headers
368383
// we leave should_download as-is, which is true if the file does not exist
384+
bool should_download_from_scratch = false;
369385
if (head_request_ok) {
370386
// check if ETag or Last-Modified headers are different
371387
// if it is, we need to download the file again
372388
if (!etag.empty() && etag != headers.etag) {
373389
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str());
374390
should_download = true;
391+
should_download_from_scratch = true;
375392
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
376393
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str());
377394
should_download = true;
395+
should_download_from_scratch = true;
378396
}
379397
}
380398

381399
if (should_download) {
382-
std::string path_temporary = path + ".downloadInProgress";
383-
if (file_exists) {
400+
if (file_exists &&
401+
headers.accept_ranges.empty()) { // Resumable downloads not supported, delete and start again.
384402
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
385403
if (remove(path.c_str()) != 0) {
386404
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
387405
return false;
388406
}
389407
}
390408

409+
std::string path_temporary = path + ".downloadInProgress";
410+
if (should_download_from_scratch) {
411+
if (std::filesystem::exists(path_temporary)) {
412+
if (remove(path_temporary.c_str()) != 0) {
413+
LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
414+
return false;
415+
}
416+
}
417+
418+
if (std::filesystem::exists(path)) {
419+
if (remove(path.c_str()) != 0) {
420+
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
421+
return false;
422+
}
423+
}
424+
}
425+
391426
// Set the output file
392427

393428
struct FILE_deleter {
@@ -396,7 +431,8 @@ static bool common_download_file_single(const std::string & url, const std::stri
396431
}
397432
};
398433

399-
std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "wb"));
434+
// Always open file in append mode could be resuming
435+
std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "ab"));
400436
if (!outfile) {
401437
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str());
402438
return false;
@@ -431,7 +467,19 @@ static bool common_download_file_single(const std::string & url, const std::stri
431467
// start the download
432468
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
433469
llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str());
434-
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS, "GET");
470+
471+
// Write the updated JSON metadata file.
472+
metadata.update({
473+
{"url", url},
474+
{"etag", headers.etag},
475+
{"lastModified", headers.last_modified}
476+
});
477+
write_file(metadata_path, metadata.dump(4));
478+
LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
479+
480+
const bool was_perform_successful =
481+
curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS, "GET",
482+
headers.accept_ranges.empty() ? "" : path_temporary);
435483
if (!was_perform_successful) {
436484
return false;
437485
}
@@ -446,15 +494,6 @@ static bool common_download_file_single(const std::string & url, const std::stri
446494
// Causes file to be closed explicitly here before we rename it.
447495
outfile.reset();
448496

449-
// Write the updated JSON metadata file.
450-
metadata.update({
451-
{"url", url},
452-
{"etag", headers.etag},
453-
{"lastModified", headers.last_modified}
454-
});
455-
write_file(metadata_path, metadata.dump(4));
456-
LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
457-
458497
if (rename(path_temporary.c_str(), path.c_str()) != 0) {
459498
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
460499
return false;
@@ -745,6 +784,118 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
745784

746785
#endif // LLAMA_USE_CURL
747786

787+
//
788+
// Docker registry functions
789+
//
790+
791+
static std::string common_docker_get_token(const std::string & repo) {
792+
std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull";
793+
794+
common_remote_params params;
795+
auto res = common_remote_get_content(url, params);
796+
797+
if (res.first != 200) {
798+
throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first));
799+
}
800+
801+
std::string response_str(res.second.begin(), res.second.end());
802+
nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str);
803+
804+
if (!response.contains("token")) {
805+
throw std::runtime_error("Docker registry token response missing 'token' field");
806+
}
807+
808+
return response["token"].get<std::string>();
809+
}
810+
811+
static std::string common_docker_resolve_model(const std::string & docker) {
812+
// Parse ai/smollm2:135M-Q4_K_M
813+
size_t colon_pos = docker.find(':');
814+
std::string repo, tag;
815+
if (colon_pos != std::string::npos) {
816+
repo = docker.substr(0, colon_pos);
817+
tag = docker.substr(colon_pos + 1);
818+
} else {
819+
repo = docker;
820+
tag = "latest";
821+
}
822+
823+
LOG_INF("Downloading Docker Model: %s:%s\n", repo.c_str(), tag.c_str());
824+
try {
825+
// --- helper: digest validation ---
826+
auto validate_oci_digest = [](const std::string & digest) -> std::string {
827+
// Expected: algo:hex ; start with sha256 (64 hex chars)
828+
// You can extend this map if supporting other algorithms in future.
829+
static const std::regex re("^sha256:([a-fA-F0-9]{64})$");
830+
std::smatch m;
831+
if (!std::regex_match(digest, m, re)) {
832+
throw std::runtime_error("Invalid OCI digest format received in manifest: " + digest);
833+
}
834+
// normalize hex to lowercase
835+
std::string normalized = digest;
836+
std::transform(normalized.begin()+7, normalized.end(), normalized.begin()+7, [](unsigned char c){
837+
return std::tolower(c);
838+
});
839+
return normalized;
840+
};
841+
842+
std::string token = common_docker_get_token(repo); // Get authentication token
843+
844+
// Get manifest
845+
std::string manifest_url = "https://registry-1.docker.io/v2/" + repo + "/manifests/" + tag;
846+
common_remote_params manifest_params;
847+
manifest_params.headers.push_back("Authorization: Bearer " + token);
848+
manifest_params.headers.push_back(
849+
"Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json");
850+
auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
851+
if (manifest_res.first != 200) {
852+
throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
853+
}
854+
855+
std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end());
856+
nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str);
857+
std::string gguf_digest; // Find the GGUF layer
858+
if (manifest.contains("layers")) {
859+
for (const auto & layer : manifest["layers"]) {
860+
if (layer.contains("mediaType")) {
861+
std::string media_type = layer["mediaType"].get<std::string>();
862+
if (media_type == "application/vnd.docker.ai.gguf.v3" ||
863+
media_type.find("gguf") != std::string::npos) {
864+
gguf_digest = layer["digest"].get<std::string>();
865+
break;
866+
}
867+
}
868+
}
869+
}
870+
871+
if (gguf_digest.empty()) {
872+
throw std::runtime_error("No GGUF layer found in Docker manifest");
873+
}
874+
875+
// Validate & normalize digest
876+
gguf_digest = validate_oci_digest(gguf_digest);
877+
LOG_DBG("Using validated digest: %s\n", gguf_digest.c_str());
878+
879+
// Prepare local filename
880+
std::string model_filename = repo;
881+
std::replace(model_filename.begin(), model_filename.end(), '/', '_');
882+
model_filename += "_" + tag + ".gguf";
883+
std::string local_path = fs_get_cache_file(model_filename);
884+
885+
// Download the blob using common_download_file_single with is_docker=true
886+
std::string blob_url = "https://registry-1.docker.io/v2/" + repo + "/blobs/" + gguf_digest;
887+
if (!common_download_file_single(blob_url, local_path, token, false)) {
888+
throw std::runtime_error("Failed to download Docker Model");
889+
}
890+
891+
LOG_INF("%s: Downloaded Docker Model to: %s\n", __func__, local_path.c_str());
892+
return local_path;
893+
} catch (const std::exception & e) {
894+
LOG_ERR("%s: Docker Model download failed: %s\n", __func__, e.what());
895+
throw;
896+
}
897+
}
898+
748899
//
749900
// utils
750901
//
@@ -795,7 +946,9 @@ static handle_model_result common_params_handle_model(
795946
handle_model_result result;
796947
// handle pre-fill default model path and url based on hf_repo and hf_file
797948
{
798-
if (!model.hf_repo.empty()) {
949+
if (!model.docker_repo.empty()) { // Handle Docker URLs by resolving them to local paths
950+
model.path = common_docker_resolve_model(model.docker_repo);
951+
} else if (!model.hf_repo.empty()) {
799952
// short-hand to avoid specifying --hf-file -> default it to --model
800953
if (model.hf_file.empty()) {
801954
if (model.path.empty()) {
@@ -2636,6 +2789,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
26362789
params.model.url = value;
26372790
}
26382791
).set_env("LLAMA_ARG_MODEL_URL"));
2792+
add_opt(common_arg(
2793+
{ "-d", "-dr", "--docker-repo" }, "<repo>/<model>[:quant]",
2794+
"Docker Hub model repository; quant is optional, default to latest.\n"
2795+
"example: ai/smollm2:135M-Q4_K_M\n"
2796+
"(default: unused)",
2797+
[](common_params & params, const std::string & value) {
2798+
params.model.docker_repo = value;
2799+
}
2800+
).set_env("LLAMA_ARG_DOCKER"));
26392801
add_opt(common_arg(
26402802
{"-hf", "-hfr", "--hf-repo"}, "<user>/<model>[:quant]",
26412803
"Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n"

common/common.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,11 @@ struct common_params_sampling {
193193
};
194194

195195
struct common_params_model {
196-
std::string path = ""; // model local path // NOLINT
197-
std::string url = ""; // model url to download // NOLINT
198-
std::string hf_repo = ""; // HF repo // NOLINT
199-
std::string hf_file = ""; // HF file // NOLINT
196+
std::string path = ""; // model local path // NOLINT
197+
std::string url = ""; // model url to download // NOLINT
198+
std::string hf_repo = ""; // HF repo // NOLINT
199+
std::string hf_file = ""; // HF file // NOLINT
200+
std::string docker_repo = ""; // Docker Model url to download // NOLINT
200201
};
201202

202203
struct common_params_speculative {

0 commit comments

Comments
 (0)