Skip to content

Commit 2e7c686

Browse files
committed
Add ability to pull vllm-compatible hf models
This commit introduces native HuggingFace model support by adding a new HuggingFace client implementation that can download safetensors files directly from HuggingFace Hub repositories. The changes include: A new HuggingFace client with authentication, file listing, and download capabilities. The client handles LFS files, error responses, and rate limiting appropriately. A downloader component that manages parallel file downloads with progress reporting and temporary file storage. It includes progress tracking and concurrent download limiting. Model building functionality that downloads files from HuggingFace repositories and constructs OCI model artifacts using the existing builder framework. Repository utilities for file classification, filtering, and size calculations to identify safetensors and config files needed for model construction. Integration with the existing pull mechanism to detect HuggingFace references and attempt native pulling when no OCI manifest is found. This preserves existing OCI functionality while adding fallback support for raw HuggingFace repositories. Signed-off-by: Eric Curtin <[email protected]>
1 parent fb80c6d commit 2e7c686

File tree

8 files changed

+1188
-1
lines changed

8 files changed

+1188
-1
lines changed

pkg/distribution/distribution/client.go

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ import (
66
"fmt"
77
"io"
88
"net/http"
9+
"os"
910
"slices"
1011
"strings"
1112

13+
"github.com/docker/model-runner/pkg/distribution/huggingface"
1214
"github.com/docker/model-runner/pkg/distribution/internal/progress"
1315
"github.com/docker/model-runner/pkg/distribution/internal/store"
1416
"github.com/docker/model-runner/pkg/distribution/registry"
@@ -267,15 +269,22 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter
267269

268270
// Use the client's registry, or create a temporary one if bearer token is provided
269271
registryClient := c.registry
272+
var token string
270273
if len(bearerToken) > 0 && bearerToken[0] != "" {
274+
token = bearerToken[0]
271275
// Create a temporary registry client with bearer token authentication
272-
auth := &authn.Bearer{Token: bearerToken[0]}
276+
auth := &authn.Bearer{Token: token}
273277
registryClient = registry.FromClient(c.registry, registry.WithAuth(auth))
274278
}
275279

276280
// First, fetch the remote model to get the manifest
277281
remoteModel, err := registryClient.Model(ctx, reference)
278282
if err != nil {
283+
// Check if this is a HuggingFace reference and the error indicates no OCI manifest
284+
if isHuggingFaceReference(reference) && isNotOCIError(err) {
285+
c.log.Infoln("No OCI manifest found, attempting native HuggingFace pull")
286+
return c.pullNativeHuggingFace(ctx, reference, progressWriter, token)
287+
}
279288
return fmt.Errorf("reading model from registry: %w", err)
280289
}
281290

@@ -637,3 +646,103 @@ func checkCompat(image types.ModelArtifact, log *logrus.Entry, reference string,
637646

638647
return nil
639648
}
649+
650+
// isHuggingFaceReference checks if a reference is a HuggingFace model reference
651+
func isHuggingFaceReference(reference string) bool {
652+
return strings.HasPrefix(reference, "huggingface.co/")
653+
}
654+
655+
// isNotOCIError checks if the error indicates the model is not OCI-formatted
656+
// This happens when the HuggingFace repository doesn't have an OCI manifest
657+
func isNotOCIError(err error) bool {
658+
if err == nil {
659+
return false
660+
}
661+
662+
// Check for registry errors indicating no manifest
663+
var regErr *registry.Error
664+
if errors.As(err, &regErr) {
665+
return regErr.Code == "MANIFEST_UNKNOWN" || regErr.Code == "NAME_UNKNOWN"
666+
}
667+
668+
// Also check error message for common patterns
669+
errStr := err.Error()
670+
return strings.Contains(errStr, "MANIFEST_UNKNOWN") ||
671+
strings.Contains(errStr, "NAME_UNKNOWN") ||
672+
strings.Contains(errStr, "manifest unknown")
673+
}
674+
675+
// parseHFReference extracts repo and revision from a normalized HF reference
676+
// e.g., "huggingface.co/org/model:revision" -> ("org/model", "revision")
677+
// e.g., "huggingface.co/org/model:latest" -> ("org/model", "main")
678+
func parseHFReference(reference string) (repo, revision string) {
679+
// Remove registry prefix
680+
ref := strings.TrimPrefix(reference, "huggingface.co/")
681+
682+
// Split by colon to get tag
683+
parts := strings.SplitN(ref, ":", 2)
684+
repo = parts[0]
685+
686+
revision = "main"
687+
if len(parts) == 2 && parts[1] != "" && parts[1] != "latest" {
688+
revision = parts[1]
689+
}
690+
691+
return repo, revision
692+
}
693+
694+
// pullNativeHuggingFace pulls a native HuggingFace repository (non-OCI format)
695+
// This is used when the model is stored as raw files (safetensors) on HuggingFace Hub
696+
func (c *Client) pullNativeHuggingFace(ctx context.Context, reference string, progressWriter io.Writer, token string) error {
697+
repo, revision := parseHFReference(reference)
698+
c.log.Infof("Pulling native HuggingFace model: repo=%s, revision=%s", repo, revision)
699+
700+
// Create HuggingFace client
701+
hfOpts := []huggingface.ClientOption{
702+
huggingface.WithUserAgent(registry.DefaultUserAgent),
703+
}
704+
if token != "" {
705+
hfOpts = append(hfOpts, huggingface.WithToken(token))
706+
}
707+
hfClient := huggingface.NewClient(hfOpts...)
708+
709+
// Create temp directory for downloads
710+
tempDir, err := os.MkdirTemp("", "hf-model-*")
711+
if err != nil {
712+
return fmt.Errorf("create temp dir: %w", err)
713+
}
714+
defer os.RemoveAll(tempDir)
715+
716+
// Build model from HuggingFace repository
717+
model, err := huggingface.BuildModel(ctx, hfClient, repo, revision, tempDir, progressWriter)
718+
if err != nil {
719+
// Convert HuggingFace errors to registry errors for consistent handling
720+
var authErr *huggingface.AuthError
721+
var notFoundErr *huggingface.NotFoundError
722+
if errors.As(err, &authErr) {
723+
return registry.ErrUnauthorized
724+
}
725+
if errors.As(err, &notFoundErr) {
726+
return registry.ErrModelNotFound
727+
}
728+
if writeErr := progress.WriteError(progressWriter, fmt.Sprintf("Error: %s", err.Error())); writeErr != nil {
729+
c.log.Warnf("Failed to write error message: %v", writeErr)
730+
}
731+
return fmt.Errorf("build model from HuggingFace: %w", err)
732+
}
733+
734+
// Write model to store
735+
c.log.Infoln("Writing model to store")
736+
if err := c.store.Write(model, []string{reference}, progressWriter); err != nil {
737+
if writeErr := progress.WriteError(progressWriter, fmt.Sprintf("Error: %s", err.Error())); writeErr != nil {
738+
c.log.Warnf("Failed to write error message: %v", writeErr)
739+
}
740+
return fmt.Errorf("writing model to store: %w", err)
741+
}
742+
743+
if err := progress.WriteSuccess(progressWriter, "Model pulled successfully"); err != nil {
744+
c.log.Warnf("Failed to write success message: %v", err)
745+
}
746+
747+
return nil
748+
}

pkg/distribution/distribution/normalize_test.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package distribution
22

33
import (
44
"context"
5+
"errors"
56
"io"
67
"path/filepath"
78
"strings"
@@ -355,6 +356,106 @@ func createTestClient(t *testing.T) (*Client, func()) {
355356
return client, cleanup
356357
}
357358

359+
func TestIsHuggingFaceReference(t *testing.T) {
360+
tests := []struct {
361+
name string
362+
input string
363+
expected bool
364+
}{
365+
{"huggingface.co prefix", "huggingface.co/org/model:latest", true},
366+
{"huggingface.co without tag", "huggingface.co/org/model", true},
367+
{"not huggingface", "registry.example.com/model:latest", false},
368+
{"docker hub", "ai/gemma3:latest", false},
369+
{"hf.co prefix (not normalized)", "hf.co/org/model", false}, // This is the un-normalized form
370+
{"empty", "", false},
371+
}
372+
373+
for _, tt := range tests {
374+
t.Run(tt.name, func(t *testing.T) {
375+
result := isHuggingFaceReference(tt.input)
376+
if result != tt.expected {
377+
t.Errorf("isHuggingFaceReference(%q) = %v, want %v", tt.input, result, tt.expected)
378+
}
379+
})
380+
}
381+
}
382+
383+
func TestParseHFReference(t *testing.T) {
384+
tests := []struct {
385+
name string
386+
input string
387+
expectedRepo string
388+
expectedRev string
389+
}{
390+
{
391+
name: "basic with latest tag",
392+
input: "huggingface.co/org/model:latest",
393+
expectedRepo: "org/model",
394+
expectedRev: "main", // latest maps to main
395+
},
396+
{
397+
name: "with explicit revision",
398+
input: "huggingface.co/org/model:v1.0",
399+
expectedRepo: "org/model",
400+
expectedRev: "v1.0",
401+
},
402+
{
403+
name: "without tag",
404+
input: "huggingface.co/org/model",
405+
expectedRepo: "org/model",
406+
expectedRev: "main",
407+
},
408+
{
409+
name: "with commit hash as tag",
410+
input: "huggingface.co/HuggingFaceTB/SmolLM2-135M-Instruct:abc123",
411+
expectedRepo: "HuggingFaceTB/SmolLM2-135M-Instruct",
412+
expectedRev: "abc123",
413+
},
414+
{
415+
name: "single name (no org)",
416+
input: "huggingface.co/model:latest",
417+
expectedRepo: "model",
418+
expectedRev: "main",
419+
},
420+
}
421+
422+
for _, tt := range tests {
423+
t.Run(tt.name, func(t *testing.T) {
424+
repo, rev := parseHFReference(tt.input)
425+
if repo != tt.expectedRepo {
426+
t.Errorf("parseHFReference(%q) repo = %q, want %q", tt.input, repo, tt.expectedRepo)
427+
}
428+
if rev != tt.expectedRev {
429+
t.Errorf("parseHFReference(%q) rev = %q, want %q", tt.input, rev, tt.expectedRev)
430+
}
431+
})
432+
}
433+
}
434+
435+
func TestIsNotOCIError(t *testing.T) {
436+
tests := []struct {
437+
name string
438+
err error
439+
expected bool
440+
}{
441+
{"nil error", nil, false},
442+
{"generic error", errors.New("some error"), false},
443+
{"manifest unknown in message", errors.New("MANIFEST_UNKNOWN: manifest not found"), true},
444+
{"name unknown in message", errors.New("NAME_UNKNOWN: repository not found"), true},
445+
{"manifest unknown lowercase", errors.New("manifest unknown"), true},
446+
{"unrelated error", errors.New("network timeout"), false},
447+
}
448+
449+
for _, tt := range tests {
450+
t.Run(tt.name, func(t *testing.T) {
451+
result := isNotOCIError(tt.err)
452+
if result != tt.expected {
453+
t.Errorf("isNotOCIError(%v) = %v, want %v", tt.err, result, tt.expected)
454+
}
455+
})
456+
}
457+
}
458+
358459
// Helper function to load a test model and return its ID
359460
func loadTestModel(t *testing.T, client *Client, ggufPath string) string {
360461
t.Helper()

0 commit comments

Comments
 (0)