Skip to content

Commit f16b5b6

Browse files
committed
Add docker:// protocol support for llama-server model pulling
So we can pull and run models from dockerhub via: llama-server -m docker://ai/smollm2:135M-Q4_K_M Signed-off-by: Eric Curtin <[email protected]>
1 parent 0fce7a1 commit f16b5b6

File tree

3 files changed

+200
-10
lines changed

3 files changed

+200
-10
lines changed

common/arg.cpp

Lines changed: 190 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,10 @@ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int ma
244244
return false;
245245
}
246246

247+
struct FILE_deleter {
248+
void operator()(FILE * f) const { fclose(f); }
249+
};
250+
247251
// download one single file from remote URL to local path
248252
static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token, bool offline) {
249253
// Check if the file already exists locally
@@ -389,13 +393,6 @@ static bool common_download_file_single(const std::string & url, const std::stri
389393
}
390394

391395
// Set the output file
392-
393-
struct FILE_deleter {
394-
void operator()(FILE * f) const {
395-
fclose(f);
396-
}
397-
};
398-
399396
std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "wb"));
400397
if (!outfile) {
401398
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str());
@@ -745,6 +742,192 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
745742

746743
#endif // LLAMA_USE_CURL
747744

745+
//
746+
// Docker registry functions
747+
//
748+
749+
static std::string common_docker_get_token(const std::string & repo) {
750+
std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull";
751+
752+
common_remote_params params;
753+
auto res = common_remote_get_content(url, params);
754+
755+
if (res.first != 200) {
756+
throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first));
757+
}
758+
759+
std::string response_str(res.second.begin(), res.second.end());
760+
nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str);
761+
762+
if (!response.contains("token")) {
763+
throw std::runtime_error("Docker registry token response missing 'token' field");
764+
}
765+
766+
return response["token"].get<std::string>();
767+
}
768+
769+
#ifdef LLAMA_USE_CURL
770+
771+
// Helper function to download Docker blob directly to file
772+
static bool common_docker_download_blob(const std::string & blob_url,
773+
const std::string & token,
774+
const std::string & local_path) {
775+
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
776+
curl_slist_ptr http_headers;
777+
if (!curl.get()) {
778+
LOG_ERR("%s: curl_easy_init() failed\n", __func__);
779+
return false;
780+
}
781+
782+
// Prepare temporary filename for safe downloading
783+
std::string path_temporary = local_path + ".tmp";
784+
std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "wb"));
785+
if (!outfile) {
786+
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_temporary.c_str());
787+
return false;
788+
}
789+
790+
// Set up CURL options
791+
curl_easy_setopt(curl.get(), CURLOPT_URL, blob_url.c_str());
792+
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
793+
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
794+
795+
// Set up write callback to stream directly to file
796+
typedef size_t (*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd);
797+
auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
798+
return fwrite(data, size, nmemb, (FILE *) fd);
799+
};
800+
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
801+
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get());
802+
803+
# if defined(_WIN32)
804+
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
805+
# endif
806+
807+
// Set headers
808+
http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
809+
std::string auth_header = "Authorization: Bearer " + token;
810+
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
811+
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
812+
813+
// Perform the download
814+
CURLcode res = curl_easy_perform(curl.get());
815+
if (res != CURLE_OK) {
816+
LOG_ERR("%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res));
817+
outfile.reset(); // Close file before removing
818+
std::filesystem::remove(path_temporary);
819+
return false;
820+
}
821+
822+
// Check HTTP response code
823+
long response_code;
824+
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &response_code);
825+
if (response_code != 200) {
826+
LOG_ERR("%s: HTTP error %ld\n", __func__, response_code);
827+
outfile.reset(); // Close file before removing
828+
std::filesystem::remove(path_temporary);
829+
return false;
830+
}
831+
832+
// Close the file and move to final location
833+
outfile.reset();
834+
835+
if (std::filesystem::exists(local_path)) {
836+
std::filesystem::remove(local_path);
837+
}
838+
839+
std::filesystem::rename(path_temporary, local_path);
840+
841+
return true;
842+
}
843+
844+
#else
845+
846+
static bool common_docker_download_blob(const std::string &, const std::string &, const std::string &) {
847+
LOG_ERR("error: built without CURL, cannot download Docker blob\n");
848+
return false;
849+
}
850+
851+
#endif // LLAMA_USE_CURL
852+
853+
std::string common_docker_resolve_model(const std::string & docker_url) {
854+
// Parse docker://ai/smollm2:135M-Q4_K_M
855+
if (docker_url.substr(0, 9) != "docker://") {
856+
return docker_url; // Not a docker URL, return as-is
857+
}
858+
859+
std::string model_spec = docker_url.substr(9); // Remove "docker://"
860+
861+
// Parse ai/smollm2:135M-Q4_K_M
862+
size_t colon_pos = model_spec.find(':');
863+
std::string repo, tag;
864+
if (colon_pos != std::string::npos) {
865+
repo = model_spec.substr(0, colon_pos);
866+
tag = model_spec.substr(colon_pos + 1);
867+
} else {
868+
repo = model_spec;
869+
tag = "latest";
870+
}
871+
872+
LOG_INF("Downloading Docker AI model: %s:%s\n", repo.c_str(), tag.c_str());
873+
try {
874+
std::string token = common_docker_get_token(repo); // Get authentication token
875+
876+
// Get manifest
877+
std::string manifest_url = "https://registry-1.docker.io/v2/" + repo + "/manifests/" + tag;
878+
common_remote_params manifest_params;
879+
manifest_params.headers.push_back("Authorization: Bearer " + token);
880+
manifest_params.headers.push_back(
881+
"Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json");
882+
auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
883+
if (manifest_res.first != 200) {
884+
throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
885+
}
886+
887+
std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end());
888+
nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str);
889+
std::string gguf_digest; // Find the GGUF layer
890+
if (manifest.contains("layers")) {
891+
for (const auto & layer : manifest["layers"]) {
892+
if (layer.contains("mediaType")) {
893+
std::string media_type = layer["mediaType"].get<std::string>();
894+
if (media_type == "application/vnd.docker.ai.gguf.v3" ||
895+
media_type.find("gguf") != std::string::npos) {
896+
gguf_digest = layer["digest"].get<std::string>();
897+
break;
898+
}
899+
}
900+
}
901+
}
902+
903+
if (gguf_digest.empty()) {
904+
throw std::runtime_error("No GGUF layer found in Docker manifest");
905+
}
906+
907+
// Prepare local filename
908+
std::string model_filename = repo;
909+
std::replace(model_filename.begin(), model_filename.end(), '/', '_');
910+
model_filename += "_" + tag + ".gguf";
911+
std::string local_path = fs_get_cache_file(model_filename);
912+
if (std::filesystem::exists(local_path)) { // Check if already downloaded
913+
LOG_INF("Docker model already cached: %s\n", local_path.c_str());
914+
return local_path;
915+
}
916+
917+
// Download the blob using streaming approach
918+
std::string blob_url = "https://registry-1.docker.io/v2/" + repo + "/blobs/" + gguf_digest;
919+
if (!common_docker_download_blob(blob_url, token, local_path)) {
920+
throw std::runtime_error("Failed to download Docker blob");
921+
}
922+
923+
LOG_INF("Downloaded Docker model to: %s\n", local_path.c_str());
924+
return local_path;
925+
} catch (const std::exception & e) {
926+
LOG_ERR("Docker model download failed: %s\n", e.what());
927+
throw;
928+
}
929+
}
930+
748931
//
749932
// utils
750933
//

common/arg.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,6 @@ struct common_remote_params {
8787
};
8888
// get remote file content, returns <http_code, raw_response_body>
8989
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params);
90+
91+
// Docker registry functions
92+
std::string common_docker_resolve_model(const std::string & docker_url);

common/common.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "common.h"
99
#include "log.h"
1010
#include "llama.h"
11+
#include "arg.h"
1112

1213
#include <algorithm>
1314
#include <cinttypes>
@@ -899,10 +900,13 @@ struct common_init_result common_init_from_params(common_params & params) {
899900
common_init_result iparams;
900901
auto mparams = common_model_params_to_llama(params);
901902

902-
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
903+
// Resolve Docker URLs if needed
904+
std::string resolved_model_path = common_docker_resolve_model(params.model.path);
905+
906+
llama_model * model = llama_model_load_from_file(resolved_model_path.c_str(), mparams);
903907
if (model == NULL) {
904908
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
905-
__func__, params.model.path.c_str());
909+
__func__, resolved_model_path.c_str());
906910
return iparams;
907911
}
908912

@@ -913,7 +917,7 @@ struct common_init_result common_init_from_params(common_params & params) {
913917
llama_context * lctx = llama_init_from_model(model, cparams);
914918
if (lctx == NULL) {
915919
LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
916-
__func__, params.model.path.c_str());
920+
__func__, resolved_model_path.c_str());
917921
llama_model_free(model);
918922
return iparams;
919923
}

0 commit comments

Comments
 (0)