@@ -745,6 +745,124 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
745745
746746#endif // LLAMA_USE_CURL
747747
748+ //
749+ // Docker registry functions
750+ //
751+
752+ static std::string common_docker_get_token (const std::string & repo) {
753+ std::string url = " https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + " :pull" ;
754+
755+ common_remote_params params;
756+ auto res = common_remote_get_content (url, params);
757+
758+ if (res.first != 200 ) {
759+ throw std::runtime_error (" Failed to get Docker registry token, HTTP code: " + std::to_string (res.first ));
760+ }
761+
762+ std::string response_str (res.second .begin (), res.second .end ());
763+ nlohmann::ordered_json response = nlohmann::ordered_json::parse (response_str);
764+
765+ if (!response.contains (" token" )) {
766+ throw std::runtime_error (" Docker registry token response missing 'token' field" );
767+ }
768+
769+ return response[" token" ].get <std::string>();
770+ }
771+
772+ static std::string common_docker_resolve_model (const std::string & docker) {
773+ // Parse ai/smollm2:135M-Q4_K_M
774+ size_t colon_pos = docker.find (' :' );
775+ std::string repo, tag;
776+ if (colon_pos != std::string::npos) {
777+ repo = docker.substr (0 , colon_pos);
778+ tag = docker.substr (colon_pos + 1 );
779+ } else {
780+ repo = docker;
781+ tag = " latest" ;
782+ }
783+
784+ // ai/ is the default
785+ size_t slash_pos = docker.find (' /' );
786+ if (slash_pos == std::string::npos) {
787+ repo.insert (0 , " ai/" );
788+ }
789+
790+ LOG_INF (" %s: Downloading Docker Model: %s:%s\n " , __func__, repo.c_str (), tag.c_str ());
791+ try {
792+ // --- helper: digest validation ---
793+ auto validate_oci_digest = [](const std::string & digest) -> std::string {
794+ // Expected: algo:hex ; start with sha256 (64 hex chars)
795+ // You can extend this map if supporting other algorithms in future.
796+ static const std::regex re (" ^sha256:([a-fA-F0-9]{64})$" );
797+ std::smatch m;
798+ if (!std::regex_match (digest, m, re)) {
799+ throw std::runtime_error (" Invalid OCI digest format received in manifest: " + digest);
800+ }
801+ // normalize hex to lowercase
802+ std::string normalized = digest;
803+ std::transform (normalized.begin ()+7 , normalized.end (), normalized.begin ()+7 , [](unsigned char c){
804+ return std::tolower (c);
805+ });
806+ return normalized;
807+ };
808+
809+ std::string token = common_docker_get_token (repo); // Get authentication token
810+
811+ // Get manifest
812+ const std::string url_prefix = " https://registry-1.docker.io/v2/" + repo;
813+ std::string manifest_url = url_prefix + " /manifests/" + tag;
814+ common_remote_params manifest_params;
815+ manifest_params.headers .push_back (" Authorization: Bearer " + token);
816+ manifest_params.headers .push_back (
817+ " Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json" );
818+ auto manifest_res = common_remote_get_content (manifest_url, manifest_params);
819+ if (manifest_res.first != 200 ) {
820+ throw std::runtime_error (" Failed to get Docker manifest, HTTP code: " + std::to_string (manifest_res.first ));
821+ }
822+
823+ std::string manifest_str (manifest_res.second .begin (), manifest_res.second .end ());
824+ nlohmann::ordered_json manifest = nlohmann::ordered_json::parse (manifest_str);
825+ std::string gguf_digest; // Find the GGUF layer
826+ if (manifest.contains (" layers" )) {
827+ for (const auto & layer : manifest[" layers" ]) {
828+ if (layer.contains (" mediaType" )) {
829+ std::string media_type = layer[" mediaType" ].get <std::string>();
830+ if (media_type == " application/vnd.docker.ai.gguf.v3" ||
831+ media_type.find (" gguf" ) != std::string::npos) {
832+ gguf_digest = layer[" digest" ].get <std::string>();
833+ break ;
834+ }
835+ }
836+ }
837+ }
838+
839+ if (gguf_digest.empty ()) {
840+ throw std::runtime_error (" No GGUF layer found in Docker manifest" );
841+ }
842+
843+ // Validate & normalize digest
844+ gguf_digest = validate_oci_digest (gguf_digest);
845+ LOG_DBG (" %s: Using validated digest: %s\n " , __func__, gguf_digest.c_str ());
846+
847+ // Prepare local filename
848+ std::string model_filename = repo;
849+ std::replace (model_filename.begin (), model_filename.end (), ' /' , ' _' );
850+ model_filename += " _" + tag + " .gguf" ;
851+ std::string local_path = fs_get_cache_file (model_filename);
852+
853+ const std::string blob_url = url_prefix + " /blobs/" + gguf_digest;
854+ if (!common_download_file_single (blob_url, local_path, token, false )) {
855+ throw std::runtime_error (" Failed to download Docker Model" );
856+ }
857+
858+ LOG_INF (" %s: Downloaded Docker Model to: %s\n " , __func__, local_path.c_str ());
859+ return local_path;
860+ } catch (const std::exception & e) {
861+ LOG_ERR (" %s: Docker Model download failed: %s\n " , __func__, e.what ());
862+ throw ;
863+ }
864+ }
865+
748866//
749867// utils
750868//
@@ -795,7 +913,9 @@ static handle_model_result common_params_handle_model(
795913 handle_model_result result;
796914 // handle pre-fill default model path and url based on hf_repo and hf_file
797915 {
798- if (!model.hf_repo .empty ()) {
916+ if (!model.docker_repo .empty ()) { // Handle Docker URLs by resolving them to local paths
917+ model.path = common_docker_resolve_model (model.docker_repo );
918+ } else if (!model.hf_repo .empty ()) {
799919 // short-hand to avoid specifying --hf-file -> default it to --model
800920 if (model.hf_file .empty ()) {
801921 if (model.path .empty ()) {
@@ -2636,6 +2756,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
26362756 params.model .url = value;
26372757 }
26382758 ).set_env (" LLAMA_ARG_MODEL_URL" ));
2759+ add_opt (common_arg (
2760+ { " -dr" , " --docker-repo" }, " [<repo>/]<model>[:quant]" ,
2761+ " Docker Hub model repository. repo is optional, default to ai/. quant is optional, default to :latest.\n "
2762+ " example: gemma3\n "
2763+ " (default: unused)" ,
2764+ [](common_params & params, const std::string & value) {
2765+ params.model .docker_repo = value;
2766+ }
2767+ ).set_env (" LLAMA_ARG_DOCKER_REPO" ));
26392768 add_opt (common_arg (
26402769 {" -hf" , " -hfr" , " --hf-repo" }, " <user>/<model>[:quant]" ,
26412770 " Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n "
0 commit comments