Skip to content

Commit 3e66a68

Browse files
committed
Add Go OCI library integration with go-containerregistry
So we can pull from any OCI registry, add authentication, etc. Add docker-style progress bars and resumable downloads to OCI pulls Update documentation with progress bars and resumable downloads info Make OCI Go build optional and skip editorconfig for oci-go Signed-off-by: Eric Curtin <[email protected]>
1 parent ee09828 commit 3e66a68

File tree

11 files changed

+851
-103
lines changed

11 files changed

+851
-103
lines changed

.ecrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"Exclude": ["^\\.gitmodules$", "stb_image\\.h"],
2+
"Exclude": ["^\\.gitmodules$", "stb_image\\.h", "oci-go/"],
33
"Disable": {
44
"IndentSize": true
55
}

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
*.swp
2222
*.tmp
2323

24+
# OCI Go generated files
25+
oci-go/liboci.h
26+
2427
# IDE / OS
2528

2629
.cache/

common/CMakeLists.txt

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ endif()
4444

4545
set(TARGET common)
4646

47-
add_library(${TARGET} STATIC
47+
set(COMMON_SOURCES
4848
arg.cpp
4949
arg.h
5050
base64.hpp
@@ -73,11 +73,46 @@ add_library(${TARGET} STATIC
7373
speculative.h
7474
)
7575

76+
add_library(${TARGET} STATIC ${COMMON_SOURCES})
77+
7678
if (BUILD_SHARED_LIBS)
7779
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
7880
endif()
7981

80-
set(LLAMA_COMMON_EXTRA_LIBS build_info)
82+
# Build OCI Go library
83+
find_program(GO_EXECUTABLE go)
84+
if (GO_EXECUTABLE)
85+
set(OCI_GO_DIR ${CMAKE_SOURCE_DIR}/oci-go)
86+
set(OCI_LIB ${OCI_GO_DIR}/liboci.a)
87+
set(OCI_HEADER ${OCI_GO_DIR}/liboci.h)
88+
89+
add_custom_command(
90+
OUTPUT ${OCI_LIB} ${OCI_HEADER}
91+
COMMAND ${GO_EXECUTABLE} build -buildmode=c-archive -o ${OCI_LIB} ${OCI_GO_DIR}/oci.go
92+
WORKING_DIRECTORY ${OCI_GO_DIR}
93+
DEPENDS ${OCI_GO_DIR}/oci.go ${OCI_GO_DIR}/go.mod
94+
COMMENT "Building OCI Go library"
95+
)
96+
97+
add_custom_target(oci_go_lib DEPENDS ${OCI_LIB} ${OCI_HEADER})
98+
add_dependencies(${TARGET} oci_go_lib)
99+
100+
target_include_directories(${TARGET} PRIVATE ${OCI_GO_DIR})
101+
target_sources(${TARGET} PRIVATE oci.cpp oci.h)
102+
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_OCI)
103+
set(LLAMA_COMMON_EXTRA_LIBS build_info ${OCI_LIB})
104+
105+
# On macOS, the Go runtime requires CoreFoundation and Security frameworks
106+
if (APPLE)
107+
find_library(OCI_CORE_FOUNDATION_FRAMEWORK CoreFoundation REQUIRED)
108+
find_library(OCI_SECURITY_FRAMEWORK Security REQUIRED)
109+
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${OCI_CORE_FOUNDATION_FRAMEWORK} ${OCI_SECURITY_FRAMEWORK})
110+
endif()
111+
else()
112+
message(WARNING "Go compiler not found. OCI functionality will not be available.")
113+
set(LLAMA_COMMON_EXTRA_LIBS build_info)
114+
endif()
115+
81116

82117
# Use curl to download model url
83118
if (LLAMA_CURL)
@@ -172,7 +207,7 @@ endif ()
172207

173208
target_include_directories(${TARGET} PUBLIC . ../vendor)
174209
target_compile_features (${TARGET} PUBLIC cxx_std_17)
175-
target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads)
210+
target_link_libraries (${TARGET} PUBLIC ${LLAMA_COMMON_EXTRA_LIBS} llama Threads::Threads)
176211

177212

178213
#

common/arg.cpp

Lines changed: 29 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
#include "gguf.h" // for reading GGUF splits
66
#include "json-schema-to-grammar.h"
77
#include "log.h"
8+
#ifdef LLAMA_USE_OCI
9+
#include "oci.h"
10+
#endif
811
#include "sampling.h"
912

1013
// fix problem with std::min and std::max
@@ -1043,119 +1046,42 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_
10431046
// Docker registry functions
10441047
//
10451048

1046-
static std::string common_docker_get_token(const std::string & repo) {
1047-
std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull";
1048-
1049-
common_remote_params params;
1050-
auto res = common_remote_get_content(url, params);
1051-
1052-
if (res.first != 200) {
1053-
throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first));
1054-
}
1055-
1056-
std::string response_str(res.second.begin(), res.second.end());
1057-
nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str);
1058-
1059-
if (!response.contains("token")) {
1060-
throw std::runtime_error("Docker registry token response missing 'token' field");
1061-
}
1062-
1063-
return response["token"].get<std::string>();
1064-
}
1065-
1049+
#ifdef LLAMA_USE_OCI
10661050
static std::string common_docker_resolve_model(const std::string & docker) {
1067-
// Parse ai/smollm2:135M-Q4_0
1068-
size_t colon_pos = docker.find(':');
1069-
std::string repo, tag;
1070-
if (colon_pos != std::string::npos) {
1071-
repo = docker.substr(0, colon_pos);
1072-
tag = docker.substr(colon_pos + 1);
1073-
} else {
1074-
repo = docker;
1075-
tag = "latest";
1076-
}
1051+
// Parse image reference (e.g., ai/smollm2:135M-Q4_0)
1052+
std::string image_ref = docker;
10771053

1078-
// ai/ is the default
1079-
size_t slash_pos = docker.find('/');
1054+
// ai/ is the default namespace for Docker Hub
1055+
size_t slash_pos = docker.find('/');
10801056
if (slash_pos == std::string::npos) {
1081-
repo.insert(0, "ai/");
1057+
image_ref = "ai/" + docker;
10821058
}
10831059

1084-
LOG_INF("%s: Downloading Docker Model: %s:%s\n", __func__, repo.c_str(), tag.c_str());
1085-
try {
1086-
// --- helper: digest validation ---
1087-
auto validate_oci_digest = [](const std::string & digest) -> std::string {
1088-
// Expected: algo:hex ; start with sha256 (64 hex chars)
1089-
// You can extend this map if supporting other algorithms in future.
1090-
static const std::regex re("^sha256:([a-fA-F0-9]{64})$");
1091-
std::smatch m;
1092-
if (!std::regex_match(digest, m, re)) {
1093-
throw std::runtime_error("Invalid OCI digest format received in manifest: " + digest);
1094-
}
1095-
// normalize hex to lowercase
1096-
std::string normalized = digest;
1097-
std::transform(normalized.begin()+7, normalized.end(), normalized.begin()+7, [](unsigned char c){
1098-
return std::tolower(c);
1099-
});
1100-
return normalized;
1101-
};
1102-
1103-
std::string token = common_docker_get_token(repo); // Get authentication token
1104-
1105-
// Get manifest
1106-
const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo;
1107-
std::string manifest_url = url_prefix + "/manifests/" + tag;
1108-
common_remote_params manifest_params;
1109-
manifest_params.headers.push_back("Authorization: Bearer " + token);
1110-
manifest_params.headers.push_back(
1111-
"Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json");
1112-
auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
1113-
if (manifest_res.first != 200) {
1114-
throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
1115-
}
1116-
1117-
std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end());
1118-
nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str);
1119-
std::string gguf_digest; // Find the GGUF layer
1120-
if (manifest.contains("layers")) {
1121-
for (const auto & layer : manifest["layers"]) {
1122-
if (layer.contains("mediaType")) {
1123-
std::string media_type = layer["mediaType"].get<std::string>();
1124-
if (media_type == "application/vnd.docker.ai.gguf.v3" ||
1125-
media_type.find("gguf") != std::string::npos) {
1126-
gguf_digest = layer["digest"].get<std::string>();
1127-
break;
1128-
}
1129-
}
1130-
}
1131-
}
1132-
1133-
if (gguf_digest.empty()) {
1134-
throw std::runtime_error("No GGUF layer found in Docker manifest");
1135-
}
1060+
// Add registry prefix if not present
1061+
if (image_ref.find("registry-1.docker.io/") != 0 && image_ref.find("docker.io/") != 0 &&
1062+
image_ref.find("index.docker.io/") != 0) {
1063+
// For Docker Hub images without explicit registry
1064+
image_ref = "index.docker.io/" + image_ref;
1065+
}
11361066

1137-
// Validate & normalize digest
1138-
gguf_digest = validate_oci_digest(gguf_digest);
1139-
LOG_DBG("%s: Using validated digest: %s\n", __func__, gguf_digest.c_str());
1067+
try {
1068+
// Get cache directory
1069+
std::string cache_dir = fs_get_cache_directory();
11401070

1141-
// Prepare local filename
1142-
std::string model_filename = repo;
1143-
std::replace(model_filename.begin(), model_filename.end(), '/', '_');
1144-
model_filename += "_" + tag + ".gguf";
1145-
std::string local_path = fs_get_cache_file(model_filename);
1071+
// Call the Go OCI library
1072+
auto result = oci_pull_model(image_ref, cache_dir);
11461073

1147-
const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
1148-
if (!common_download_file_single(blob_url, local_path, token, false)) {
1149-
throw std::runtime_error("Failed to download Docker Model");
1074+
if (!result.success()) {
1075+
throw std::runtime_error("OCI pull failed: " + result.error_message);
11501076
}
11511077

1152-
LOG_INF("%s: Downloaded Docker Model to: %s\n", __func__, local_path.c_str());
1153-
return local_path;
1078+
return result.local_path;
11541079
} catch (const std::exception & e) {
1155-
LOG_ERR("%s: Docker Model download failed: %s\n", __func__, e.what());
1080+
LOG_ERR("%s: OCI model download failed: %s\n", __func__, e.what());
11561081
throw;
11571082
}
11581083
}
1084+
#endif // LLAMA_USE_OCI
11591085

11601086
//
11611087
// utils
@@ -1208,7 +1134,11 @@ static handle_model_result common_params_handle_model(
12081134
// handle pre-fill default model path and url based on hf_repo and hf_file
12091135
{
12101136
if (!model.docker_repo.empty()) { // Handle Docker URLs by resolving them to local paths
1137+
#ifdef LLAMA_USE_OCI
12111138
model.path = common_docker_resolve_model(model.docker_repo);
1139+
#else
1140+
LOG_ERR("Need to build with go compiler and LLAMA_USE_OCI\n");
1141+
#endif
12121142
} else if (!model.hf_repo.empty()) {
12131143
// short-hand to avoid specifying --hf-file -> default it to --model
12141144
if (model.hf_file.empty()) {

common/oci.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#ifdef LLAMA_USE_OCI
2+
3+
#include "oci.h"
4+
5+
#include "log.h"
6+
7+
#include <nlohmann/json.hpp>
8+
9+
// Include the Go-generated header
10+
extern "C" {
11+
#include "../oci-go/liboci.h"
12+
}
13+
14+
using json = nlohmann::ordered_json;
15+
16+
oci_pull_result oci_pull_model(const std::string & imageRef, const std::string & cacheDir) {
17+
oci_pull_result result;
18+
result.error_code = 0;
19+
20+
// Call the Go function
21+
char * json_result = PullOCIModel(const_cast<char *>(imageRef.c_str()), const_cast<char *>(cacheDir.c_str()));
22+
23+
if (json_result == nullptr) {
24+
result.error_code = 1;
25+
result.error_message = "Failed to call OCI pull function";
26+
return result;
27+
}
28+
29+
try {
30+
// Parse the JSON result
31+
std::string json_str(json_result);
32+
auto j = json::parse(json_str);
33+
34+
if (j.contains("LocalPath")) {
35+
result.local_path = j["LocalPath"].get<std::string>();
36+
}
37+
if (j.contains("Digest")) {
38+
result.digest = j["Digest"].get<std::string>();
39+
}
40+
if (j.contains("Error") && !j["Error"].is_null()) {
41+
auto err = j["Error"];
42+
if (err.contains("Code")) {
43+
result.error_code = err["Code"].get<int>();
44+
}
45+
if (err.contains("Message")) {
46+
result.error_message = err["Message"].get<std::string>();
47+
}
48+
}
49+
} catch (const std::exception & e) {
50+
result.error_code = 1;
51+
result.error_message = std::string("Failed to parse result: ") + e.what();
52+
}
53+
54+
// Free the Go-allocated string
55+
FreeString(json_result);
56+
57+
return result;
58+
}
59+
60+
#endif // LLAMA_USE_OCI

common/oci.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#pragma once
2+
3+
#ifdef LLAMA_USE_OCI
4+
5+
#include <string>
6+
7+
// Structure to hold OCI pull results
8+
struct oci_pull_result {
9+
std::string local_path;
10+
std::string digest;
11+
int error_code;
12+
std::string error_message;
13+
14+
bool success() const { return error_code == 0; }
15+
};
16+
17+
// Pull a model from an OCI registry
18+
// imageRef: full image reference (e.g., "ai/smollm2:135M-Q4_0", "registry.io/user/model:tag")
19+
// cacheDir: directory to cache downloaded models
20+
oci_pull_result oci_pull_model(const std::string & imageRef, const std::string & cacheDir);
21+
22+
#endif // LLAMA_USE_OCI

0 commit comments

Comments
 (0)