Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions pkg/distribution/distribution/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,5 +408,16 @@ func checkCompat(image types.ModelArtifact) error {
if manifest.Config.MediaType != types.MediaTypeModelConfigV01 {
return fmt.Errorf("config type %q is unsupported: %w", manifest.Config.MediaType, ErrUnsupportedMediaType)
}

// Check if the model format is supported
config, err := image.Config()
if err != nil {
return fmt.Errorf("reading model config: %w", err)
}
Comment on lines +413 to +416
Copy link

Copilot AI Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The config is retrieved and parsed even when it may not be needed. Consider checking if the config has already been retrieved earlier in the call chain to avoid redundant operations.

Copilot uses AI. Check for mistakes.

if config.Format == types.FormatSafetensors {
return ErrUnsupportedFormat
}

return nil
}
51 changes: 51 additions & 0 deletions pkg/distribution/distribution/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/docker/model-runner/pkg/distribution/internal/gguf"
"github.com/docker/model-runner/pkg/distribution/internal/mutate"
"github.com/docker/model-runner/pkg/distribution/internal/progress"
"github.com/docker/model-runner/pkg/distribution/internal/safetensors"
mdregistry "github.com/docker/model-runner/pkg/distribution/registry"
)

Expand Down Expand Up @@ -417,6 +418,56 @@ func TestClientPullModel(t *testing.T) {
}
})

t.Run("pull safetensors model returns error", func(t *testing.T) {
// Create temp directory for the safetensors file
tempDir, err := os.MkdirTemp("", "safetensors-test-*")
if err != nil {
t.Fatalf("Failed to create temp directory: %v", err)
}
defer os.RemoveAll(tempDir)

// Create a minimal safetensors file (just needs to exist for this test)
safetensorsPath := filepath.Join(tempDir, "model.safetensors")
safetensorsContent := []byte("fake safetensors content for testing")
if err := os.WriteFile(safetensorsPath, safetensorsContent, 0644); err != nil {
t.Fatalf("Failed to create safetensors file: %v", err)
}

// Create a safetensors model
safetensorsModel, err := safetensors.NewModel([]string{safetensorsPath})
if err != nil {
t.Fatalf("Failed to create safetensors model: %v", err)
}

// Push to registry
tag := registry + "/safetensors-test/model:v1.0.0"
ref, err := name.ParseReference(tag)
if err != nil {
t.Fatalf("Failed to parse reference: %v", err)
}
if err := remote.Write(ref, safetensorsModel); err != nil {
t.Fatalf("Failed to push safetensors model to registry: %v", err)
}

// Create a new client with a separate temp store
clientTempDir, err := os.MkdirTemp("", "client-safetensors-test-*")
if err != nil {
t.Fatalf("Failed to create client temp directory: %v", err)
}
defer os.RemoveAll(clientTempDir)

testClient, err := NewClient(WithStoreRootPath(clientTempDir))
if err != nil {
t.Fatalf("Failed to create test client: %v", err)
}

// Try to pull the safetensors model - should fail with ErrUnsupportedFormat
err = testClient.PullModel(context.Background(), tag, nil)
if !errors.Is(err, ErrUnsupportedFormat) {
t.Fatalf("Expected ErrUnsupportedFormat, got: %v", err)
}
})

t.Run("pull with JSON progress messages", func(t *testing.T) {
// Create temp directory for store
tempDir, err := os.MkdirTemp("", "model-distribution-json-test-*")
Expand Down
3 changes: 2 additions & 1 deletion pkg/distribution/distribution/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ var (
"client supports only models of type %q and older - try upgrading",
types.MediaTypeModelConfigV01,
))
ErrConflict = errors.New("resource conflict")
ErrUnsupportedFormat = errors.New("safetensors models are not currently supported - this runner only supports GGUF format models")
ErrConflict = errors.New("resource conflict")
)

// ReferenceError represents an error related to an invalid model reference
Expand Down
Loading