|
5 | 5 | #include "gguf.h" // for reading GGUF splits |
6 | 6 | #include "json-schema-to-grammar.h" |
7 | 7 | #include "log.h" |
| 8 | +#ifdef LLAMA_USE_OCI |
| 9 | +#include "oci.h" |
| 10 | +#endif |
8 | 11 | #include "sampling.h" |
9 | 12 |
|
10 | 13 | // fix problem with std::min and std::max |
@@ -1043,119 +1046,42 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_ |
1043 | 1046 | // Docker registry functions |
1044 | 1047 | // |
1045 | 1048 |
|
1046 | | -static std::string common_docker_get_token(const std::string & repo) { |
1047 | | - std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull"; |
1048 | | - |
1049 | | - common_remote_params params; |
1050 | | - auto res = common_remote_get_content(url, params); |
1051 | | - |
1052 | | - if (res.first != 200) { |
1053 | | - throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first)); |
1054 | | - } |
1055 | | - |
1056 | | - std::string response_str(res.second.begin(), res.second.end()); |
1057 | | - nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str); |
1058 | | - |
1059 | | - if (!response.contains("token")) { |
1060 | | - throw std::runtime_error("Docker registry token response missing 'token' field"); |
1061 | | - } |
1062 | | - |
1063 | | - return response["token"].get<std::string>(); |
1064 | | -} |
1065 | | - |
| 1049 | +#ifdef LLAMA_USE_OCI |
1066 | 1050 | static std::string common_docker_resolve_model(const std::string & docker) { |
1067 | | - // Parse ai/smollm2:135M-Q4_0 |
1068 | | - size_t colon_pos = docker.find(':'); |
1069 | | - std::string repo, tag; |
1070 | | - if (colon_pos != std::string::npos) { |
1071 | | - repo = docker.substr(0, colon_pos); |
1072 | | - tag = docker.substr(colon_pos + 1); |
1073 | | - } else { |
1074 | | - repo = docker; |
1075 | | - tag = "latest"; |
1076 | | - } |
| 1051 | + // Parse image reference (e.g., ai/smollm2:135M-Q4_0) |
| 1052 | + std::string image_ref = docker; |
1077 | 1053 |
|
1078 | | - // ai/ is the default |
1079 | | - size_t slash_pos = docker.find('/'); |
| 1054 | + // ai/ is the default namespace for Docker Hub |
| 1055 | + size_t slash_pos = docker.find('/'); |
1080 | 1056 | if (slash_pos == std::string::npos) { |
1081 | | - repo.insert(0, "ai/"); |
| 1057 | + image_ref = "ai/" + docker; |
1082 | 1058 | } |
1083 | 1059 |
|
1084 | | - LOG_INF("%s: Downloading Docker Model: %s:%s\n", __func__, repo.c_str(), tag.c_str()); |
1085 | | - try { |
1086 | | - // --- helper: digest validation --- |
1087 | | - auto validate_oci_digest = [](const std::string & digest) -> std::string { |
1088 | | - // Expected: algo:hex ; start with sha256 (64 hex chars) |
1089 | | - // You can extend this map if supporting other algorithms in future. |
1090 | | - static const std::regex re("^sha256:([a-fA-F0-9]{64})$"); |
1091 | | - std::smatch m; |
1092 | | - if (!std::regex_match(digest, m, re)) { |
1093 | | - throw std::runtime_error("Invalid OCI digest format received in manifest: " + digest); |
1094 | | - } |
1095 | | - // normalize hex to lowercase |
1096 | | - std::string normalized = digest; |
1097 | | - std::transform(normalized.begin()+7, normalized.end(), normalized.begin()+7, [](unsigned char c){ |
1098 | | - return std::tolower(c); |
1099 | | - }); |
1100 | | - return normalized; |
1101 | | - }; |
1102 | | - |
1103 | | - std::string token = common_docker_get_token(repo); // Get authentication token |
1104 | | - |
1105 | | - // Get manifest |
1106 | | - const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo; |
1107 | | - std::string manifest_url = url_prefix + "/manifests/" + tag; |
1108 | | - common_remote_params manifest_params; |
1109 | | - manifest_params.headers.push_back("Authorization: Bearer " + token); |
1110 | | - manifest_params.headers.push_back( |
1111 | | - "Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json"); |
1112 | | - auto manifest_res = common_remote_get_content(manifest_url, manifest_params); |
1113 | | - if (manifest_res.first != 200) { |
1114 | | - throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first)); |
1115 | | - } |
1116 | | - |
1117 | | - std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end()); |
1118 | | - nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str); |
1119 | | - std::string gguf_digest; // Find the GGUF layer |
1120 | | - if (manifest.contains("layers")) { |
1121 | | - for (const auto & layer : manifest["layers"]) { |
1122 | | - if (layer.contains("mediaType")) { |
1123 | | - std::string media_type = layer["mediaType"].get<std::string>(); |
1124 | | - if (media_type == "application/vnd.docker.ai.gguf.v3" || |
1125 | | - media_type.find("gguf") != std::string::npos) { |
1126 | | - gguf_digest = layer["digest"].get<std::string>(); |
1127 | | - break; |
1128 | | - } |
1129 | | - } |
1130 | | - } |
1131 | | - } |
1132 | | - |
1133 | | - if (gguf_digest.empty()) { |
1134 | | - throw std::runtime_error("No GGUF layer found in Docker manifest"); |
1135 | | - } |
| 1060 | + // Add registry prefix if not present |
| 1061 | + if (image_ref.find("registry-1.docker.io/") != 0 && image_ref.find("docker.io/") != 0 && |
| 1062 | + image_ref.find("index.docker.io/") != 0) { |
| 1063 | + // For Docker Hub images without explicit registry |
| 1064 | + image_ref = "index.docker.io/" + image_ref; |
| 1065 | + } |
1136 | 1066 |
|
1137 | | - // Validate & normalize digest |
1138 | | - gguf_digest = validate_oci_digest(gguf_digest); |
1139 | | - LOG_DBG("%s: Using validated digest: %s\n", __func__, gguf_digest.c_str()); |
| 1067 | + try { |
| 1068 | + // Get cache directory |
| 1069 | + std::string cache_dir = fs_get_cache_directory(); |
1140 | 1070 |
|
1141 | | - // Prepare local filename |
1142 | | - std::string model_filename = repo; |
1143 | | - std::replace(model_filename.begin(), model_filename.end(), '/', '_'); |
1144 | | - model_filename += "_" + tag + ".gguf"; |
1145 | | - std::string local_path = fs_get_cache_file(model_filename); |
| 1071 | + // Call the Go OCI library |
| 1072 | + auto result = oci_pull_model(image_ref, cache_dir); |
1146 | 1073 |
|
1147 | | - const std::string blob_url = url_prefix + "/blobs/" + gguf_digest; |
1148 | | - if (!common_download_file_single(blob_url, local_path, token, false)) { |
1149 | | - throw std::runtime_error("Failed to download Docker Model"); |
| 1074 | + if (!result.success()) { |
| 1075 | + throw std::runtime_error("OCI pull failed: " + result.error_message); |
1150 | 1076 | } |
1151 | 1077 |
|
1152 | | - LOG_INF("%s: Downloaded Docker Model to: %s\n", __func__, local_path.c_str()); |
1153 | | - return local_path; |
| 1078 | + return result.local_path; |
1154 | 1079 | } catch (const std::exception & e) { |
1155 | | - LOG_ERR("%s: Docker Model download failed: %s\n", __func__, e.what()); |
| 1080 | + LOG_ERR("%s: OCI model download failed: %s\n", __func__, e.what()); |
1156 | 1081 | throw; |
1157 | 1082 | } |
1158 | 1083 | } |
| 1084 | +#endif // LLAMA_USE_OCI |
1159 | 1085 |
|
1160 | 1086 | // |
1161 | 1087 | // utils |
@@ -1208,7 +1134,11 @@ static handle_model_result common_params_handle_model( |
1208 | 1134 | // handle pre-fill default model path and url based on hf_repo and hf_file |
1209 | 1135 | { |
1210 | 1136 | if (!model.docker_repo.empty()) { // Handle Docker URLs by resolving them to local paths |
| 1137 | +#ifdef LLAMA_USE_OCI |
1211 | 1138 | model.path = common_docker_resolve_model(model.docker_repo); |
| 1139 | +#else |
| 1140 | + LOG_ERR("Need to build with go compiler and LLAMA_USE_OCI\n"); |
| 1141 | +#endif |
1212 | 1142 | } else if (!model.hf_repo.empty()) { |
1213 | 1143 | // short-hand to avoid specifying --hf-file -> default it to --model |
1214 | 1144 | if (model.hf_file.empty()) { |
|
0 commit comments