Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ecrc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"Exclude": ["^\\.gitmodules$", "stb_image\\.h"],
"Exclude": ["^\\.gitmodules$", "stb_image\\.h", "oci-go/"],
"Disable": {
"IndentSize": true
}
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
*.swp
*.tmp

# OCI Go generated files
oci-go/liboci.h

# IDE / OS

.cache/
Expand Down
70 changes: 69 additions & 1 deletion common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,75 @@ if (BUILD_SHARED_LIBS)
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif()

set(LLAMA_COMMON_EXTRA_LIBS build_info)
# Build OCI Go library
find_program(GO_EXECUTABLE go)
if (GO_EXECUTABLE)
# Check Go version - we need at least 1.21 for toolchain directive support
set(GO_VERSION_RESULT 1)
execute_process(
COMMAND ${GO_EXECUTABLE} version
OUTPUT_VARIABLE GO_VERSION_OUTPUT
OUTPUT_STRIP_TRAILING_WHITESPACE
RESULT_VARIABLE GO_VERSION_RESULT
)

if (GO_VERSION_RESULT EQUAL 0)
# Extract version number from "go version go1.X.Y ..." or "go version go1.X ..."
string(REGEX MATCH "go([0-9]+)\\.([0-9]+)" GO_VERSION_MATCH "${GO_VERSION_OUTPUT}")
if (GO_VERSION_MATCH)
set(GO_VERSION_MAJOR ${CMAKE_MATCH_1})
set(GO_VERSION_MINOR ${CMAKE_MATCH_2})

if (GO_VERSION_MAJOR LESS 1 OR (GO_VERSION_MAJOR EQUAL 1 AND GO_VERSION_MINOR LESS 21))
message(WARNING "Go version ${GO_VERSION_MAJOR}.${GO_VERSION_MINOR} is too old. OCI functionality requires Go 1.21 or later. OCI functionality will not be available.")
set(GO_VERSION_OK FALSE)
else()
set(GO_VERSION_OK TRUE)
endif()
else()
message(WARNING "Unable to parse Go version from: ${GO_VERSION_OUTPUT}. OCI functionality will not be available.")
set(GO_VERSION_OK FALSE)
endif()
else()
message(WARNING "Failed to get Go version. OCI functionality will not be available.")
set(GO_VERSION_OK FALSE)
endif()
endif()

if (GO_EXECUTABLE AND GO_VERSION_OK)
set(OCI_GO_DIR ${CMAKE_SOURCE_DIR}/oci-go)
set(OCI_LIB ${OCI_GO_DIR}/liboci.a)
set(OCI_HEADER ${OCI_GO_DIR}/liboci.h)

add_custom_command(
OUTPUT ${OCI_LIB} ${OCI_HEADER}
COMMAND ${GO_EXECUTABLE} build -buildmode=c-archive -o ${OCI_LIB} ${OCI_GO_DIR}/oci.go
WORKING_DIRECTORY ${OCI_GO_DIR}
DEPENDS ${OCI_GO_DIR}/oci.go ${OCI_GO_DIR}/go.mod
COMMENT "Building OCI Go library"
)

add_custom_target(oci_go_lib DEPENDS ${OCI_LIB} ${OCI_HEADER})
add_dependencies(${TARGET} oci_go_lib)

target_include_directories(${TARGET} PRIVATE ${OCI_GO_DIR})
target_sources(${TARGET} PRIVATE oci.cpp oci.h)
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_OCI)
set(LLAMA_COMMON_EXTRA_LIBS build_info ${OCI_LIB})

# On macOS, the Go runtime requires CoreFoundation and Security frameworks
if (APPLE)
find_library(OCI_CORE_FOUNDATION_FRAMEWORK CoreFoundation REQUIRED)
find_library(OCI_SECURITY_FRAMEWORK Security REQUIRED)
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${OCI_CORE_FOUNDATION_FRAMEWORK} ${OCI_SECURITY_FRAMEWORK})
endif()
else()
if (NOT GO_EXECUTABLE)
message(WARNING "Go compiler not found. OCI functionality will not be available.")
endif()
set(LLAMA_COMMON_EXTRA_LIBS build_info)
endif()


# Use curl to download model url
if (LLAMA_CURL)
Expand Down
128 changes: 29 additions & 99 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
#include "gguf.h" // for reading GGUF splits
#include "json-schema-to-grammar.h"
#include "log.h"
#ifdef LLAMA_USE_OCI
#include "oci.h"
#endif
#include "sampling.h"

// fix problem with std::min and std::max
Expand Down Expand Up @@ -1043,119 +1046,42 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_
// Docker registry functions
//

static std::string common_docker_get_token(const std::string & repo) {
std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull";

common_remote_params params;
auto res = common_remote_get_content(url, params);

if (res.first != 200) {
throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first));
}

std::string response_str(res.second.begin(), res.second.end());
nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str);

if (!response.contains("token")) {
throw std::runtime_error("Docker registry token response missing 'token' field");
}

return response["token"].get<std::string>();
}

#ifdef LLAMA_USE_OCI
static std::string common_docker_resolve_model(const std::string & docker) {
// Parse ai/smollm2:135M-Q4_0
size_t colon_pos = docker.find(':');
std::string repo, tag;
if (colon_pos != std::string::npos) {
repo = docker.substr(0, colon_pos);
tag = docker.substr(colon_pos + 1);
} else {
repo = docker;
tag = "latest";
}
// Parse image reference (e.g., ai/smollm2:135M-Q4_0)
std::string image_ref = docker;

// ai/ is the default
size_t slash_pos = docker.find('/');
// ai/ is the default namespace for Docker Hub
size_t slash_pos = docker.find('/');
if (slash_pos == std::string::npos) {
repo.insert(0, "ai/");
image_ref = "ai/" + docker;
}

LOG_INF("%s: Downloading Docker Model: %s:%s\n", __func__, repo.c_str(), tag.c_str());
try {
// --- helper: digest validation ---
auto validate_oci_digest = [](const std::string & digest) -> std::string {
// Expected: algo:hex ; start with sha256 (64 hex chars)
// You can extend this map if supporting other algorithms in future.
static const std::regex re("^sha256:([a-fA-F0-9]{64})$");
std::smatch m;
if (!std::regex_match(digest, m, re)) {
throw std::runtime_error("Invalid OCI digest format received in manifest: " + digest);
}
// normalize hex to lowercase
std::string normalized = digest;
std::transform(normalized.begin()+7, normalized.end(), normalized.begin()+7, [](unsigned char c){
return std::tolower(c);
});
return normalized;
};

std::string token = common_docker_get_token(repo); // Get authentication token

// Get manifest
const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo;
std::string manifest_url = url_prefix + "/manifests/" + tag;
common_remote_params manifest_params;
manifest_params.headers.push_back("Authorization: Bearer " + token);
manifest_params.headers.push_back(
"Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json");
auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
if (manifest_res.first != 200) {
throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
}

std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end());
nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str);
std::string gguf_digest; // Find the GGUF layer
if (manifest.contains("layers")) {
for (const auto & layer : manifest["layers"]) {
if (layer.contains("mediaType")) {
std::string media_type = layer["mediaType"].get<std::string>();
if (media_type == "application/vnd.docker.ai.gguf.v3" ||
media_type.find("gguf") != std::string::npos) {
gguf_digest = layer["digest"].get<std::string>();
break;
}
}
}
}

if (gguf_digest.empty()) {
throw std::runtime_error("No GGUF layer found in Docker manifest");
}
// Add registry prefix if not present
if (image_ref.find("registry-1.docker.io/") != 0 && image_ref.find("docker.io/") != 0 &&
image_ref.find("index.docker.io/") != 0) {
// For Docker Hub images without explicit registry
image_ref = "index.docker.io/" + image_ref;
}

// Validate & normalize digest
gguf_digest = validate_oci_digest(gguf_digest);
LOG_DBG("%s: Using validated digest: %s\n", __func__, gguf_digest.c_str());
try {
// Get cache directory
std::string cache_dir = fs_get_cache_directory();

// Prepare local filename
std::string model_filename = repo;
std::replace(model_filename.begin(), model_filename.end(), '/', '_');
model_filename += "_" + tag + ".gguf";
std::string local_path = fs_get_cache_file(model_filename);
// Call the Go OCI library
auto result = oci_pull_model(image_ref, cache_dir);

const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
if (!common_download_file_single(blob_url, local_path, token, false)) {
throw std::runtime_error("Failed to download Docker Model");
if (!result.success()) {
throw std::runtime_error("OCI pull failed: " + result.error_message);
}

LOG_INF("%s: Downloaded Docker Model to: %s\n", __func__, local_path.c_str());
return local_path;
return result.local_path;
} catch (const std::exception & e) {
LOG_ERR("%s: Docker Model download failed: %s\n", __func__, e.what());
LOG_ERR("%s: OCI model download failed: %s\n", __func__, e.what());
throw;
}
}
#endif // LLAMA_USE_OCI

//
// utils
Expand Down Expand Up @@ -1208,7 +1134,11 @@ static handle_model_result common_params_handle_model(
// handle pre-fill default model path and url based on hf_repo and hf_file
{
if (!model.docker_repo.empty()) { // Handle Docker URLs by resolving them to local paths
#ifdef LLAMA_USE_OCI
model.path = common_docker_resolve_model(model.docker_repo);
#else
LOG_ERR("Need to build with go compiler and LLAMA_USE_OCI\n");
#endif
} else if (!model.hf_repo.empty()) {
// short-hand to avoid specifying --hf-file -> default it to --model
if (model.hf_file.empty()) {
Expand Down
58 changes: 58 additions & 0 deletions common/oci.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#ifdef LLAMA_USE_OCI

# include "oci.h"

# include "log.h"

# include <nlohmann/json.hpp>

// Include the Go-generated header
# include "../oci-go/liboci.h"

using json = nlohmann::ordered_json;

oci_pull_result oci_pull_model(const std::string & imageRef, const std::string & cacheDir) {
oci_pull_result result;
result.error_code = 0;

// Call the Go function
char * json_result = PullOCIModel(const_cast<char *>(imageRef.c_str()), const_cast<char *>(cacheDir.c_str()));

if (json_result == nullptr) {
result.error_code = 1;
result.error_message = "Failed to call OCI pull function";
return result;
}

try {
// Parse the JSON result
std::string json_str(json_result);
auto j = json::parse(json_str);

if (j.contains("LocalPath")) {
result.local_path = j["LocalPath"].get<std::string>();
}
if (j.contains("Digest")) {
result.digest = j["Digest"].get<std::string>();
}
if (j.contains("Error") && !j["Error"].is_null()) {
auto err = j["Error"];
if (err.contains("Code")) {
result.error_code = err["Code"].get<int>();
}
if (err.contains("Message")) {
result.error_message = err["Message"].get<std::string>();
}
}
} catch (const std::exception & e) {
result.error_code = 1;
result.error_message = std::string("Failed to parse result: ") + e.what();
}

// Free the Go-allocated string
FreeString(json_result);

return result;
}

#endif // LLAMA_USE_OCI
22 changes: 22 additions & 0 deletions common/oci.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#pragma once

#ifdef LLAMA_USE_OCI

#include <string>

// Structure to hold OCI pull results
struct oci_pull_result {
std::string local_path;
std::string digest;
int error_code;
std::string error_message;

bool success() const { return error_code == 0; }
};

// Pull a model from an OCI registry
// imageRef: full image reference (e.g., "ai/smollm2:135M-Q4_0", "registry.io/user/model:tag")
// cacheDir: directory to cache downloaded models
oci_pull_result oci_pull_model(const std::string & imageRef, const std::string & cacheDir);

#endif // LLAMA_USE_OCI
Loading