Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 128 additions & 2 deletions pkg/distribution/distribution/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ import (
"fmt"
"io"
"net/http"
"os"
"slices"
"strings"

"github.com/docker/model-runner/pkg/distribution/huggingface"
"github.com/docker/model-runner/pkg/distribution/internal/progress"
"github.com/docker/model-runner/pkg/distribution/internal/store"
"github.com/docker/model-runner/pkg/distribution/registry"
Expand Down Expand Up @@ -162,9 +164,10 @@ func (c *Client) normalizeModelName(model string) string {
return model
}

// Normalize HuggingFace model names (lowercase path)
// Normalize HuggingFace model names
if strings.HasPrefix(model, "hf.co/") {
// Replace hf.co with huggingface.co to avoid losing the Authorization header on redirect.
// Lowercase for OCI compatibility (repository names must be lowercase)
model = "huggingface.co" + strings.ToLower(strings.TrimPrefix(model, "hf.co"))
}

Expand Down Expand Up @@ -261,21 +264,31 @@ func (c *Client) resolveID(id string) string {

// PullModel pulls a model from a registry and returns the local file path
func (c *Client) PullModel(ctx context.Context, reference string, progressWriter io.Writer, bearerToken ...string) error {
// Store original reference before normalization (needed for case-sensitive HuggingFace API)
originalReference := reference
// Normalize the model reference
reference = c.normalizeModelName(reference)
c.log.Infoln("Starting model pull:", utils.SanitizeForLog(reference))

// Use the client's registry, or create a temporary one if bearer token is provided
registryClient := c.registry
var token string
if len(bearerToken) > 0 && bearerToken[0] != "" {
token = bearerToken[0]
// Create a temporary registry client with bearer token authentication
auth := &authn.Bearer{Token: bearerToken[0]}
auth := &authn.Bearer{Token: token}
registryClient = registry.FromClient(c.registry, registry.WithAuth(auth))
}

// First, fetch the remote model to get the manifest
remoteModel, err := registryClient.Model(ctx, reference)
if err != nil {
// Check if this is a HuggingFace reference and the error indicates no OCI manifest
if isHuggingFaceReference(reference) && isNotOCIError(err) {
c.log.Infoln("No OCI manifest found, attempting native HuggingFace pull")
// Pass original reference to preserve case-sensitivity for HuggingFace API
return c.pullNativeHuggingFace(ctx, originalReference, progressWriter, token)
}
return fmt.Errorf("reading model from registry: %w", err)
}

Expand Down Expand Up @@ -637,3 +650,116 @@ func checkCompat(image types.ModelArtifact, log *logrus.Entry, reference string,

return nil
}

// isHuggingFaceReference checks if a reference is a HuggingFace model reference
func isHuggingFaceReference(reference string) bool {
return strings.HasPrefix(reference, "huggingface.co/")
}

// isNotOCIError checks if the error indicates the model is not OCI-formatted
// This happens when the HuggingFace repository doesn't have an OCI manifest
func isNotOCIError(err error) bool {
if err == nil {
return false
}

// Check for registry errors indicating no manifest
var regErr *registry.Error
if errors.As(err, &regErr) {
if regErr.Code == "MANIFEST_UNKNOWN" || regErr.Code == "NAME_UNKNOWN" {
return true
}
}

// Note: We intentionally don't treat ErrInvalidReference as "not OCI" - that's a format error
// that should be reported to the user, not interpreted as a native HF model.
// The model name is lowercased during normalization to ensure OCI compatibility.

// Also check error message for common patterns
errStr := err.Error()
return strings.Contains(errStr, "MANIFEST_UNKNOWN") ||
strings.Contains(errStr, "NAME_UNKNOWN") ||
strings.Contains(errStr, "manifest unknown") ||
// HuggingFace returns this error for non-GGUF repositories
strings.Contains(errStr, "Repository is not GGUF") ||
strings.Contains(errStr, "not compatible with llama.cpp")
}

// parseHFReference extracts repo and revision from a HF reference
// e.g., "huggingface.co/org/model:revision" -> ("org/model", "revision")
// e.g., "hf.co/org/model:latest" -> ("org/model", "main")
// Note: This preserves the original case of the repo name for HuggingFace API compatibility
func parseHFReference(reference string) (repo, revision string) {
// Remove registry prefix (handle both hf.co and huggingface.co)
ref := strings.TrimPrefix(reference, "huggingface.co/")
ref = strings.TrimPrefix(ref, "hf.co/")

// Split by colon to get tag
parts := strings.SplitN(ref, ":", 2)
repo = parts[0]

revision = "main"
if len(parts) == 2 && parts[1] != "" && parts[1] != "latest" {
revision = parts[1]
}

return repo, revision
}

// pullNativeHuggingFace pulls a native HuggingFace repository (non-OCI format)
// This is used when the model is stored as raw files (safetensors) on HuggingFace Hub
func (c *Client) pullNativeHuggingFace(ctx context.Context, reference string, progressWriter io.Writer, token string) error {
repo, revision := parseHFReference(reference)
c.log.Infof("Pulling native HuggingFace model: repo=%s, revision=%s", utils.SanitizeForLog(repo), utils.SanitizeForLog(revision))

// Create HuggingFace client
hfOpts := []huggingface.ClientOption{
huggingface.WithUserAgent(registry.DefaultUserAgent),
}
if token != "" {
hfOpts = append(hfOpts, huggingface.WithToken(token))
}
hfClient := huggingface.NewClient(hfOpts...)

// Create temp directory for downloads
tempDir, err := os.MkdirTemp("", "hf-model-*")
if err != nil {
return fmt.Errorf("create temp dir: %w", err)
}
defer os.RemoveAll(tempDir)

// Build model from HuggingFace repository
model, err := huggingface.BuildModel(ctx, hfClient, repo, revision, tempDir, progressWriter)
if err != nil {
// Convert HuggingFace errors to registry errors for consistent handling
var authErr *huggingface.AuthError
var notFoundErr *huggingface.NotFoundError
if errors.As(err, &authErr) {
return registry.ErrUnauthorized
}
if errors.As(err, &notFoundErr) {
return registry.ErrModelNotFound
}
if writeErr := progress.WriteError(progressWriter, fmt.Sprintf("Error: %s", err.Error())); writeErr != nil {
c.log.Warnf("Failed to write error message: %v", writeErr)
}
return fmt.Errorf("build model from HuggingFace: %w", err)
}

// Write model to store
// Lowercase the reference for storage since OCI tags don't allow uppercase
storageTag := strings.ToLower(reference)
c.log.Infof("Writing model to store with tag: %s", utils.SanitizeForLog(storageTag))
if err := c.store.Write(model, []string{storageTag}, progressWriter); err != nil {
if writeErr := progress.WriteError(progressWriter, fmt.Sprintf("Error: %s", err.Error())); writeErr != nil {
c.log.Warnf("Failed to write error message: %v", writeErr)
}
return fmt.Errorf("writing model to store: %w", err)
}

if err := progress.WriteSuccess(progressWriter, "Model pulled successfully"); err != nil {
c.log.Warnf("Failed to write success message: %v", err)
}

return nil
}
118 changes: 115 additions & 3 deletions pkg/distribution/distribution/normalize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package distribution

import (
"context"
"errors"
"io"
"path/filepath"
"strings"
"testing"

"github.com/docker/model-runner/pkg/distribution/builder"
"github.com/docker/model-runner/pkg/distribution/registry"
"github.com/docker/model-runner/pkg/distribution/tarball"
"github.com/sirupsen/logrus"
)
Expand Down Expand Up @@ -66,7 +68,7 @@ func TestNormalizeModelName(t *testing.T) {
expected: "registry.example.com/myorg/model:v1",
},

// HuggingFace cases
// HuggingFace cases (lowercased for OCI reference compatibility)
{
name: "huggingface short form lowercase",
input: "hf.co/model",
Expand All @@ -75,12 +77,12 @@ func TestNormalizeModelName(t *testing.T) {
{
name: "huggingface short form uppercase",
input: "hf.co/Model",
expected: "huggingface.co/model:latest",
expected: "huggingface.co/model:latest", // lowercased for OCI compatibility
},
{
name: "huggingface short form with org",
input: "hf.co/MyOrg/MyModel",
expected: "huggingface.co/myorg/mymodel:latest",
expected: "huggingface.co/myorg/mymodel:latest", // lowercased for OCI compatibility
},
{
name: "huggingface with tag",
Expand Down Expand Up @@ -355,6 +357,116 @@ func createTestClient(t *testing.T) (*Client, func()) {
return client, cleanup
}

func TestIsHuggingFaceReference(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{"huggingface.co prefix", "huggingface.co/org/model:latest", true},
{"huggingface.co without tag", "huggingface.co/org/model", true},
{"not huggingface", "registry.example.com/model:latest", false},
{"docker hub", "ai/gemma3:latest", false},
{"hf.co prefix (not normalized)", "hf.co/org/model", false}, // This is the un-normalized form
{"empty", "", false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isHuggingFaceReference(tt.input)
if result != tt.expected {
t.Errorf("isHuggingFaceReference(%q) = %v, want %v", tt.input, result, tt.expected)
}
})
}
}

func TestParseHFReference(t *testing.T) {
tests := []struct {
name string
input string
expectedRepo string
expectedRev string
}{
{
name: "basic with latest tag",
input: "huggingface.co/org/model:latest",
expectedRepo: "org/model",
expectedRev: "main", // latest maps to main
},
{
name: "with explicit revision",
input: "huggingface.co/org/model:v1.0",
expectedRepo: "org/model",
expectedRev: "v1.0",
},
{
name: "without tag",
input: "huggingface.co/org/model",
expectedRepo: "org/model",
expectedRev: "main",
},
{
name: "with commit hash as tag",
input: "huggingface.co/HuggingFaceTB/SmolLM2-135M-Instruct:abc123",
expectedRepo: "HuggingFaceTB/SmolLM2-135M-Instruct",
expectedRev: "abc123",
},
{
name: "single name (no org)",
input: "huggingface.co/model:latest",
expectedRepo: "model",
expectedRev: "main",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo, rev := parseHFReference(tt.input)
if repo != tt.expectedRepo {
t.Errorf("parseHFReference(%q) repo = %q, want %q", tt.input, repo, tt.expectedRepo)
}
if rev != tt.expectedRev {
t.Errorf("parseHFReference(%q) rev = %q, want %q", tt.input, rev, tt.expectedRev)
}
})
}
}

func TestIsNotOCIError(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{"nil error", nil, false},
{"generic error", errors.New("some error"), false},
{"manifest unknown in message", errors.New("MANIFEST_UNKNOWN: manifest not found"), true},
{"name unknown in message", errors.New("NAME_UNKNOWN: repository not found"), true},
{"manifest unknown lowercase", errors.New("manifest unknown"), true},
{"unrelated error", errors.New("network timeout"), false},
{"HuggingFace not GGUF error", errors.New("Repository is not GGUF or is not compatible with llama.cpp"), true},
{"HuggingFace llama.cpp incompatible", errors.New("not compatible with llama.cpp"), true},
// registry.Error typed error cases
{"registry error MANIFEST_UNKNOWN", &registry.Error{Code: "MANIFEST_UNKNOWN"}, true},
{"registry error NAME_UNKNOWN", &registry.Error{Code: "NAME_UNKNOWN"}, true},
{"registry error other code", &registry.Error{Code: "UNAUTHORIZED"}, false},
// ErrInvalidReference is NOT treated as "not OCI" - it's a format error
// that should be reported to the user. Model names are lowercased during
// normalization to ensure OCI compatibility.
{"invalid reference error", registry.ErrInvalidReference, false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isNotOCIError(tt.err)
if result != tt.expected {
t.Errorf("isNotOCIError(%v) = %v, want %v", tt.err, result, tt.expected)
}
})
}
}

// Helper function to load a test model and return its ID
func loadTestModel(t *testing.T, client *Client, ggufPath string) string {
t.Helper()
Expand Down
Loading