Skip to content

Commit 50465f9

Browse files
feat: add support for Safetensors model format error handling (#191)
* feat: add support for safetensors model format error handling * Update pkg/distribution/distribution/client_test.go Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent de7e2d2 commit 50465f9

File tree

3 files changed

+64
-1
lines changed

3 files changed

+64
-1
lines changed

pkg/distribution/distribution/client.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,5 +408,16 @@ func checkCompat(image types.ModelArtifact) error {
408408
if manifest.Config.MediaType != types.MediaTypeModelConfigV01 {
409409
return fmt.Errorf("config type %q is unsupported: %w", manifest.Config.MediaType, ErrUnsupportedMediaType)
410410
}
411+
412+
// Check if the model format is supported
413+
config, err := image.Config()
414+
if err != nil {
415+
return fmt.Errorf("reading model config: %w", err)
416+
}
417+
418+
if config.Format == types.FormatSafetensors {
419+
return ErrUnsupportedFormat
420+
}
421+
411422
return nil
412423
}

pkg/distribution/distribution/client_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"github.com/docker/model-runner/pkg/distribution/internal/gguf"
2525
"github.com/docker/model-runner/pkg/distribution/internal/mutate"
2626
"github.com/docker/model-runner/pkg/distribution/internal/progress"
27+
"github.com/docker/model-runner/pkg/distribution/internal/safetensors"
2728
mdregistry "github.com/docker/model-runner/pkg/distribution/registry"
2829
)
2930

@@ -417,6 +418,56 @@ func TestClientPullModel(t *testing.T) {
417418
}
418419
})
419420

421+
t.Run("pull safetensors model returns error", func(t *testing.T) {
422+
// Create temp directory for the safetensors file
423+
tempDir, err := os.MkdirTemp("", "safetensors-test-*")
424+
if err != nil {
425+
t.Fatalf("Failed to create temp directory: %v", err)
426+
}
427+
defer os.RemoveAll(tempDir)
428+
429+
// Create a minimal safetensors file (just needs to exist for this test)
430+
safetensorsPath := filepath.Join(tempDir, "model.safetensors")
431+
safetensorsContent := []byte("fake safetensors content for testing")
432+
if err := os.WriteFile(safetensorsPath, safetensorsContent, 0644); err != nil {
433+
t.Fatalf("Failed to create safetensors file: %v", err)
434+
}
435+
436+
// Create a safetensors model
437+
safetensorsModel, err := safetensors.NewModel([]string{safetensorsPath})
438+
if err != nil {
439+
t.Fatalf("Failed to create safetensors model: %v", err)
440+
}
441+
442+
// Push to registry
443+
tag := registry + "/safetensors-test/model:v1.0.0"
444+
ref, err := name.ParseReference(tag)
445+
if err != nil {
446+
t.Fatalf("Failed to parse reference: %v", err)
447+
}
448+
if err := remote.Write(ref, safetensorsModel); err != nil {
449+
t.Fatalf("Failed to push safetensors model to registry: %v", err)
450+
}
451+
452+
// Create a new client with a separate temp store
453+
clientTempDir, err := os.MkdirTemp("", "client-safetensors-test-*")
454+
if err != nil {
455+
t.Fatalf("Failed to create client temp directory: %v", err)
456+
}
457+
defer os.RemoveAll(clientTempDir)
458+
459+
testClient, err := NewClient(WithStoreRootPath(clientTempDir))
460+
if err != nil {
461+
t.Fatalf("Failed to create test client: %v", err)
462+
}
463+
464+
// Try to pull the safetensors model - should fail with ErrUnsupportedFormat
465+
err = testClient.PullModel(context.Background(), tag, nil)
466+
if !errors.Is(err, ErrUnsupportedFormat) {
467+
t.Fatalf("Expected ErrUnsupportedFormat, got: %v", err)
468+
}
469+
})
470+
420471
t.Run("pull with JSON progress messages", func(t *testing.T) {
421472
// Create temp directory for store
422473
tempDir, err := os.MkdirTemp("", "model-distribution-json-test-*")

pkg/distribution/distribution/errors.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ var (
1616
"client supports only models of type %q and older - try upgrading",
1717
types.MediaTypeModelConfigV01,
1818
))
19-
ErrConflict = errors.New("resource conflict")
19+
ErrUnsupportedFormat = errors.New("safetensors models are not currently supported - this runner only supports GGUF format models")
20+
ErrConflict = errors.New("resource conflict")
2021
)
2122

2223
// ReferenceError represents an error related to an invalid model reference

0 commit comments

Comments
 (0)