Skip to content

Commit e5bd64a

Browse files
committed
Add Go OCI library integration with go-containerregistry
So we can pull from any OCI registry, add authentication, etc. Add docker-style progress bars and resumable downloads to OCI pulls Update documentation with progress bars and resumable downloads info Signed-off-by: Eric Curtin <[email protected]>
1 parent ee09828 commit e5bd64a

File tree

10 files changed

+830
-100
lines changed

10 files changed

+830
-100
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
*.swp
2222
*.tmp
2323

24+
# OCI Go generated files
25+
oci-go/liboci.h
26+
2427
# IDE / OS
2528

2629
.cache/

common/CMakeLists.txt

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ add_library(${TARGET} STATIC
6565
log.h
6666
ngram-cache.cpp
6767
ngram-cache.h
68+
oci.cpp
69+
oci.h
6870
regex-partial.cpp
6971
regex-partial.h
7072
sampling.cpp
@@ -77,7 +79,38 @@ if (BUILD_SHARED_LIBS)
7779
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
7880
endif()
7981

80-
set(LLAMA_COMMON_EXTRA_LIBS build_info)
82+
# Build OCI Go library
83+
find_program(GO_EXECUTABLE go)
84+
if (GO_EXECUTABLE)
85+
set(OCI_GO_DIR ${CMAKE_SOURCE_DIR}/oci-go)
86+
set(OCI_LIB ${OCI_GO_DIR}/liboci.a)
87+
set(OCI_HEADER ${OCI_GO_DIR}/liboci.h)
88+
89+
add_custom_command(
90+
OUTPUT ${OCI_LIB} ${OCI_HEADER}
91+
COMMAND ${GO_EXECUTABLE} build -buildmode=c-archive -o ${OCI_LIB} ${OCI_GO_DIR}/oci.go
92+
WORKING_DIRECTORY ${OCI_GO_DIR}
93+
DEPENDS ${OCI_GO_DIR}/oci.go ${OCI_GO_DIR}/go.mod
94+
COMMENT "Building OCI Go library"
95+
)
96+
97+
add_custom_target(oci_go_lib DEPENDS ${OCI_LIB} ${OCI_HEADER})
98+
add_dependencies(${TARGET} oci_go_lib)
99+
100+
target_include_directories(${TARGET} PRIVATE ${OCI_GO_DIR})
101+
set(LLAMA_COMMON_EXTRA_LIBS build_info ${OCI_LIB})
102+
103+
# On macOS, the Go runtime requires CoreFoundation and Security frameworks
104+
if (APPLE)
105+
find_library(OCI_CORE_FOUNDATION_FRAMEWORK CoreFoundation REQUIRED)
106+
find_library(OCI_SECURITY_FRAMEWORK Security REQUIRED)
107+
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${OCI_CORE_FOUNDATION_FRAMEWORK} ${OCI_SECURITY_FRAMEWORK})
108+
endif()
109+
else()
110+
message(WARNING "Go compiler not found. OCI functionality will not be available.")
111+
set(LLAMA_COMMON_EXTRA_LIBS build_info)
112+
endif()
113+
81114

82115
# Use curl to download model url
83116
if (LLAMA_CURL)

common/arg.cpp

Lines changed: 21 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "gguf.h" // for reading GGUF splits
66
#include "json-schema-to-grammar.h"
77
#include "log.h"
8+
#include "oci.h"
89
#include "sampling.h"
910

1011
// fix problem with std::min and std::max
@@ -1043,116 +1044,37 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_
10431044
// Docker registry functions
10441045
//
10451046

1046-
static std::string common_docker_get_token(const std::string & repo) {
1047-
std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull";
1048-
1049-
common_remote_params params;
1050-
auto res = common_remote_get_content(url, params);
1051-
1052-
if (res.first != 200) {
1053-
throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first));
1054-
}
1055-
1056-
std::string response_str(res.second.begin(), res.second.end());
1057-
nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str);
1058-
1059-
if (!response.contains("token")) {
1060-
throw std::runtime_error("Docker registry token response missing 'token' field");
1061-
}
1062-
1063-
return response["token"].get<std::string>();
1064-
}
1065-
10661047
static std::string common_docker_resolve_model(const std::string & docker) {
1067-
// Parse ai/smollm2:135M-Q4_0
1068-
size_t colon_pos = docker.find(':');
1069-
std::string repo, tag;
1070-
if (colon_pos != std::string::npos) {
1071-
repo = docker.substr(0, colon_pos);
1072-
tag = docker.substr(colon_pos + 1);
1073-
} else {
1074-
repo = docker;
1075-
tag = "latest";
1076-
}
1048+
// Parse image reference (e.g., ai/smollm2:135M-Q4_0)
1049+
std::string image_ref = docker;
10771050

1078-
// ai/ is the default
1079-
size_t slash_pos = docker.find('/');
1051+
// ai/ is the default namespace for Docker Hub
1052+
size_t slash_pos = docker.find('/');
10801053
if (slash_pos == std::string::npos) {
1081-
repo.insert(0, "ai/");
1054+
image_ref = "ai/" + docker;
10821055
}
10831056

1084-
LOG_INF("%s: Downloading Docker Model: %s:%s\n", __func__, repo.c_str(), tag.c_str());
1085-
try {
1086-
// --- helper: digest validation ---
1087-
auto validate_oci_digest = [](const std::string & digest) -> std::string {
1088-
// Expected: algo:hex ; start with sha256 (64 hex chars)
1089-
// You can extend this map if supporting other algorithms in future.
1090-
static const std::regex re("^sha256:([a-fA-F0-9]{64})$");
1091-
std::smatch m;
1092-
if (!std::regex_match(digest, m, re)) {
1093-
throw std::runtime_error("Invalid OCI digest format received in manifest: " + digest);
1094-
}
1095-
// normalize hex to lowercase
1096-
std::string normalized = digest;
1097-
std::transform(normalized.begin()+7, normalized.end(), normalized.begin()+7, [](unsigned char c){
1098-
return std::tolower(c);
1099-
});
1100-
return normalized;
1101-
};
1102-
1103-
std::string token = common_docker_get_token(repo); // Get authentication token
1104-
1105-
// Get manifest
1106-
const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo;
1107-
std::string manifest_url = url_prefix + "/manifests/" + tag;
1108-
common_remote_params manifest_params;
1109-
manifest_params.headers.push_back("Authorization: Bearer " + token);
1110-
manifest_params.headers.push_back(
1111-
"Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json");
1112-
auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
1113-
if (manifest_res.first != 200) {
1114-
throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
1115-
}
1116-
1117-
std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end());
1118-
nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str);
1119-
std::string gguf_digest; // Find the GGUF layer
1120-
if (manifest.contains("layers")) {
1121-
for (const auto & layer : manifest["layers"]) {
1122-
if (layer.contains("mediaType")) {
1123-
std::string media_type = layer["mediaType"].get<std::string>();
1124-
if (media_type == "application/vnd.docker.ai.gguf.v3" ||
1125-
media_type.find("gguf") != std::string::npos) {
1126-
gguf_digest = layer["digest"].get<std::string>();
1127-
break;
1128-
}
1129-
}
1130-
}
1131-
}
1132-
1133-
if (gguf_digest.empty()) {
1134-
throw std::runtime_error("No GGUF layer found in Docker manifest");
1135-
}
1057+
// Add registry prefix if not present
1058+
if (image_ref.find("registry-1.docker.io/") != 0 && image_ref.find("docker.io/") != 0 &&
1059+
image_ref.find("index.docker.io/") != 0) {
1060+
// For Docker Hub images without explicit registry
1061+
image_ref = "index.docker.io/" + image_ref;
1062+
}
11361063

1137-
// Validate & normalize digest
1138-
gguf_digest = validate_oci_digest(gguf_digest);
1139-
LOG_DBG("%s: Using validated digest: %s\n", __func__, gguf_digest.c_str());
1064+
try {
1065+
// Get cache directory
1066+
std::string cache_dir = fs_get_cache_directory();
11401067

1141-
// Prepare local filename
1142-
std::string model_filename = repo;
1143-
std::replace(model_filename.begin(), model_filename.end(), '/', '_');
1144-
model_filename += "_" + tag + ".gguf";
1145-
std::string local_path = fs_get_cache_file(model_filename);
1068+
// Call the Go OCI library
1069+
auto result = oci_pull_model(image_ref, cache_dir);
11461070

1147-
const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
1148-
if (!common_download_file_single(blob_url, local_path, token, false)) {
1149-
throw std::runtime_error("Failed to download Docker Model");
1071+
if (!result.success()) {
1072+
throw std::runtime_error("OCI pull failed: " + result.error_message);
11501073
}
11511074

1152-
LOG_INF("%s: Downloaded Docker Model to: %s\n", __func__, local_path.c_str());
1153-
return local_path;
1075+
return result.local_path;
11541076
} catch (const std::exception & e) {
1155-
LOG_ERR("%s: Docker Model download failed: %s\n", __func__, e.what());
1077+
LOG_ERR("%s: OCI model download failed: %s\n", __func__, e.what());
11561078
throw;
11571079
}
11581080
}

common/oci.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#include "oci.h"
2+
3+
#include "log.h"
4+
5+
#include <nlohmann/json.hpp>
6+
7+
// Include the Go-generated header
8+
extern "C" {
9+
#include "../oci-go/liboci.h"
10+
}
11+
12+
using json = nlohmann::ordered_json;
13+
14+
oci_pull_result oci_pull_model(const std::string & imageRef, const std::string & cacheDir) {
15+
oci_pull_result result;
16+
result.error_code = 0;
17+
18+
// Call the Go function
19+
char * json_result = PullOCIModel(const_cast<char *>(imageRef.c_str()), const_cast<char *>(cacheDir.c_str()));
20+
21+
if (json_result == nullptr) {
22+
result.error_code = 1;
23+
result.error_message = "Failed to call OCI pull function";
24+
return result;
25+
}
26+
27+
try {
28+
// Parse the JSON result
29+
std::string json_str(json_result);
30+
auto j = json::parse(json_str);
31+
32+
if (j.contains("LocalPath")) {
33+
result.local_path = j["LocalPath"].get<std::string>();
34+
}
35+
if (j.contains("Digest")) {
36+
result.digest = j["Digest"].get<std::string>();
37+
}
38+
if (j.contains("Error") && !j["Error"].is_null()) {
39+
auto err = j["Error"];
40+
if (err.contains("Code")) {
41+
result.error_code = err["Code"].get<int>();
42+
}
43+
if (err.contains("Message")) {
44+
result.error_message = err["Message"].get<std::string>();
45+
}
46+
}
47+
} catch (const std::exception & e) {
48+
result.error_code = 1;
49+
result.error_message = std::string("Failed to parse result: ") + e.what();
50+
}
51+
52+
// Free the Go-allocated string
53+
FreeString(json_result);
54+
55+
return result;
56+
}

common/oci.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#pragma once
2+
3+
#include <string>
4+
5+
// Structure to hold OCI pull results
6+
struct oci_pull_result {
7+
std::string local_path;
8+
std::string digest;
9+
int error_code;
10+
std::string error_message;
11+
12+
bool success() const { return error_code == 0; }
13+
};
14+
15+
// Pull a model from an OCI registry
16+
// imageRef: full image reference (e.g., "ai/smollm2:135M-Q4_0", "registry.io/user/model:tag")
17+
// cacheDir: directory to cache downloaded models
18+
oci_pull_result oci_pull_model(const std::string & imageRef, const std::string & cacheDir);

0 commit comments

Comments
 (0)