Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
*.swp
*.tmp

# OCI Go generated files
oci-go/liboci.h

# IDE / OS

.cache/
Expand Down
28 changes: 27 additions & 1 deletion common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ add_library(${TARGET} STATIC
log.h
ngram-cache.cpp
ngram-cache.h
oci.cpp
oci.h
regex-partial.cpp
regex-partial.h
sampling.cpp
Expand All @@ -77,7 +79,31 @@ if (BUILD_SHARED_LIBS)
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif()

set(LLAMA_COMMON_EXTRA_LIBS build_info)
# Build OCI Go library
find_program(GO_EXECUTABLE go)
if (GO_EXECUTABLE)
set(OCI_GO_DIR ${CMAKE_SOURCE_DIR}/oci-go)
set(OCI_LIB ${OCI_GO_DIR}/liboci.a)
set(OCI_HEADER ${OCI_GO_DIR}/liboci.h)

add_custom_command(
OUTPUT ${OCI_LIB} ${OCI_HEADER}
COMMAND ${GO_EXECUTABLE} build -buildmode=c-archive -o ${OCI_LIB} ${OCI_GO_DIR}/oci.go
WORKING_DIRECTORY ${OCI_GO_DIR}
DEPENDS ${OCI_GO_DIR}/oci.go ${OCI_GO_DIR}/go.mod
COMMENT "Building OCI Go library"
)

add_custom_target(oci_go_lib DEPENDS ${OCI_LIB} ${OCI_HEADER})
add_dependencies(${TARGET} oci_go_lib)

target_include_directories(${TARGET} PRIVATE ${OCI_GO_DIR})
set(LLAMA_COMMON_EXTRA_LIBS build_info ${OCI_LIB})
else()
message(WARNING "Go compiler not found. OCI functionality will not be available.")
set(LLAMA_COMMON_EXTRA_LIBS build_info)
endif()


# Use curl to download model url
if (LLAMA_CURL)
Expand Down
132 changes: 29 additions & 103 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "gguf.h" // for reading GGUF splits
#include "json-schema-to-grammar.h"
#include "log.h"
#include "oci.h"
#include "sampling.h"

// fix problem with std::min and std::max
Expand Down Expand Up @@ -1043,116 +1044,41 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_
// Docker registry functions
//

static std::string common_docker_get_token(const std::string & repo) {
std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull";

common_remote_params params;
auto res = common_remote_get_content(url, params);

if (res.first != 200) {
throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first));
}

std::string response_str(res.second.begin(), res.second.end());
nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str);

if (!response.contains("token")) {
throw std::runtime_error("Docker registry token response missing 'token' field");
}

return response["token"].get<std::string>();
}

static std::string common_docker_resolve_model(const std::string & docker) {
// Parse ai/smollm2:135M-Q4_0
size_t colon_pos = docker.find(':');
std::string repo, tag;
if (colon_pos != std::string::npos) {
repo = docker.substr(0, colon_pos);
tag = docker.substr(colon_pos + 1);
} else {
repo = docker;
tag = "latest";
}

// ai/ is the default
size_t slash_pos = docker.find('/');
// Parse image reference (e.g., ai/smollm2:135M-Q4_0)
std::string image_ref = docker;

// ai/ is the default namespace for Docker Hub
size_t slash_pos = docker.find('/');
if (slash_pos == std::string::npos) {
repo.insert(0, "ai/");
image_ref = "ai/" + docker;
}

LOG_INF("%s: Downloading Docker Model: %s:%s\n", __func__, repo.c_str(), tag.c_str());
try {
// --- helper: digest validation ---
auto validate_oci_digest = [](const std::string & digest) -> std::string {
// Expected: algo:hex ; start with sha256 (64 hex chars)
// You can extend this map if supporting other algorithms in future.
static const std::regex re("^sha256:([a-fA-F0-9]{64})$");
std::smatch m;
if (!std::regex_match(digest, m, re)) {
throw std::runtime_error("Invalid OCI digest format received in manifest: " + digest);
}
// normalize hex to lowercase
std::string normalized = digest;
std::transform(normalized.begin()+7, normalized.end(), normalized.begin()+7, [](unsigned char c){
return std::tolower(c);
});
return normalized;
};

std::string token = common_docker_get_token(repo); // Get authentication token

// Get manifest
const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo;
std::string manifest_url = url_prefix + "/manifests/" + tag;
common_remote_params manifest_params;
manifest_params.headers.push_back("Authorization: Bearer " + token);
manifest_params.headers.push_back(
"Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json");
auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
if (manifest_res.first != 200) {
throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
}

std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end());
nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str);
std::string gguf_digest; // Find the GGUF layer
if (manifest.contains("layers")) {
for (const auto & layer : manifest["layers"]) {
if (layer.contains("mediaType")) {
std::string media_type = layer["mediaType"].get<std::string>();
if (media_type == "application/vnd.docker.ai.gguf.v3" ||
media_type.find("gguf") != std::string::npos) {
gguf_digest = layer["digest"].get<std::string>();
break;
}
}
}
}

if (gguf_digest.empty()) {
throw std::runtime_error("No GGUF layer found in Docker manifest");
}

// Validate & normalize digest
gguf_digest = validate_oci_digest(gguf_digest);
LOG_DBG("%s: Using validated digest: %s\n", __func__, gguf_digest.c_str());

// Prepare local filename
std::string model_filename = repo;
std::replace(model_filename.begin(), model_filename.end(), '/', '_');
model_filename += "_" + tag + ".gguf";
std::string local_path = fs_get_cache_file(model_filename);
// Add registry prefix if not present
if (image_ref.find("registry-1.docker.io/") != 0 &&
image_ref.find("docker.io/") != 0 &&
image_ref.find("index.docker.io/") != 0) {
// For Docker Hub images without explicit registry
image_ref = "index.docker.io/" + image_ref;
}

const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
if (!common_download_file_single(blob_url, local_path, token, false)) {
throw std::runtime_error("Failed to download Docker Model");
}
LOG_INF("%s: Pulling OCI model: %s\n", __func__, image_ref.c_str());

LOG_INF("%s: Downloaded Docker Model to: %s\n", __func__, local_path.c_str());
return local_path;
try {
// Get cache directory
std::string cache_dir = fs_get_cache_directory();

// Call the Go OCI library
auto result = oci_pull_model(image_ref, cache_dir);

if (!result.success()) {
throw std::runtime_error("OCI pull failed: " + result.error_message);
}

LOG_INF("%s: Downloaded OCI model to: %s\n", __func__, result.local_path.c_str());
return result.local_path;
} catch (const std::exception & e) {
LOG_ERR("%s: Docker Model download failed: %s\n", __func__, e.what());
LOG_ERR("%s: OCI model download failed: %s\n", __func__, e.what());
throw;
}
}
Expand Down
58 changes: 58 additions & 0 deletions common/oci.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include "oci.h"
#include "log.h"

#include <nlohmann/json.hpp>

// Include the Go-generated header
extern "C" {
#include "../oci-go/liboci.h"
}

using json = nlohmann::ordered_json;

oci_pull_result oci_pull_model(const std::string & imageRef, const std::string & cacheDir) {
oci_pull_result result;
result.error_code = 0;

// Call the Go function
char * json_result = PullOCIModel(
const_cast<char*>(imageRef.c_str()),
const_cast<char*>(cacheDir.c_str())
);

if (json_result == nullptr) {
result.error_code = 1;
result.error_message = "Failed to call OCI pull function";
return result;
}

try {
// Parse the JSON result
std::string json_str(json_result);
auto j = json::parse(json_str);

if (j.contains("LocalPath")) {
result.local_path = j["LocalPath"].get<std::string>();
}
if (j.contains("Digest")) {
result.digest = j["Digest"].get<std::string>();
}
if (j.contains("Error") && !j["Error"].is_null()) {
auto err = j["Error"];
if (err.contains("Code")) {
result.error_code = err["Code"].get<int>();
}
if (err.contains("Message")) {
result.error_message = err["Message"].get<std::string>();
}
}
} catch (const std::exception & e) {
result.error_code = 1;
result.error_message = std::string("Failed to parse result: ") + e.what();
}

// Free the Go-allocated string
FreeString(json_result);

return result;
}
18 changes: 18 additions & 0 deletions common/oci.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#pragma once

#include <string>

// Structure to hold OCI pull results
struct oci_pull_result {
std::string local_path;
std::string digest;
int error_code;
std::string error_message;

bool success() const { return error_code == 0; }
};

// Pull a model from an OCI registry
// imageRef: full image reference (e.g., "ai/smollm2:135M-Q4_0", "registry.io/user/model:tag")
// cacheDir: directory to cache downloaded models
oci_pull_result oci_pull_model(const std::string & imageRef, const std::string & cacheDir);
Loading
Loading