diff --git a/.ecrc b/.ecrc index c68877ec211f1..17298e3097a72 100644 --- a/.ecrc +++ b/.ecrc @@ -1,5 +1,5 @@ { - "Exclude": ["^\\.gitmodules$", "stb_image\\.h"], + "Exclude": ["^\\.gitmodules$", "stb_image\\.h", "oci-go/"], "Disable": { "IndentSize": true } diff --git a/.gitignore b/.gitignore index c7d000978571a..9f98323597740 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,9 @@ *.swp *.tmp +# OCI Go generated files +oci-go/liboci.h + # IDE / OS .cache/ diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index fe290bf8fdda4..46d9b25c49628 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -77,7 +77,75 @@ if (BUILD_SHARED_LIBS) set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON) endif() -set(LLAMA_COMMON_EXTRA_LIBS build_info) +# Build OCI Go library +find_program(GO_EXECUTABLE go) +if (GO_EXECUTABLE) + # Check Go version - we need at least 1.21 for toolchain directive support + set(GO_VERSION_RESULT 1) + execute_process( + COMMAND ${GO_EXECUTABLE} version + OUTPUT_VARIABLE GO_VERSION_OUTPUT + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE GO_VERSION_RESULT + ) + + if (GO_VERSION_RESULT EQUAL 0) + # Extract version number from "go version go1.X.Y ..." or "go version go1.X ..." + string(REGEX MATCH "go([0-9]+)\\.([0-9]+)" GO_VERSION_MATCH "${GO_VERSION_OUTPUT}") + if (GO_VERSION_MATCH) + set(GO_VERSION_MAJOR ${CMAKE_MATCH_1}) + set(GO_VERSION_MINOR ${CMAKE_MATCH_2}) + + if (GO_VERSION_MAJOR LESS 1 OR (GO_VERSION_MAJOR EQUAL 1 AND GO_VERSION_MINOR LESS 21)) + message(WARNING "Go version ${GO_VERSION_MAJOR}.${GO_VERSION_MINOR} is too old. OCI functionality requires Go 1.21 or later. OCI functionality will not be available.") + set(GO_VERSION_OK FALSE) + else() + set(GO_VERSION_OK TRUE) + endif() + else() + message(WARNING "Unable to parse Go version from: ${GO_VERSION_OUTPUT}. OCI functionality will not be available.") + set(GO_VERSION_OK FALSE) + endif() + else() + message(WARNING "Failed to get Go version. OCI functionality will not be available.") + set(GO_VERSION_OK FALSE) + endif() +endif() + +if (GO_EXECUTABLE AND GO_VERSION_OK) + set(OCI_GO_DIR ${CMAKE_SOURCE_DIR}/oci-go) + set(OCI_LIB ${OCI_GO_DIR}/liboci.a) + set(OCI_HEADER ${OCI_GO_DIR}/liboci.h) + + add_custom_command( + OUTPUT ${OCI_LIB} ${OCI_HEADER} + COMMAND ${GO_EXECUTABLE} build -buildmode=c-archive -o ${OCI_LIB} ${OCI_GO_DIR}/oci.go + WORKING_DIRECTORY ${OCI_GO_DIR} + DEPENDS ${OCI_GO_DIR}/oci.go ${OCI_GO_DIR}/go.mod + COMMENT "Building OCI Go library" + ) + + add_custom_target(oci_go_lib DEPENDS ${OCI_LIB} ${OCI_HEADER}) + add_dependencies(${TARGET} oci_go_lib) + + target_include_directories(${TARGET} PRIVATE ${OCI_GO_DIR}) + target_sources(${TARGET} PRIVATE oci.cpp oci.h) + target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_OCI) + set(LLAMA_COMMON_EXTRA_LIBS build_info ${OCI_LIB}) + + # On macOS, the Go runtime requires CoreFoundation and Security frameworks + if (APPLE) + find_library(OCI_CORE_FOUNDATION_FRAMEWORK CoreFoundation REQUIRED) + find_library(OCI_SECURITY_FRAMEWORK Security REQUIRED) + set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${OCI_CORE_FOUNDATION_FRAMEWORK} ${OCI_SECURITY_FRAMEWORK}) + endif() +else() + if (NOT GO_EXECUTABLE) + message(WARNING "Go compiler not found. OCI functionality will not be available.") + endif() + set(LLAMA_COMMON_EXTRA_LIBS build_info) +endif() + # Use curl to download model url if (LLAMA_CURL) diff --git a/common/arg.cpp b/common/arg.cpp index 33ed7ae857269..652808393fd6e 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -5,6 +5,9 @@ #include "gguf.h" // for reading GGUF splits #include "json-schema-to-grammar.h" #include "log.h" +#ifdef LLAMA_USE_OCI +#include "oci.h" +#endif #include "sampling.h" // fix problem with std::min and std::max @@ -1043,119 +1046,42 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_ // Docker registry functions // -static std::string common_docker_get_token(const std::string & repo) { - std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull"; - - common_remote_params params; - auto res = common_remote_get_content(url, params); - - if (res.first != 200) { - throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first)); - } - - std::string response_str(res.second.begin(), res.second.end()); - nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str); - - if (!response.contains("token")) { - throw std::runtime_error("Docker registry token response missing 'token' field"); - } - - return response["token"].get(); -} - +#ifdef LLAMA_USE_OCI static std::string common_docker_resolve_model(const std::string & docker) { - // Parse ai/smollm2:135M-Q4_0 - size_t colon_pos = docker.find(':'); - std::string repo, tag; - if (colon_pos != std::string::npos) { - repo = docker.substr(0, colon_pos); - tag = docker.substr(colon_pos + 1); - } else { - repo = docker; - tag = "latest"; - } + // Parse image reference (e.g., ai/smollm2:135M-Q4_0) + std::string image_ref = docker; - // ai/ is the default - size_t slash_pos = docker.find('/'); + // ai/ is the default namespace for Docker Hub + size_t slash_pos = docker.find('/'); if (slash_pos == std::string::npos) { - repo.insert(0, "ai/"); + image_ref = "ai/" + docker; } - LOG_INF("%s: Downloading Docker Model: %s:%s\n", __func__, repo.c_str(), tag.c_str()); - try { - // --- helper: digest validation --- - auto validate_oci_digest = [](const std::string & digest) -> std::string { - // Expected: algo:hex ; start with sha256 (64 hex chars) - // You can extend this map if supporting other algorithms in future. - static const std::regex re("^sha256:([a-fA-F0-9]{64})$"); - std::smatch m; - if (!std::regex_match(digest, m, re)) { - throw std::runtime_error("Invalid OCI digest format received in manifest: " + digest); - } - // normalize hex to lowercase - std::string normalized = digest; - std::transform(normalized.begin()+7, normalized.end(), normalized.begin()+7, [](unsigned char c){ - return std::tolower(c); - }); - return normalized; - }; - - std::string token = common_docker_get_token(repo); // Get authentication token - - // Get manifest - const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo; - std::string manifest_url = url_prefix + "/manifests/" + tag; - common_remote_params manifest_params; - manifest_params.headers.push_back("Authorization: Bearer " + token); - manifest_params.headers.push_back( - "Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json"); - auto manifest_res = common_remote_get_content(manifest_url, manifest_params); - if (manifest_res.first != 200) { - throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first)); - } - - std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end()); - nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str); - std::string gguf_digest; // Find the GGUF layer - if (manifest.contains("layers")) { - for (const auto & layer : manifest["layers"]) { - if (layer.contains("mediaType")) { - std::string media_type = layer["mediaType"].get(); - if (media_type == "application/vnd.docker.ai.gguf.v3" || - media_type.find("gguf") != std::string::npos) { - gguf_digest = layer["digest"].get(); - break; - } - } - } - } - - if (gguf_digest.empty()) { - throw std::runtime_error("No GGUF layer found in Docker manifest"); - } + // Add registry prefix if not present + if (image_ref.find("registry-1.docker.io/") != 0 && image_ref.find("docker.io/") != 0 && + image_ref.find("index.docker.io/") != 0) { + // For Docker Hub images without explicit registry + image_ref = "index.docker.io/" + image_ref; + } - // Validate & normalize digest - gguf_digest = validate_oci_digest(gguf_digest); - LOG_DBG("%s: Using validated digest: %s\n", __func__, gguf_digest.c_str()); + try { + // Get cache directory + std::string cache_dir = fs_get_cache_directory(); - // Prepare local filename - std::string model_filename = repo; - std::replace(model_filename.begin(), model_filename.end(), '/', '_'); - model_filename += "_" + tag + ".gguf"; - std::string local_path = fs_get_cache_file(model_filename); + // Call the Go OCI library + auto result = oci_pull_model(image_ref, cache_dir); - const std::string blob_url = url_prefix + "/blobs/" + gguf_digest; - if (!common_download_file_single(blob_url, local_path, token, false)) { - throw std::runtime_error("Failed to download Docker Model"); + if (!result.success()) { + throw std::runtime_error("OCI pull failed: " + result.error_message); } - LOG_INF("%s: Downloaded Docker Model to: %s\n", __func__, local_path.c_str()); - return local_path; + return result.local_path; } catch (const std::exception & e) { - LOG_ERR("%s: Docker Model download failed: %s\n", __func__, e.what()); + LOG_ERR("%s: OCI model download failed: %s\n", __func__, e.what()); throw; } } +#endif // LLAMA_USE_OCI // // utils @@ -1208,7 +1134,11 @@ static handle_model_result common_params_handle_model( // handle pre-fill default model path and url based on hf_repo and hf_file { if (!model.docker_repo.empty()) { // Handle Docker URLs by resolving them to local paths +#ifdef LLAMA_USE_OCI model.path = common_docker_resolve_model(model.docker_repo); +#else + LOG_ERR("Need to build with go compiler and LLAMA_USE_OCI\n"); +#endif } else if (!model.hf_repo.empty()) { // short-hand to avoid specifying --hf-file -> default it to --model if (model.hf_file.empty()) { diff --git a/common/oci.cpp b/common/oci.cpp new file mode 100644 index 0000000000000..e38668ad4c220 --- /dev/null +++ b/common/oci.cpp @@ -0,0 +1,60 @@ +#ifdef LLAMA_USE_OCI + +#include "oci.h" + +#include "log.h" + +#include + +// Include the Go-generated header +extern "C" { +#include "../oci-go/liboci.h" +} + +using json = nlohmann::ordered_json; + +oci_pull_result oci_pull_model(const std::string & imageRef, const std::string & cacheDir) { + oci_pull_result result; + result.error_code = 0; + + // Call the Go function + char * json_result = PullOCIModel(const_cast(imageRef.c_str()), const_cast(cacheDir.c_str())); + + if (json_result == nullptr) { + result.error_code = 1; + result.error_message = "Failed to call OCI pull function"; + return result; + } + + try { + // Parse the JSON result + std::string json_str(json_result); + auto j = json::parse(json_str); + + if (j.contains("LocalPath")) { + result.local_path = j["LocalPath"].get(); + } + if (j.contains("Digest")) { + result.digest = j["Digest"].get(); + } + if (j.contains("Error") && !j["Error"].is_null()) { + auto err = j["Error"]; + if (err.contains("Code")) { + result.error_code = err["Code"].get(); + } + if (err.contains("Message")) { + result.error_message = err["Message"].get(); + } + } + } catch (const std::exception & e) { + result.error_code = 1; + result.error_message = std::string("Failed to parse result: ") + e.what(); + } + + // Free the Go-allocated string + FreeString(json_result); + + return result; +} + +#endif // LLAMA_USE_OCI diff --git a/common/oci.h b/common/oci.h new file mode 100644 index 0000000000000..607f7687330e5 --- /dev/null +++ b/common/oci.h @@ -0,0 +1,22 @@ +#pragma once + +#ifdef LLAMA_USE_OCI + +#include + +// Structure to hold OCI pull results +struct oci_pull_result { + std::string local_path; + std::string digest; + int error_code; + std::string error_message; + + bool success() const { return error_code == 0; } +}; + +// Pull a model from an OCI registry +// imageRef: full image reference (e.g., "ai/smollm2:135M-Q4_0", "registry.io/user/model:tag") +// cacheDir: directory to cache downloaded models +oci_pull_result oci_pull_model(const std::string & imageRef, const std::string & cacheDir); + +#endif // LLAMA_USE_OCI diff --git a/docs/oci-registry.md b/docs/oci-registry.md new file mode 100644 index 0000000000000..9563828ae5330 --- /dev/null +++ b/docs/oci-registry.md @@ -0,0 +1,161 @@ +# OCI/Docker Registry Integration + +llama.cpp supports pulling models directly from OCI-compliant registries such as Docker Hub. This feature uses the [go-containerregistry](https://github.com/google/go-containerregistry) library to handle registry authentication and image pulling. + +## Features + +- Pull GGUF models from Docker Hub and other OCI registries +- Automatic authentication using Docker credentials (via `docker login`) +- Support for private registries with authentication +- Caching of downloaded models +- **Docker-style progress bars** showing download progress, speed, and ETA +- **Resumable downloads** - interrupted downloads can be resumed automatically + +## Prerequisites + +- Go 1.24 or later (for building from source) +- Docker credentials configured (for private registries) + +## Usage + +### Pulling Public Models + +To pull a public model from Docker Hub: + +```bash +./llama-cli --docker-repo ai/smollm2:135M-Q4_0 +``` + +By default, models are pulled from the `ai/` namespace on Docker Hub. If no namespace is specified, `ai/` is assumed: + +```bash +# These are equivalent: +./llama-cli --docker-repo gemma3 +./llama-cli --docker-repo ai/gemma3 +``` + +### Pulling Private Models + +For private models or registries requiring authentication, first authenticate using Docker: + +```bash +docker login +# Or for a specific registry: +docker login registry.example.com +``` + +Then pull the model: + +```bash +./llama-cli --docker-repo myuser/private-model:Q4_K_M +``` + +### Custom Registries + +You can also pull from custom OCI registries by specifying the full registry URL: + +```bash +./llama-cli --docker-repo registry.example.com/namespace/model:tag +``` + +## How It Works + +1. The `--docker-repo` (or `-dr`) flag specifies the OCI image reference +2. llama.cpp uses the Go-based OCI library to: + - Parse the image reference + - Authenticate using Docker credentials (if available) + - Fetch the manifest from the registry + - Identify and download the GGUF layer with progress tracking + - Display docker-style progress bars during download +3. The model is cached locally for future use +4. If a download is interrupted, it will automatically resume from where it left off + +### Progress Display + +During download, you'll see progress information similar to Docker: + +``` +1234567890ab: Downloading [===================> ] 39.0% (39.00 MB / 100.00 MB) 20.10 MB/s +``` + +- **1234567890ab**: Short digest of the layer being downloaded +- **Progress bar**: Visual representation of download progress +- **Percentage**: Completion percentage +- **Size**: Downloaded size / Total size in MB +- **Speed**: Current download speed in MB/s + +### Resumable Downloads + +If a download is interrupted (e.g., network failure, Ctrl+C), the next download attempt will automatically resume: + +``` +1234567890ab: Resuming download from 39.00 MB +1234567890ab: Downloading [===================> ] 41.0% (41.00 MB / 100.00 MB) 18.50 MB/s +``` + +The download will continue from where it stopped, saving time and bandwidth. The integrity of resumed downloads is verified using layer digests. + +## Image Format + +Models must be packaged as OCI images with a GGUF layer. The layer should have one of these media types: +- `application/vnd.docker.ai.gguf.v3` +- Any media type containing "gguf" + +## Authentication + +Authentication is handled automatically using the same credentials as the Docker CLI: +- Credentials are stored in `~/.docker/config.json` +- Use `docker login` to authenticate +- Supports credential helpers and authentication providers + +## Caching + +Downloaded models are cached in the standard llama.cpp cache directory: +- Linux/macOS: `~/.cache/llama.cpp/` +- Windows: `%LOCALAPPDATA%\llama.cpp\` + +Cached models are verified using their digest to ensure integrity. If the cached file matches the expected digest, it will be used instead of re-downloading. + +### Partial Downloads + +Partial downloads are stored with a `.tmp` extension alongside a `.digest` file for verification. If a download is interrupted: +1. The partial file and digest are preserved +2. On the next attempt, if the digest matches, download resumes +3. If the digest differs (e.g., model was updated), a fresh download starts + +## Building with OCI Support + +OCI support is automatically enabled if Go is available during build: + +```bash +cmake -B build +cmake --build build +``` + +If Go is not found, a warning will be displayed and OCI functionality will be unavailable. + +## Troubleshooting + +### Authentication Issues + +If you encounter authentication errors: +1. Ensure you're logged in: `docker login` +2. Verify credentials: Check `~/.docker/config.json` +3. For private registries, specify the full registry URL + +### Network Issues + +If downloads fail or are interrupted: +1. Check your internet connection +2. Verify the registry is accessible +3. Try pulling a test image with Docker: `docker pull ` +4. The download will automatically resume on retry if the partial download is valid + +Note: Progress bars require a TTY. If running in a non-interactive environment (e.g., CI/CD), progress information will be minimal. + +### Build Issues + +If OCI support is not available: +1. Ensure Go 1.24 or later is installed: `go version` +2. Rebuild the project: `cmake --build build --clean-first` +3. Check CMake output for Go-related warnings diff --git a/oci-go/go.mod b/oci-go/go.mod new file mode 100644 index 0000000000000..2cbe3479dc317 --- /dev/null +++ b/oci-go/go.mod @@ -0,0 +1,26 @@ +module github.com/ggml-org/llama.cpp/oci-go + +go 1.24.0 + +toolchain go1.24.9 + +require ( + github.com/google/go-containerregistry v0.20.6 + golang.org/x/term v0.36.0 +) + +require ( + github.com/containerd/stargz-snapshotter/estargz v0.16.3 // indirect + github.com/docker/cli v28.2.2+incompatible // indirect + github.com/docker/distribution v2.8.3+incompatible // indirect + github.com/docker/docker-credential-helpers v0.9.3 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/mitchellh/go-homedir v1.1.0 // indirect + github.com/opencontainers/go-digest v1.0.0 // indirect + github.com/opencontainers/image-spec v1.1.1 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect + github.com/vbatts/tar-split v0.12.1 // indirect + golang.org/x/sync v0.15.0 // indirect + golang.org/x/sys v0.37.0 // indirect +) diff --git a/oci-go/go.sum b/oci-go/go.sum new file mode 100644 index 0000000000000..181419508c91b --- /dev/null +++ b/oci-go/go.sum @@ -0,0 +1,48 @@ +github.com/containerd/stargz-snapshotter/estargz v0.16.3 h1:7evrXtoh1mSbGj/pfRccTampEyKpjpOnS3CyiV1Ebr8= +github.com/containerd/stargz-snapshotter/estargz v0.16.3/go.mod h1:uyr4BfYfOj3G9WBVE8cOlQmXAbPN9VEQpBBeJIuOipU= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/docker/cli v28.2.2+incompatible h1:qzx5BNUDFqlvyq4AHzdNB7gSyVTmU4cgsyN9SdInc1A= +github.com/docker/cli v28.2.2+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= +github.com/docker/distribution v2.8.3+incompatible h1:AtKxIZ36LoNK51+Z6RpzLpddBirtxJnzDrHLEKxTAYk= +github.com/docker/distribution v2.8.3+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= +github.com/docker/docker-credential-helpers v0.9.3 h1:gAm/VtF9wgqJMoxzT3Gj5p4AqIjCBS4wrsOh9yRqcz8= +github.com/docker/docker-credential-helpers v0.9.3/go.mod h1:x+4Gbw9aGmChi3qTLZj8Dfn0TD20M/fuWy0E5+WDeCo= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/go-containerregistry v0.20.6 h1:cvWX87UxxLgaH76b4hIvya6Dzz9qHB31qAwjAohdSTU= +github.com/google/go-containerregistry v0.20.6/go.mod h1:T0x8MuoAoKX/873bkeSfLD2FAkwCDf9/HZgsFJ02E2Y= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= +github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= +github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= +github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= +github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= +github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/vbatts/tar-split v0.12.1 h1:CqKoORW7BUWBe7UL/iqTVvkTBOF8UvOMKOIZykxnnbo= +github.com/vbatts/tar-split v0.12.1/go.mod h1:eF6B6i6ftWQcDqEn3/iGFRFRo8cBIMSJVOpnNdfTMFA= +golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= +golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= +golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q= +golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gotest.tools/v3 v3.0.3 h1:4AuOwCGf4lLR9u3YOe2awrHygurzhO/HeQ6laiA6Sx0= +gotest.tools/v3 v3.0.3/go.mod h1:Z7Lb0S5l+klDB31fvDQX8ss/FlKDxtlFlw3Oa8Ymbl8= diff --git a/oci-go/liboci.h b/oci-go/liboci.h new file mode 100644 index 0000000000000..7eceed9711240 --- /dev/null +++ b/oci-go/liboci.h @@ -0,0 +1,95 @@ +/* Code generated by cmd/cgo; DO NOT EDIT. */ + +/* package command-line-arguments */ + + +#line 1 "cgo-builtin-export-prolog" + +#include + +#ifndef GO_CGO_EXPORT_PROLOGUE_H +#define GO_CGO_EXPORT_PROLOGUE_H + +#ifndef GO_CGO_GOSTRING_TYPEDEF +typedef struct { const char *p; ptrdiff_t n; } _GoString_; +extern size_t _GoStringLen(_GoString_ s); +extern const char *_GoStringPtr(_GoString_ s); +#endif + +#endif + +/* Start of preamble from import "C" comments. */ + + +#line 3 "oci.go" + +#include + +#line 1 "cgo-generated-wrapper" + + +/* End of preamble from import "C" comments. */ + + +/* Start of boilerplate cgo prologue. */ +#line 1 "cgo-gcc-export-header-prolog" + +#ifndef GO_CGO_PROLOGUE_H +#define GO_CGO_PROLOGUE_H + +typedef signed char GoInt8; +typedef unsigned char GoUint8; +typedef short GoInt16; +typedef unsigned short GoUint16; +typedef int GoInt32; +typedef unsigned int GoUint32; +typedef long long GoInt64; +typedef unsigned long long GoUint64; +typedef GoInt64 GoInt; +typedef GoUint64 GoUint; +typedef size_t GoUintptr; +typedef float GoFloat32; +typedef double GoFloat64; +#ifdef _MSC_VER +#if !defined(__cplusplus) || _MSVC_LANG <= 201402L +#include +typedef _Fcomplex GoComplex64; +typedef _Dcomplex GoComplex128; +#else +#include +typedef std::complex GoComplex64; +typedef std::complex GoComplex128; +#endif +#else +typedef float _Complex GoComplex64; +typedef double _Complex GoComplex128; +#endif + +/* + static assertion to make sure the file is being used on architecture + at least with matching size of GoInt. +*/ +typedef char _check_for_64_bit_pointer_matching_GoInt[sizeof(void*)==64/8 ? 1:-1]; + +#ifndef GO_CGO_GOSTRING_TYPEDEF +typedef _GoString_ GoString; +#endif +typedef void *GoMap; +typedef void *GoChan; +typedef struct { void *t; void *v; } GoInterface; +typedef struct { void *data; GoInt len; GoInt cap; } GoSlice; + +#endif + +/* End of boilerplate cgo prologue. */ + +#ifdef __cplusplus +extern "C" { +#endif + +extern char* PullOCIModel(char* imageRef, char* cacheDir); +extern void FreeString(char* s); + +#ifdef __cplusplus +} +#endif diff --git a/oci-go/oci.go b/oci-go/oci.go new file mode 100644 index 0000000000000..83b4b9c9a4061 --- /dev/null +++ b/oci-go/oci.go @@ -0,0 +1,368 @@ +package main + +/* +#include +*/ +import "C" +import ( + "context" + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "strings" + "sync/atomic" + "time" + "unsafe" + + "github.com/google/go-containerregistry/pkg/authn" + "github.com/google/go-containerregistry/pkg/name" + v1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/google/go-containerregistry/pkg/v1/remote" + "golang.org/x/term" +) + +// OCIError represents an error that occurred during OCI operations +type OCIError struct { + Code int + Message string +} + +// OCIResult represents the result of pulling a model +type OCIResult struct { + LocalPath string + Digest string + Error *OCIError +} + +//export PullOCIModel +func PullOCIModel(imageRef, cacheDir *C.char) *C.char { + goImageRef := C.GoString(imageRef) + goCacheDir := C.GoString(cacheDir) + + result, err := pullModel(goImageRef, goCacheDir) + if err != nil { + result = &OCIResult{ + Error: &OCIError{ + Code: 1, + Message: err.Error(), + }, + } + } + + jsonBytes, _ := json.Marshal(result) + return C.CString(string(jsonBytes)) +} + +//export FreeString +func FreeString(s *C.char) { + C.free(unsafe.Pointer(s)) +} + +// progressWriter wraps an io.Writer and tracks progress with docker-style output +type progressWriter struct { + writer io.Writer + total int64 + downloaded int64 + lastPrint time.Time + layerDigest string + startTime time.Time +} + +func newProgressWriter(w io.Writer, total int64, digest string) *progressWriter { + return &progressWriter{ + writer: w, + total: total, + downloaded: 0, + lastPrint: time.Now(), + layerDigest: digest, + startTime: time.Now(), + } +} + +func (pw *progressWriter) Write(p []byte) (int, error) { + n, err := pw.writer.Write(p) + if n > 0 { + atomic.AddInt64(&pw.downloaded, int64(n)) + + // Update progress display every 100ms + now := time.Now() + if now.Sub(pw.lastPrint) >= 100*time.Millisecond { + pw.printProgress() + pw.lastPrint = now + } + } + return n, err +} + +func (pw *progressWriter) printProgress() { + downloaded := atomic.LoadInt64(&pw.downloaded) + + // Calculate percentage and download speed + var percentage float64 + if pw.total > 0 { + percentage = float64(downloaded) / float64(pw.total) * 100.0 + } + + elapsed := time.Since(pw.startTime).Seconds() + speed := float64(0) + if elapsed > 0 { + speed = float64(downloaded) / elapsed / (1024 * 1024) // MB/s + } + + // Format sizes + downloadedMB := float64(downloaded) / (1024 * 1024) + totalMB := float64(pw.total) / (1024 * 1024) + + // Get short digest (first 12 chars after sha256:) + shortDigest := pw.layerDigest + if strings.HasPrefix(shortDigest, "sha256:") { + shortDigest = shortDigest[7:19] + } + + // Get terminal width, default to 80 if cannot be determined + termWidth := 80 + if width, _, err := term.GetSize(int(os.Stderr.Fd())); err == nil && width > 0 { + termWidth = width + } + + // Print docker-style progress + if pw.total > 0 { + // Build the progress message to measure its length + // Format: "shortDigest: Downloading [] 100.0% (9999.99 MB / 9999.99 MB) 999.99 MB/s" + prefix := fmt.Sprintf("%s: Downloading [", shortDigest) + suffix := fmt.Sprintf("] %.1f%% (%.2f MB / %.2f MB) %.2f MB/s", + percentage, downloadedMB, totalMB, speed) + + // Calculate available space for progress bar + // Reserve at least 10 chars for the bar, use remaining space up to 50 chars + fixedWidth := len(prefix) + len(suffix) + maxBarWidth := termWidth - fixedWidth + if maxBarWidth < 10 { + maxBarWidth = 10 + } else if maxBarWidth > 50 { + maxBarWidth = 50 + } + + // Build the complete line + var line strings.Builder + line.WriteString("\r") + line.WriteString(prefix) + + // Progress bar + filled := int(float64(maxBarWidth) * percentage / 100.0) + for i := 0; i < maxBarWidth; i++ { + if i < filled { + line.WriteString("=") + } else if i == filled { + line.WriteString(">") + } else { + line.WriteString(" ") + } + } + + line.WriteString(suffix) + + // Pad with spaces to clear any trailing characters from previous output + currentLen := len(prefix) + maxBarWidth + len(suffix) + if currentLen < termWidth { + padding := termWidth - currentLen + for i := 0; i < padding; i++ { + line.WriteString(" ") + } + } + + // Write the complete line + fmt.Fprint(os.Stderr, line.String()) + } else { + // Build line for unknown total size + line := fmt.Sprintf("\r%s: Downloading %.2f MB %.2f MB/s", + shortDigest, downloadedMB, speed) + + // Pad with spaces to clear trailing characters + if len(line) < termWidth { + padding := termWidth - len(line) + for i := 0; i < padding; i++ { + line += " " + } + } + + fmt.Fprint(os.Stderr, line) + } +} + +func (pw *progressWriter) finish() { + downloaded := atomic.LoadInt64(&pw.downloaded) + downloadedMB := float64(downloaded) / (1024 * 1024) + + // Get short digest + shortDigest := pw.layerDigest + if strings.HasPrefix(shortDigest, "sha256:") { + shortDigest = shortDigest[7:19] + } + + fmt.Fprintf(os.Stderr, "\r%s: Download complete (%.2f MB)\n", shortDigest, downloadedMB) +} + +func pullModel(imageRef, cacheDir string) (*OCIResult, error) { + ctx := context.Background() + + // Parse the image reference + ref, err := name.ParseReference(imageRef) + if err != nil { + return nil, fmt.Errorf("failed to parse image reference: %w", err) + } + + // Use docker config for authentication (supports docker login) + authenticator := authn.NewMultiKeychain( + authn.DefaultKeychain, + ) + + // Get the image descriptor + img, err := remote.Image(ref, remote.WithAuthFromKeychain(authenticator), remote.WithContext(ctx)) + if err != nil { + return nil, fmt.Errorf("failed to fetch image: %w", err) + } + + // Get the manifest + manifest, err := img.Manifest() + if err != nil { + return nil, fmt.Errorf("failed to get manifest: %w", err) + } + + // Find the GGUF layer + var ggufLayer v1.Layer + var ggufDigest string + var layerSize int64 + for _, layer := range manifest.Layers { + mediaType := string(layer.MediaType) + if mediaType == "application/vnd.docker.ai.gguf.v3" || strings.Contains(mediaType, "gguf") { + ggufLayer, err = img.LayerByDigest(layer.Digest) + if err != nil { + return nil, fmt.Errorf("failed to get GGUF layer: %w", err) + } + ggufDigest = layer.Digest.String() + layerSize = layer.Size + break + } + } + + if ggufLayer == nil { + return nil, fmt.Errorf("no GGUF layer found in image") + } + + // Prepare local file path + refStr := ref.String() + modelFilename := strings.ReplaceAll(refStr, "/", "_") + modelFilename = strings.ReplaceAll(modelFilename, ":", "_") + modelFilename += ".gguf" + + localPath := filepath.Join(cacheDir, modelFilename) + tempPath := localPath + ".tmp" + digestPath := localPath + ".digest" + + // Check if file already exists and is complete + if _, err := os.Stat(localPath); err == nil { + // File exists, verify digest matches + if storedDigest, err := os.ReadFile(digestPath); err == nil && string(storedDigest) == ggufDigest { + fmt.Fprintf(os.Stderr, "%s: Using cached model (digest verified)\n", ggufDigest[7:19]) + return &OCIResult{ + LocalPath: localPath, + Digest: ggufDigest, + }, nil + } + // Digest mismatch or missing, need to re-download + fmt.Fprintf(os.Stderr, "%s: Cache digest mismatch, re-downloading\n", ggufDigest[7:19]) + os.Remove(localPath) + os.Remove(digestPath) + os.Remove(tempPath) + } + + // Check for partial download + var existingSize int64 = 0 + var resuming bool = false + if fileInfo, err := os.Stat(tempPath); err == nil { + // Verify the digest matches what we expect + if storedDigest, err := os.ReadFile(digestPath); err == nil && string(storedDigest) == ggufDigest { + existingSize = fileInfo.Size() + if existingSize > 0 && existingSize < layerSize { + resuming = true + } + } else { + // Digest mismatch, remove partial file + os.Remove(tempPath) + os.Remove(digestPath) + } + } + + // Store digest for verification + if err := os.WriteFile(digestPath, []byte(ggufDigest), 0644); err != nil { + return nil, fmt.Errorf("failed to write digest file: %w", err) + } + + // Download the layer + layerReader, err := ggufLayer.Uncompressed() + if err != nil { + return nil, fmt.Errorf("failed to get layer reader: %w", err) + } + defer layerReader.Close() + + // Skip already downloaded bytes if resuming + if resuming && existingSize > 0 { + _, err = io.CopyN(io.Discard, layerReader, existingSize) + if err != nil { + return nil, fmt.Errorf("failed to skip downloaded bytes: %w", err) + } + } + + // Open file for appending or create new + var outFile *os.File + if resuming { + outFile, err = os.OpenFile(tempPath, os.O_APPEND|os.O_WRONLY, 0644) + } else { + outFile, err = os.Create(tempPath) + } + if err != nil { + return nil, fmt.Errorf("failed to create output file: %w", err) + } + + // Create progress writer + pw := newProgressWriter(outFile, layerSize, ggufDigest) + if resuming { + atomic.StoreInt64(&pw.downloaded, existingSize) + pw.startTime = time.Now() // Reset start time for accurate speed calculation + } + + // Copy the data with progress tracking + _, err = io.Copy(pw, layerReader) + outFile.Close() + + if err != nil { + // Don't remove partial file on error - allow resume + return nil, fmt.Errorf("failed to write layer data: %w", err) + } + + // Print completion message + pw.finish() + + // Verify downloaded file size matches expected + if fileInfo, err := os.Stat(tempPath); err == nil { + if fileInfo.Size() != layerSize { + return nil, fmt.Errorf("downloaded file size (%d) doesn't match expected size (%d)", + fileInfo.Size(), layerSize) + } + } + + // Rename to final location (atomic operation) + if err := os.Rename(tempPath, localPath); err != nil { + return nil, fmt.Errorf("failed to rename file: %w", err) + } + + return &OCIResult{ + LocalPath: localPath, + Digest: ggufDigest, + }, nil +} + +func main() {}