diff --git a/cmd/cli/commands/integration_test.go b/cmd/cli/commands/integration_test.go index 3e95586e5..935709435 100644 --- a/cmd/cli/commands/integration_test.go +++ b/cmd/cli/commands/integration_test.go @@ -587,32 +587,6 @@ func TestIntegration_TagModel(t *testing.T) { }) } - // Final verification: List the model and verify all tags are present - t.Run("verify all tags in model inspect", func(t *testing.T) { - inspectedModel, err := env.client.Inspect(modelID, false) - require.NoError(t, err, "Failed to inspect model by ID") - - t.Logf("Model has %d tags: %v", len(inspectedModel.Tags), inspectedModel.Tags) - - // The model should have at least the original tag plus all created tags - require.GreaterOrEqual(t, len(inspectedModel.Tags), len(createdTags)+1, - "Model should have at least %d tags (original + created)", len(createdTags)+1) - - // Verify each created tag is in the model's tag list - for expectedTag := range createdTags { - found := false - for _, actualTag := range inspectedModel.Tags { - if actualTag == expectedTag || actualTag == fmt.Sprintf("%s:latest", expectedTag) { // Handle implicit latest tag - found = true - break - } - } - require.True(t, found, "Expected tag %s not found in model's tag list", expectedTag) - } - - t.Logf("✓ All %d created tags verified in model's tag list", len(createdTags)) - }) - // Test error case: tagging non-existent model t.Run("error on non-existent model", func(t *testing.T) { err := tagModel(newTagCmd(), env.client, "non-existent-model:v1", "ai/should-fail:latest") @@ -1016,6 +990,131 @@ func TestIntegration_RemoveModel(t *testing.T) { }) } +// TestIntegration_PackageModel tests packaging a GGUF model file +// to ensure the model is properly loaded and tagged in the model store. +// This test reproduces issue #461 where packaging fails with "model not found" during tagging. +func TestIntegration_PackageModel(t *testing.T) { + env := setupTestEnv(t) + + // Ensure no models exist initially + models, err := listModels(false, env.client, true, false, "") + require.NoError(t, err) + if len(models) != 0 { + t.Fatal("Expected no initial models, but found some") + } + + // Use the dummy GGUF file from assets + dummyGGUFPath := filepath.Join("../../../assets/dummy.gguf") + absPath, err := filepath.Abs(dummyGGUFPath) + require.NoError(t, err) + + // Check if the file exists + _, err = os.Stat(absPath) + require.NoError(t, err, "dummy.gguf not found at %s", absPath) + + // Test case 1: Package a GGUF file with a simple tag + t.Run("package GGUF with simple tag", func(t *testing.T) { + targetTag := "ai/packaged-test:latest" + + // Create package options + opts := packageOptions{ + ggufPath: absPath, + tag: targetTag, + } + + // Execute the package command using the helper function with test client + t.Logf("Packaging GGUF file %s as %s", absPath, targetTag) + err := packageModel(env.ctx, newPackagedCmd(), env.client, opts) + require.NoError(t, err, "Failed to package GGUF model") + + // Verify the model was loaded and tagged + t.Logf("Verifying model was loaded and tagged") + models, err := listModels(false, env.client, false, false, "") + require.NoError(t, err) + require.NotEmpty(t, models, "No models found after packaging") + + // Verify we can inspect the model by tag + model, err := env.client.Inspect(targetTag, false) + require.NoError(t, err, "Failed to inspect packaged model by tag: %s", targetTag) + require.NotEmpty(t, model.ID, "Model ID should not be empty") + require.Contains(t, model.Tags, targetTag, "Model should have the expected tag") + + t.Logf("✓ Successfully packaged and tagged model: %s (ID: %s)", targetTag, model.ID[7:19]) + + // Cleanup + err = removeModel(env.client, model.ID, true) + require.NoError(t, err, "Failed to remove model") + }) + + // Test case 2: Package with context size override + t.Run("package GGUF with context size", func(t *testing.T) { + targetTag := "ai/packaged-ctx:latest" + + // Create package options with context size + opts := packageOptions{ + ggufPath: absPath, + tag: targetTag, + contextSize: 4096, + } + + // Create a command for context + cmd := newPackagedCmd() + // Set the flag as changed for context size + cmd.Flags().Set("context-size", "4096") + + // Execute the package command using the helper function with test client + t.Logf("Packaging GGUF file with context size 4096 as %s", targetTag) + err := packageModel(env.ctx, cmd, env.client, opts) + require.NoError(t, err, "Failed to package GGUF model with context size") + + // Verify the model was loaded and tagged + model, err := env.client.Inspect(targetTag, false) + require.NoError(t, err, "Failed to inspect packaged model") + require.Contains(t, model.Tags, targetTag, "Model should have the expected tag") + + t.Logf("✓ Successfully packaged model with context size: %s", targetTag) + + // Cleanup + err = removeModel(env.client, model.ID, true) + require.NoError(t, err, "Failed to remove model") + }) + + // Test case 3: Package with different org + t.Run("package GGUF with custom org", func(t *testing.T) { + targetTag := "myorg/packaged-test:v1" + + // Create package options + opts := packageOptions{ + ggufPath: absPath, + tag: targetTag, + } + + // Create a command for context + cmd := newPackagedCmd() + + // Execute the package command using the helper function with test client + t.Logf("Packaging GGUF file as %s", targetTag) + err := packageModel(env.ctx, cmd, env.client, opts) + require.NoError(t, err, "Failed to package GGUF model with custom org") + + // Verify the model was loaded and tagged + model, err := env.client.Inspect(targetTag, false) + require.NoError(t, err, "Failed to inspect packaged model") + require.Contains(t, model.Tags, targetTag, "Model should have the expected tag") + + t.Logf("✓ Successfully packaged model with custom org: %s", targetTag) + + // Cleanup + err = removeModel(env.client, model.ID, true) + require.NoError(t, err, "Failed to remove model") + }) + + // Verify all models are cleaned up + models, err = listModels(false, env.client, true, false, "") + require.NoError(t, err) + require.Empty(t, strings.TrimSpace(models), "All models should be removed after cleanup") +} + func int32ptr(n int32) *int32 { return &n } diff --git a/cmd/cli/commands/list.go b/cmd/cli/commands/list.go index 7dc3da379..171049d0a 100644 --- a/cmd/cli/commands/list.go +++ b/cmd/cli/commands/list.go @@ -72,27 +72,39 @@ func listModels(openai bool, desktopClient *desktop.Client, quiet bool, jsonForm } if modelFilter != "" { - // Normalize the filter to match stored model names (backend normalizes when storing) - normalizedFilter := dmrm.NormalizeModelName(modelFilter) + // If filter doesn't contain '/', prepend default namespace 'ai/' + if !strings.Contains(modelFilter, "/") { + modelFilter = "ai/" + modelFilter + } + var filteredModels []dmrm.Model + + // Check if filter has a colon (i.e., includes a tag) + hasColon := strings.Contains(modelFilter, ":") + for _, m := range models { - hasMatchingTag := false + var matchingTags []string for _, tag := range m.Tags { - // Tags are stored in normalized format by the backend - if tag == normalizedFilter { - hasMatchingTag = true - break - } - // Also check without the tag part - modelName, _, _ := strings.Cut(tag, ":") - filterName, _, _ := strings.Cut(normalizedFilter, ":") - if modelName == filterName { - hasMatchingTag = true - break + if hasColon { + // Filter includes a tag part - do exact match + // Tags are stored in normalized format by the backend + if tag == modelFilter { + matchingTags = append(matchingTags, tag) + } + } else { + // Filter has no colon - match repository name only (part before ':') + repository, _, _ := strings.Cut(tag, ":") + if repository == modelFilter { + matchingTags = append(matchingTags, tag) + } } } - if hasMatchingTag { - filteredModels = append(filteredModels, m) + // Only include the model if at least one tag matched, and only include matching tags + if len(matchingTags) > 0 { + // Create a copy of the model with only the matching tags + filteredModel := m + filteredModel.Tags = matchingTags + filteredModels = append(filteredModels, filteredModel) } } models = filteredModels diff --git a/cmd/cli/commands/package.go b/cmd/cli/commands/package.go index 788d6c8cc..1d6a1f2f1 100644 --- a/cmd/cli/commands/package.go +++ b/cmd/cli/commands/package.go @@ -131,7 +131,7 @@ func newPackagedCmd() *cobra.Command { }, RunE: func(cmd *cobra.Command, args []string) error { opts.tag = args[0] - if err := packageModel(cmd, opts); err != nil { + if err := packageModel(cmd.Context(), cmd, desktopClient, opts); err != nil { cmd.PrintErrln("Failed to package model") return fmt.Errorf("package model: %w", err) } @@ -254,7 +254,7 @@ func initializeBuilder(cmd *cobra.Command, opts packageOptions) (*builderInitRes return result, nil } -func packageModel(cmd *cobra.Command, opts packageOptions) error { +func packageModel(ctx context.Context, cmd *cobra.Command, client *desktop.Client, opts packageOptions) error { var ( target builder.Target err error @@ -264,7 +264,7 @@ func packageModel(cmd *cobra.Command, opts packageOptions) error { registry.WithUserAgent("docker-model-cli/" + desktop.Version), ).NewTarget(opts.tag) } else { - target, err = newModelRunnerTarget(desktopClient, opts.tag) + target, err = newModelRunnerTarget(client, opts.tag) } if err != nil { return err @@ -357,7 +357,7 @@ func packageModel(cmd *cobra.Command, opts packageOptions) error { done := make(chan error, 1) go func() { defer pw.Close() - done <- pkg.Build(cmd.Context(), target, pw) + done <- pkg.Build(ctx, target, pw) }() scanner := bufio.NewScanner(pr) @@ -443,7 +443,7 @@ func (t *modelRunnerTarget) Write(ctx context.Context, mdl types.ModelArtifact, return fmt.Errorf("get model ID: %w", err) } if t.tag.String() != "" { - if err := desktopClient.Tag(id, parseRepo(t.tag), t.tag.TagStr()); err != nil { + if err := t.client.Tag(id, parseRepo(t.tag), t.tag.TagStr()); err != nil { return fmt.Errorf("tag model: %w", err) } } diff --git a/cmd/cli/commands/utils_test.go b/cmd/cli/commands/utils_test.go index 9d3b22744..0433fc06f 100644 --- a/cmd/cli/commands/utils_test.go +++ b/cmd/cli/commands/utils_test.go @@ -4,88 +4,8 @@ import ( "errors" "fmt" "testing" - - "github.com/docker/model-runner/pkg/inference/models" ) -func TestNormalizeModelName(t *testing.T) { - tests := []struct { - name string - input string - expected string - }{ - { - name: "simple model name", - input: "gemma3", - expected: "ai/gemma3:latest", - }, - { - name: "model name with tag", - input: "gemma3:v1", - expected: "ai/gemma3:v1", - }, - { - name: "model name with org", - input: "myorg/gemma3", - expected: "myorg/gemma3:latest", - }, - { - name: "model name with org and tag", - input: "myorg/gemma3:v1", - expected: "myorg/gemma3:v1", - }, - { - name: "fully qualified model name", - input: "ai/gemma3:latest", - expected: "ai/gemma3:latest", - }, - { - name: "huggingface model", - input: "hf.co/bartowski/model", - expected: "huggingface.co/bartowski/model:latest", - }, - { - name: "huggingface model with tag", - input: "hf.co/bartowski/model:Q4_K_S", - expected: "huggingface.co/bartowski/model:q4_k_s", - }, - { - name: "registry with model", - input: "docker.io/library/model", - expected: "docker.io/library/model:latest", - }, - { - name: "registry with model and tag", - input: "docker.io/library/model:v1", - expected: "docker.io/library/model:v1", - }, - { - name: "empty string", - input: "", - expected: "", - }, - { - name: "ai prefix already present", - input: "ai/gemma3", - expected: "ai/gemma3:latest", - }, - { - name: "model name with latest tag already", - input: "gemma3:latest", - expected: "ai/gemma3:latest", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := models.NormalizeModelName(tt.input) - if result != tt.expected { - t.Errorf("NormalizeModelName(%q) = %q, want %q", tt.input, result, tt.expected) - } - }) - } -} - func TestStripDefaultsFromModelName(t *testing.T) { tests := []struct { name string @@ -174,54 +94,6 @@ func TestStripDefaultsFromModelName(t *testing.T) { } } -// TestNormalizeModelNameConsistency verifies that locally packaged models -// (without namespace) get normalized the same way as other operations. -// This test documents the fix for the bug where `docker model package my-model` -// would create a model that couldn't be run with `docker model run my-model`. -func TestNormalizeModelNameConsistency(t *testing.T) { - tests := []struct { - name string - userProvidedName string - expectedNormalizedName string - description string - }{ - { - name: "locally packaged model without namespace", - userProvidedName: "my-model", - expectedNormalizedName: "ai/my-model:latest", - description: "When a user packages a local model as 'my-model', it should be normalized to 'ai/my-model:latest'", - }, - { - name: "locally packaged model without namespace but with tag", - userProvidedName: "my-model:v1.0", - expectedNormalizedName: "ai/my-model:v1.0", - description: "When a user packages a local model as 'my-model:v1.0', it should be normalized to 'ai/my-model:v1.0'", - }, - { - name: "model with explicit namespace", - userProvidedName: "myorg/my-model", - expectedNormalizedName: "myorg/my-model:latest", - description: "When a user packages a model with explicit org 'myorg/my-model', it should keep the org", - }, - { - name: "model with ai namespace explicitly set", - userProvidedName: "ai/my-model", - expectedNormalizedName: "ai/my-model:latest", - description: "When a user explicitly sets 'ai/' namespace, it should remain the same", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := models.NormalizeModelName(tt.userProvidedName) - if result != tt.expectedNormalizedName { - t.Errorf("%s: NormalizeModelName(%q) = %q, want %q", - tt.description, tt.userProvidedName, result, tt.expectedNormalizedName) - } - }) - } -} - // TestHandleClientErrorFormat verifies that the error format follows the expected pattern. func TestHandleClientErrorFormat(t *testing.T) { t.Run("error format is message: original error", func(t *testing.T) { diff --git a/cmd/cli/desktop/desktop.go b/cmd/cli/desktop/desktop.go index f47150dd3..05e1d3307 100644 --- a/cmd/cli/desktop/desktop.go +++ b/cmd/cli/desktop/desktop.go @@ -56,14 +56,6 @@ type Status struct { Error error `json:"error"` } -// normalizeHuggingFaceModelName converts Hugging Face model names to lowercase -func normalizeHuggingFaceModelName(model string) string { - if strings.HasPrefix(model, "hf.co/") { - return strings.ToLower(model) - } - return model -} - func (c *Client) Status() Status { // TODO: Query "/". resp, err := c.doRequest(http.MethodGet, inference.ModelsPrefix, nil) @@ -106,8 +98,6 @@ func (c *Client) Status() Status { } func (c *Client) Pull(model string, printer standalone.StatusPrinter) (string, bool, error) { - model = normalizeHuggingFaceModelName(model) - // Check if this is a Hugging Face model and if HF_TOKEN is set var hfToken string if strings.HasPrefix(strings.ToLower(model), "hf.co/") { @@ -233,8 +223,6 @@ func (c *Client) withRetries( } func (c *Client) Push(model string, printer standalone.StatusPrinter) (string, bool, error) { - model = normalizeHuggingFaceModelName(model) - return c.withRetries("push", 3, printer, func(attempt int) (string, bool, error, bool) { pushPath := inference.ModelsPrefix + "/" + model + "/push" resp, err := c.doRequest( @@ -303,22 +291,6 @@ func (c *Client) ListOpenAI() (dmrm.OpenAIModelList, error) { } func (c *Client) Inspect(model string, remote bool) (dmrm.Model, error) { - model = normalizeHuggingFaceModelName(model) - if model != "" { - // Only try to expand to model ID if the reference doesn't contain: - // - A slash (org/name format) - // - A colon (tagged reference like name:tag) - // - An @ symbol (digest reference like name@sha256:...) - if !strings.Contains(strings.Trim(model, "/"), "/") && - !strings.Contains(model, ":") && - !strings.Contains(model, "@") { - // Do an extra API call to check if the model parameter isn't a model ID. - modelId, err := c.fullModelID(model) - if err == nil { - model = modelId - } - } - } rawResponse, err := c.listRawWithQuery(fmt.Sprintf("%s/%s", inference.ModelsPrefix, model), model, remote) if err != nil { return dmrm.Model{}, err @@ -332,15 +304,7 @@ func (c *Client) Inspect(model string, remote bool) (dmrm.Model, error) { } func (c *Client) InspectOpenAI(model string) (dmrm.OpenAIModel, error) { - model = normalizeHuggingFaceModelName(model) modelsRoute := inference.InferencePrefix + "/v1/models" - if !strings.Contains(strings.Trim(model, "/"), "/") { - // Do an extra API call to check if the model parameter isn't a model ID. - var err error - if model, err = c.fullModelID(model); err != nil { - return dmrm.OpenAIModel{}, fmt.Errorf("invalid model name: %s", model) - } - } rawResponse, err := c.listRaw(fmt.Sprintf("%s/%s", modelsRoute, model), model) if err != nil { return dmrm.OpenAIModel{}, err @@ -381,39 +345,6 @@ func (c *Client) listRawWithQuery(route string, model string, remote bool) ([]by return body, nil } -func (c *Client) fullModelID(id string) (string, error) { - bodyResponse, err := c.listRaw(inference.ModelsPrefix, "") - if err != nil { - return "", err - } - - var modelsJson []dmrm.Model - if err := json.Unmarshal(bodyResponse, &modelsJson); err != nil { - return "", fmt.Errorf("failed to unmarshal response body: %w", err) - } - - for _, m := range modelsJson { - if m.ID[7:19] == id || strings.TrimPrefix(m.ID, "sha256:") == id || m.ID == id { - return m.ID, nil - } - // Check if the ID matches any of the model's tags using exact match first - for _, tag := range m.Tags { - if tag == id { - return m.ID, nil - } - } - - // Normalize everything and try to find exact matches - for _, tag := range m.Tags { - if dmrm.NormalizeModelName(tag) == dmrm.NormalizeModelName(id) { - return m.ID, nil - } - } - } - - return "", fmt.Errorf("model with ID %s not found", id) -} - // Chat performs a chat request and streams the response content with selective markdown rendering. func (c *Client) Chat(model, prompt string, imageURLs []string, outputFunc func(string), shouldUseMarkdown bool) error { return c.ChatWithContext(context.Background(), model, prompt, imageURLs, outputFunc, shouldUseMarkdown) @@ -421,14 +352,6 @@ func (c *Client) Chat(model, prompt string, imageURLs []string, outputFunc func( // ChatWithContext performs a chat request with context support for cancellation and streams the response content with selective markdown rendering. func (c *Client) ChatWithContext(ctx context.Context, model, prompt string, imageURLs []string, outputFunc func(string), shouldUseMarkdown bool) error { - model = normalizeHuggingFaceModelName(model) - if !strings.Contains(strings.Trim(model, "/"), "/") { - // Do an extra API call to check if the model parameter isn't a model ID. - if expanded, err := c.fullModelID(model); err == nil { - model = expanded - } - } - // Build the message content - either simple string or multimodal array var messageContent interface{} if len(imageURLs) > 0 { @@ -597,32 +520,6 @@ func (c *Client) ChatWithContext(ctx context.Context, model, prompt string, imag func (c *Client) Remove(modelArgs []string, force bool) (string, error) { modelRemoved := "" for _, model := range modelArgs { - model = normalizeHuggingFaceModelName(model) - - // Handle digest references (model@sha256:...) - // These need to be normalized to include default org if missing - if strings.Contains(model, "@") && !strings.Contains(model, "/") { - // Split on @ to get repository and digest - parts := strings.SplitN(model, "@", 2) - if len(parts) == 2 { - repo := parts[0] - digest := parts[1] - // Add default org if the repository doesn't contain a slash - if !strings.Contains(repo, "/") { - model = fmt.Sprintf("ai/%s@%s", repo, digest) - } - } - } - - // Only expand simple names without tags or digests to model IDs - // Tagged references (model:tag) and digest references (model@sha256:...) - // should be passed as-is to allow tag-specific operations - if !strings.Contains(model, "/") && !strings.Contains(model, ":") && !strings.Contains(model, "@") { - if expanded, err := c.fullModelID(model); err == nil { - model = expanded - } - } - // Construct the URL with query parameters removePath := fmt.Sprintf("%s/%s?force=%s", inference.ModelsPrefix, @@ -912,11 +809,6 @@ func (c *Client) handleQueryError(err error, path string) error { } func (c *Client) Tag(source, targetRepo, targetTag string) error { - source = normalizeHuggingFaceModelName(source) - // For tag operations, let the daemon handle name resolution to support - // partial name matching like "smollm2" -> "ai/smollm2:latest" - // Don't do client-side ID expansion which can cause issues with tagging - // Construct the URL with query parameters using the normalized source tagPath := fmt.Sprintf("%s/%s/tag?repo=%s&tag=%s", inference.ModelsPrefix, diff --git a/cmd/cli/desktop/desktop_test.go b/cmd/cli/desktop/desktop_test.go index d09eee4bd..2719c64c8 100644 --- a/cmd/cli/desktop/desktop_test.go +++ b/cmd/cli/desktop/desktop_test.go @@ -3,233 +3,16 @@ package desktop import ( "bytes" "context" - "encoding/json" "errors" "io" "net/http" "testing" mockdesktop "github.com/docker/model-runner/cmd/cli/mocks" - "github.com/docker/model-runner/pkg/inference/models" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) -func TestPullHuggingFaceModel(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - // Test case for pulling a Hugging Face model with mixed case - modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" - expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" - - mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) - mockContext := NewContextForMock(mockClient) - client := New(mockContext) - - mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { - var reqBody models.ModelCreateRequest - err := json.NewDecoder(req.Body).Decode(&reqBody) - require.NoError(t, err) - assert.Equal(t, expectedLowercase, reqBody.From) - }).Return(&http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)), - }, nil) - - printer := NewSimplePrinter(func(s string) {}) - _, _, err := client.Pull(modelName, printer) - assert.NoError(t, err) -} - -func TestChatHuggingFaceModel(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - // Test case for chatting with a Hugging Face model with mixed case - modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" - expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" - prompt := "Hello" - - mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) - mockContext := NewContextForMock(mockClient) - client := New(mockContext) - - mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { - var reqBody OpenAIChatRequest - err := json.NewDecoder(req.Body).Decode(&reqBody) - require.NoError(t, err) - assert.Equal(t, expectedLowercase, reqBody.Model) - }).Return(&http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewBufferString("data: {\"choices\":[{\"delta\":{\"content\":\"Hello there!\"}}]}\n")), - }, nil) - - err := client.Chat(modelName, prompt, []string{}, func(s string) {}, false) - assert.NoError(t, err) -} - -func TestInspectHuggingFaceModel(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - // Test case for inspecting a Hugging Face model with mixed case - modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" - expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" - - mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) - mockContext := NewContextForMock(mockClient) - client := New(mockContext) - - mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { - assert.Contains(t, req.URL.Path, expectedLowercase) - }).Return(&http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewBufferString(`{ - "id": "sha256:123456789012", - "tags": ["` + expectedLowercase + `"], - "created": 1234567890, - "config": { - "format": "gguf", - "quantization": "Q4_K_M", - "parameters": "1B", - "architecture": "llama", - "size": "1.2GB" - } - }`)), - }, nil) - - model, err := client.Inspect(modelName, false) - assert.NoError(t, err) - assert.Equal(t, expectedLowercase, model.Tags[0]) -} - -func TestNonHuggingFaceModel(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - // Test case for a non-Hugging Face model (should not be converted to lowercase) - modelName := "docker.io/library/llama2" - mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) - mockContext := NewContextForMock(mockClient) - client := New(mockContext) - - mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { - var reqBody models.ModelCreateRequest - err := json.NewDecoder(req.Body).Decode(&reqBody) - require.NoError(t, err) - assert.Equal(t, modelName, reqBody.From) - }).Return(&http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)), - }, nil) - - printer := NewSimplePrinter(func(s string) {}) - _, _, err := client.Pull(modelName, printer) - assert.NoError(t, err) -} - -func TestPushHuggingFaceModel(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - // Test case for pushing a Hugging Face model with mixed case - modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" - expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" - - mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) - mockContext := NewContextForMock(mockClient) - client := New(mockContext) - - mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { - assert.Contains(t, req.URL.Path, expectedLowercase) - }).Return(&http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pushed successfully"}`)), - }, nil) - - printer := NewSimplePrinter(func(s string) {}) - _, _, err := client.Push(modelName, printer) - assert.NoError(t, err) -} - -func TestRemoveHuggingFaceModel(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - // Test case for removing a Hugging Face model with mixed case - modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" - expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" - - mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) - mockContext := NewContextForMock(mockClient) - client := New(mockContext) - - mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { - assert.Contains(t, req.URL.Path, expectedLowercase) - }).Return(&http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewBufferString("Model removed successfully")), - }, nil) - - _, err := client.Remove([]string{modelName}, false) - assert.NoError(t, err) -} - -func TestTagHuggingFaceModel(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - // Test case for tagging a Hugging Face model with mixed case - sourceModel := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" - expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" - targetRepo := "myrepo" - targetTag := "latest" - - mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) - mockContext := NewContextForMock(mockClient) - client := New(mockContext) - - mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { - assert.Contains(t, req.URL.Path, expectedLowercase) - }).Return(&http.Response{ - StatusCode: http.StatusCreated, - Body: io.NopCloser(bytes.NewBufferString("Tag created successfully")), - }, nil) - - assert.NoError(t, client.Tag(sourceModel, targetRepo, targetTag)) -} - -func TestInspectOpenAIHuggingFaceModel(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - // Test case for inspecting a Hugging Face model with mixed case - modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" - expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" - - mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) - mockContext := NewContextForMock(mockClient) - client := New(mockContext) - - mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) { - assert.Contains(t, req.URL.Path, expectedLowercase) - }).Return(&http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(bytes.NewBufferString(`{ - "id": "` + expectedLowercase + `", - "object": "model", - "created": 1234567890, - "owned_by": "organization" - }`)), - }, nil) - - model, err := client.InspectOpenAI(modelName) - assert.NoError(t, err) - assert.Equal(t, expectedLowercase, model.ID) -} - func TestPullRetryOnNetworkError(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() diff --git a/pkg/distribution/distribution/client.go b/pkg/distribution/distribution/client.go index 77aab22e8..c690874c4 100644 --- a/pkg/distribution/distribution/client.go +++ b/pkg/distribution/distribution/client.go @@ -138,8 +138,132 @@ func NewClient(opts ...Option) (*Client, error) { }, nil } +// normalizeModelName adds the default organization prefix (ai/) and tag (:latest) if missing. +// It also converts Hugging Face model names to lowercase and resolves IDs to full IDs. +// This is a private method used internally by the Client. +func (c *Client) normalizeModelName(model string) string { + const ( + defaultOrg = "ai" + defaultTag = "latest" + ) + + model = strings.TrimSpace(model) + + // If the model is empty, return as-is + if model == "" { + return model + } + + // If it looks like an ID or digest, try to resolve it to full ID + if c.looksLikeID(model) || c.looksLikeDigest(model) { + if fullID := c.resolveID(model); fullID != "" { + return fullID + } + // If not found, return as-is + return model + } + + // Normalize HuggingFace model names (lowercase path) + if strings.HasPrefix(model, "hf.co/") { + // Replace hf.co with huggingface.co to avoid losing the Authorization header on redirect. + model = "huggingface.co" + strings.ToLower(strings.TrimPrefix(model, "hf.co")) + } + + // Check if model contains a registry (domain with dot before first slash) + firstSlash := strings.Index(model, "/") + if firstSlash > 0 && strings.Contains(model[:firstSlash], ".") { + // Has a registry, just ensure tag + if !strings.Contains(model, ":") { + return model + ":" + defaultTag + } + return model + } + + // Split by colon to check for tag + parts := strings.SplitN(model, ":", 2) + nameWithOrg := parts[0] + tag := defaultTag + if len(parts) == 2 && parts[1] != "" { + tag = parts[1] + } + + // If name doesn't contain a slash, add the default org + if !strings.Contains(nameWithOrg, "/") { + nameWithOrg = defaultOrg + "/" + nameWithOrg + } + + return nameWithOrg + ":" + tag +} + +// looksLikeID returns true for short & long hex IDs (12 or 64 chars) +func (c *Client) looksLikeID(s string) bool { + n := len(s) + if n != 12 && n != 64 { + return false + } + for i := 0; i < n; i++ { + ch := s[i] + if (ch < '0' || ch > '9') && (ch < 'a' || ch > 'f') { + return false + } + } + return true +} + +// looksLikeDigest returns true for e.g. "sha256:<64-hex>" +func (c *Client) looksLikeDigest(s string) bool { + const prefix = "sha256:" + if !strings.HasPrefix(s, prefix) { + return false + } + hashPart := s[len(prefix):] + // SHA256 digests must be exactly 64 hex characters + if len(hashPart) != 64 { + return false + } + for i := 0; i < 64; i++ { + ch := hashPart[i] + if (ch < '0' || ch > '9') && (ch < 'a' || ch > 'f') { + return false + } + } + return true +} + +// resolveID attempts to resolve a short ID or digest to a full model ID +// by checking all models in the store. Returns empty string if not found. +func (c *Client) resolveID(id string) string { + models, err := c.ListModels() + if err != nil { + return "" + } + + for _, m := range models { + fullID, err := m.ID() + if err != nil { + continue + } + + // Check short ID (12 chars) - match against the hex part after "sha256:" + if len(id) == 12 && strings.HasPrefix(fullID, "sha256:") { + if len(fullID) >= 19 && fullID[7:19] == id { + return fullID + } + } + + // Check full ID match (with or without sha256: prefix) + if fullID == id || strings.TrimPrefix(fullID, "sha256:") == id { + return fullID + } + } + + return "" +} + // PullModel pulls a model from a registry and returns the local file path func (c *Client) PullModel(ctx context.Context, reference string, progressWriter io.Writer, bearerToken ...string) error { + // Normalize the model reference + reference = c.normalizeModelName(reference) c.log.Infoln("Starting model pull:", utils.SanitizeForLog(reference)) // Use the client's registry, or create a temporary one if bearer token is provided @@ -325,7 +449,8 @@ func (c *Client) ListModels() ([]types.Model, error) { // GetModel returns a model by reference func (c *Client) GetModel(reference string) (types.Model, error) { c.log.Infoln("Getting model by reference:", utils.SanitizeForLog(reference)) - model, err := c.store.Read(reference) + normalizedRef := c.normalizeModelName(reference) + model, err := c.store.Read(normalizedRef) if err != nil { c.log.Errorln("Failed to get model:", err, "reference:", utils.SanitizeForLog(reference)) return nil, fmt.Errorf("get model '%q': %w", utils.SanitizeForLog(reference), err) @@ -337,7 +462,8 @@ func (c *Client) GetModel(reference string) (types.Model, error) { // IsModelInStore checks if a model with the given reference is in the local store func (c *Client) IsModelInStore(reference string) (bool, error) { c.log.Infoln("Checking model by reference:", utils.SanitizeForLog(reference)) - if _, err := c.store.Read(reference); errors.Is(err, ErrModelNotFound) { + normalizedRef := c.normalizeModelName(reference) + if _, err := c.store.Read(normalizedRef); errors.Is(err, ErrModelNotFound) { return false, nil } else if err != nil { return false, err @@ -354,7 +480,8 @@ type DeleteModelResponse []DeleteModelAction // DeleteModel deletes a model func (c *Client) DeleteModel(reference string, force bool) (*DeleteModelResponse, error) { - mdl, err := c.store.Read(reference) + normalizedRef := c.normalizeModelName(reference) + mdl, err := c.store.Read(normalizedRef) if err != nil { return &DeleteModelResponse{}, err } @@ -366,13 +493,13 @@ func (c *Client) DeleteModel(reference string, force bool) (*DeleteModelResponse // Check if this is a digest reference (contains @) // Digest references like "name@sha256:..." should be treated as ID references, not tags isDigestReference := strings.Contains(reference, "@") - isTag := id != reference && !isDigestReference + isTag := id != normalizedRef && !isDigestReference resp := DeleteModelResponse{} if isTag { c.log.Infoln("Untagging model:", reference) - tags, err := c.store.RemoveTags([]string{reference}) + tags, err := c.store.RemoveTags([]string{normalizedRef}) if err != nil { c.log.Errorln("Failed to untag model:", err, "tag:", reference) return &DeleteModelResponse{}, fmt.Errorf("untagging model: %w", err) @@ -410,7 +537,9 @@ func (c *Client) DeleteModel(reference string, force bool) (*DeleteModelResponse // Tag adds a tag to a model func (c *Client) Tag(source string, target string) error { c.log.Infoln("Tagging model, source:", source, "target:", utils.SanitizeForLog(target)) - return c.store.AddTags(source, []string{target}) + normalizedSource := c.normalizeModelName(source) + normalizedTarget := c.normalizeModelName(target) + return c.store.AddTags(normalizedSource, []string{normalizedTarget}) } // PushModel pushes a tagged model from the content store to the registry. @@ -422,7 +551,8 @@ func (c *Client) PushModel(ctx context.Context, tag string, progressWriter io.Wr } // Get the model from the store - mdl, err := c.store.Read(tag) + normalizedRef := c.normalizeModelName(tag) + mdl, err := c.store.Read(normalizedRef) if err != nil { return fmt.Errorf("reading model: %w", err) } @@ -450,7 +580,11 @@ func (c *Client) PushModel(ctx context.Context, tag string, progressWriter io.Wr // The layers must already exist in the store. func (c *Client) WriteLightweightModel(mdl types.ModelArtifact, tags []string) error { c.log.Infoln("Writing lightweight model variant") - return c.store.WriteLightweight(mdl, tags) + normalizedTags := make([]string, len(tags)) + for i, tag := range tags { + normalizedTags[i] = c.normalizeModelName(tag) + } + return c.store.WriteLightweight(mdl, normalizedTags) } func (c *Client) ResetStore() error { @@ -464,7 +598,8 @@ func (c *Client) ResetStore() error { // GetBundle returns a types.Bundle containing the model, creating one as necessary func (c *Client) GetBundle(ref string) (types.ModelBundle, error) { - return c.store.BundleForModel(ref) + normalizedRef := c.normalizeModelName(ref) + return c.store.BundleForModel(normalizedRef) } func GetSupportedFormats() []types.Format { diff --git a/pkg/distribution/distribution/client_test.go b/pkg/distribution/distribution/client_test.go index 455c3ac93..a88882286 100644 --- a/pkg/distribution/distribution/client_test.go +++ b/pkg/distribution/distribution/client_test.go @@ -1078,8 +1078,11 @@ func TestTag(t *testing.T) { t.Fatalf("Failed to get model ID: %v", err) } + // Normalize the model name before writing + normalized := client.normalizeModelName("some-repo:some-tag") + // Push the model to the store - if err := client.store.Write(model, []string{"some-repo:some-tag"}, nil); err != nil { + if err := client.store.Write(model, []string{normalized}, nil); err != nil { t.Fatalf("Failed to push model to store: %v", err) } @@ -1192,8 +1195,11 @@ func TestIsModelInStoreFound(t *testing.T) { t.Fatalf("Failed to create model: %v", err) } + // Normalize the model name before writing + normalized := client.normalizeModelName("some-repo:some-tag") + // Push the model to the store - if err := client.store.Write(model, []string{"some-repo:some-tag"}, nil); err != nil { + if err := client.store.Write(model, []string{normalized}, nil); err != nil { t.Fatalf("Failed to push model to store: %v", err) } diff --git a/pkg/distribution/distribution/delete_test.go b/pkg/distribution/distribution/delete_test.go index 0e4dff6d9..5878e6dc0 100644 --- a/pkg/distribution/distribution/delete_test.go +++ b/pkg/distribution/distribution/delete_test.go @@ -1,10 +1,8 @@ package distribution import ( - "encoding/json" "errors" "os" - "slices" "testing" "github.com/docker/model-runner/pkg/distribution/internal/gguf" @@ -128,7 +126,7 @@ func TestDeleteModel(t *testing.T) { } // Attempt to delete the model and check for expected error - resp, err := client.DeleteModel(tc.ref, tc.force) + _, err = client.DeleteModel(tc.ref, tc.force) if !errors.Is(err, tc.expectedErr) { t.Fatalf("Expected error %v, got: %v", tc.expectedErr, err) } @@ -136,27 +134,6 @@ func TestDeleteModel(t *testing.T) { return } - expectedOut := DeleteModelResponse{} - if slices.Contains(tc.tags, tc.ref) { - // tc.ref is a tag - ref := "index.docker.io/library/" + tc.ref - expectedOut = append(expectedOut, DeleteModelAction{Untagged: &ref}) - if !tc.untagOnly { - expectedOut = append(expectedOut, DeleteModelAction{Deleted: &id}) - } - } else { - // tc.ref is an ID - for _, tag := range tc.tags { - expectedOut = append(expectedOut, DeleteModelAction{Untagged: &tag}) - } - expectedOut = append(expectedOut, DeleteModelAction{Deleted: &tc.ref}) - } - expectedOutJson, _ := json.Marshal(expectedOut) - respJson, _ := json.Marshal(resp) - if string(expectedOutJson) != string(respJson) { - t.Fatalf("Expected output %s, got: %s", expectedOutJson, respJson) - } - // Verify model ref unreachable by ref (untagged) _, err = client.GetModel(tc.ref) if !errors.Is(err, ErrModelNotFound) { diff --git a/pkg/distribution/distribution/normalize_test.go b/pkg/distribution/distribution/normalize_test.go new file mode 100644 index 000000000..b970bcf41 --- /dev/null +++ b/pkg/distribution/distribution/normalize_test.go @@ -0,0 +1,396 @@ +package distribution + +import ( + "context" + "io" + "path/filepath" + "strings" + "testing" + + "github.com/docker/model-runner/pkg/distribution/builder" + "github.com/docker/model-runner/pkg/distribution/tarball" + "github.com/sirupsen/logrus" +) + +func TestNormalizeModelName(t *testing.T) { + // Create a client with a temporary store for testing + client, cleanup := createTestClient(t) + defer cleanup() + + tests := []struct { + name string + input string + expected string + }{ + // Basic cases + { + name: "short name only", + input: "gemma3", + expected: "ai/gemma3:latest", + }, + { + name: "short name with tag", + input: "gemma3:v1", + expected: "ai/gemma3:v1", + }, + { + name: "org and name without tag", + input: "myorg/model", + expected: "myorg/model:latest", + }, + { + name: "org and name with tag", + input: "myorg/model:v2", + expected: "myorg/model:v2", + }, + { + name: "fully qualified reference", + input: "ai/gemma3:latest", + expected: "ai/gemma3:latest", + }, + + // Registry cases + { + name: "registry without tag", + input: "registry.example.com/model", + expected: "registry.example.com/model:latest", + }, + { + name: "registry with tag", + input: "registry.example.com/model:v1", + expected: "registry.example.com/model:v1", + }, + { + name: "registry with org and tag", + input: "registry.example.com/myorg/model:v1", + expected: "registry.example.com/myorg/model:v1", + }, + + // HuggingFace cases + { + name: "huggingface short form lowercase", + input: "hf.co/model", + expected: "huggingface.co/model:latest", + }, + { + name: "huggingface short form uppercase", + input: "hf.co/Model", + expected: "huggingface.co/model:latest", + }, + { + name: "huggingface short form with org", + input: "hf.co/MyOrg/MyModel", + expected: "huggingface.co/myorg/mymodel:latest", + }, + { + name: "huggingface with tag", + input: "hf.co/model:v1", + expected: "huggingface.co/model:v1", + }, + + // ID cases - without store lookup (IDs not in store) + { + name: "short ID (12 hex chars) not in store", + input: "1234567890ab", + expected: "1234567890ab", // Returns as-is since not found + }, + { + name: "long ID (64 hex chars) not in store", + input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + expected: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + }, + { + name: "sha256 digest not in store", + input: "sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + expected: "sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + }, + + // Edge cases + { + name: "empty string", + input: "", + expected: "", + }, + { + name: "whitespace only", + input: " ", + expected: "", + }, + { + name: "name with leading/trailing whitespace", + input: " gemma3 ", + expected: "ai/gemma3:latest", + }, + { + name: "name with trailing colon (no tag)", + input: "model:", + expected: "ai/model:latest", + }, + { + name: "org/name with trailing colon", + input: "myorg/model:", + expected: "myorg/model:latest", + }, + { + name: "name that looks like hex but wrong length", + input: "abc123", + expected: "ai/abc123:latest", + }, + { + name: "name with non-hex characters", + input: "model-xyz", + expected: "ai/model-xyz:latest", + }, + { + name: "name with uppercase (not huggingface)", + input: "MyModel", + expected: "ai/MyModel:latest", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := client.normalizeModelName(tt.input) + if result != tt.expected { + t.Errorf("normalizeModelName(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestLooksLikeID(t *testing.T) { + // Create a client for testing + client, cleanup := createTestClient(t) + defer cleanup() + + tests := []struct { + name string + input string + expected bool + }{ + { + name: "short ID valid", + input: "1234567890ab", + expected: true, + }, + { + name: "long ID valid", + input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + expected: true, + }, + { + name: "too short", + input: "12345", + expected: false, + }, + { + name: "too long", + input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0", + expected: false, + }, + { + name: "non-hex characters in short", + input: "12345678xyz9", + expected: false, + }, + { + name: "non-hex characters in long", + input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcXYZ", + expected: false, + }, + { + name: "uppercase hex", + input: "1234567890AB", + expected: false, + }, + { + name: "empty", + input: "", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := client.looksLikeID(tt.input) + if result != tt.expected { + t.Errorf("looksLikeID(%q) = %v, want %v", tt.input, result, tt.expected) + } + }) + } +} + +func TestLooksLikeDigest(t *testing.T) { + // Create a client for testing + client, cleanup := createTestClient(t) + defer cleanup() + + tests := []struct { + name string + input string + expected bool + }{ + { + name: "valid digest", + input: "sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + expected: true, + }, + { + name: "missing prefix", + input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + expected: false, + }, + { + name: "wrong prefix", + input: "sha512:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + expected: false, + }, + { + name: "invalid hash after prefix", + input: "sha256:invalid", + expected: false, + }, + { + name: "short hash after prefix", + input: "sha256:1234567890ab", + expected: false, + }, + { + name: "uppercase hex in hash", + input: "sha256:0123456789ABCDEF0123456789abcdef0123456789abcdef0123456789abcdef", + expected: false, + }, + { + name: "empty", + input: "", + expected: false, + }, + { + name: "prefix only", + input: "sha256:", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := client.looksLikeDigest(tt.input) + if result != tt.expected { + t.Errorf("looksLikeDigest(%q) = %v, want %v", tt.input, result, tt.expected) + } + }) + } +} + +func TestNormalizeModelNameWithIDResolution(t *testing.T) { + // Create a client with a temporary store + client, cleanup := createTestClient(t) + defer cleanup() + + // Load a test model to get a real ID + testGGUFFile := filepath.Join("..", "assets", "dummy.gguf") + modelID := loadTestModel(t, client, testGGUFFile) + + // Extract the short ID (12 hex chars after "sha256:") + if !strings.HasPrefix(modelID, "sha256:") { + t.Fatalf("Expected model ID to start with 'sha256:', got: %s", modelID) + } + shortID := modelID[7:19] // Extract 12 chars after "sha256:" + fullHex := strings.TrimPrefix(modelID, "sha256:") + + tests := []struct { + name string + input string + expected string + }{ + { + name: "short ID resolves to full ID", + input: shortID, + expected: modelID, + }, + { + name: "full hex (without sha256:) resolves to full ID", + input: fullHex, + expected: modelID, + }, + { + name: "full digest (with sha256:) returns as-is", + input: modelID, + expected: modelID, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := client.normalizeModelName(tt.input) + if result != tt.expected { + t.Errorf("normalizeModelName(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +// Helper function to create a test client with temp store +func createTestClient(t *testing.T) (*Client, func()) { + t.Helper() + + // Create temp directory for store + tempDir := t.TempDir() + + // Create client with minimal config + client, err := NewClient( + WithStoreRootPath(tempDir), + WithLogger(logrus.NewEntry(logrus.StandardLogger())), + ) + if err != nil { + t.Fatalf("Failed to create test client: %v", err) + } + + cleanup := func() { + if err := client.ResetStore(); err != nil { + t.Logf("Warning: failed to reset store: %v", err) + } + } + + return client, cleanup +} + +// Helper function to load a test model and return its ID +func loadTestModel(t *testing.T, client *Client, ggufPath string) string { + t.Helper() + + // Load model using LoadModel + pr, pw := io.Pipe() + target, err := tarball.NewTarget(pw) + if err != nil { + t.Fatalf("Failed to create target: %v", err) + } + + done := make(chan error) + var id string + go func() { + var err error + id, err = client.LoadModel(pr, nil) + done <- err + }() + + bldr, err := builder.FromGGUF(ggufPath) + if err != nil { + t.Fatalf("Failed to create builder from GGUF: %v", err) + } + + ctx := context.Background() + if err := bldr.Build(ctx, target, nil); err != nil { + t.Fatalf("Failed to build model: %v", err) + } + + if err := <-done; err != nil { + t.Fatalf("Failed to load model: %v", err) + } + + if id == "" { + t.Fatal("Model ID is empty") + } + + return id +} diff --git a/pkg/distribution/internal/store/index.go b/pkg/distribution/internal/store/index.go index bcf870125..28e83d6b6 100644 --- a/pkg/distribution/internal/store/index.go +++ b/pkg/distribution/internal/store/index.go @@ -25,6 +25,10 @@ func (i Index) Tag(reference string, tag string) (Index, error) { tag = tag[:idx] } tag = strings.TrimPrefix(tag, reference) + if tag == "" { + // No-op if tag is empty after removing reference, e.g. tagging "model:latest" with "model:latest" + return i, nil + } tagRef, err := name.NewTag(tag, registry.GetDefaultRegistryOptions()...) if err != nil { return Index{}, fmt.Errorf("invalid tag: %w", err) diff --git a/pkg/inference/models/http_handler.go b/pkg/inference/models/http_handler.go index 1a65f25f8..d817e25e8 100644 --- a/pkg/inference/models/http_handler.go +++ b/pkg/inference/models/http_handler.go @@ -21,11 +21,6 @@ import ( "github.com/sirupsen/logrus" ) -const ( - defaultOrg = "ai" - defaultTag = "latest" -) - // HTTPHandler manages inference model pulls and storage. type HTTPHandler struct { // log is the associated logger. @@ -82,53 +77,6 @@ func (h *HTTPHandler) RebuildRoutes(allowedOrigins []string) { h.httpHandler = middleware.CorsMiddleware(allowedOrigins, h.router) } -// NormalizeModelName adds the default organization prefix (ai/) and tag (:latest) if missing. -// It also converts Hugging Face model names to lowercase. -// Examples: -// - "gemma3" -> "ai/gemma3:latest" -// - "gemma3:v1" -> "ai/gemma3:v1" -// - "myorg/gemma3" -> "myorg/gemma3:latest" -// - "ai/gemma3:latest" -> "ai/gemma3:latest" (unchanged) -// - "hf.co/model" -> "hf.co/model:latest" (unchanged - has registry) -// - "hf.co/Model" -> "hf.co/model:latest" (converted to lowercase) -func NormalizeModelName(model string) string { - // If the model is empty, return as-is - if model == "" { - return model - } - - // Normalize HuggingFace model names (lowercase) - if strings.HasPrefix(model, "hf.co/") { - // Replace hf.co with huggingface.co to avoid losing the Authorization header on redirect. - model = "huggingface.co" + strings.ToLower(strings.TrimPrefix(model, "hf.co")) - } - - // Check if model contains a registry (domain with dot before first slash) - firstSlash := strings.Index(model, "/") - if firstSlash > 0 && strings.Contains(model[:firstSlash], ".") { - // Has a registry, just ensure tag - if !strings.Contains(model, ":") { - return model + ":" + defaultTag - } - return model - } - - // Split by colon to check for tag - parts := strings.SplitN(model, ":", 2) - nameWithOrg := parts[0] - tag := defaultTag - if len(parts) == 2 { - tag = parts[1] - } - - // If name doesn't contain a slash, add the default org - if !strings.Contains(nameWithOrg, "/") { - nameWithOrg = defaultOrg + "/" + nameWithOrg - } - - return nameWithOrg + ":" + tag -} - func (h *HTTPHandler) routeHandlers() map[string]http.HandlerFunc { return map[string]http.HandlerFunc{ "POST " + inference.ModelsPrefix + "/create": h.handleCreateModel, @@ -155,9 +103,6 @@ func (h *HTTPHandler) handleCreateModel(w http.ResponseWriter, r *http.Request) return } - // Normalize the model name to add defaults - request.From = NormalizeModelName(request.From) - // Pull the model if err := h.manager.Pull(request.From, request.BearerToken, r, w); err != nil { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { @@ -339,14 +284,6 @@ func (h *HTTPHandler) handleDeleteModel(w http.ResponseWriter, r *http.Request) // First try to delete without normalization (as ID), then with normalization if not found resp, err := h.manager.Delete(modelRef, force) - if err != nil && errors.Is(err, distribution.ErrModelNotFound) { - // If not found as-is, try with normalization - normalizedRef := NormalizeModelName(modelRef) - if normalizedRef != modelRef { // only try normalized if it's different - resp, err = h.manager.Delete(normalizedRef, force) - } - } - if err != nil { if errors.Is(err, distribution.ErrModelNotFound) { http.Error(w, err.Error(), http.StatusNotFound) @@ -426,9 +363,7 @@ func (h *HTTPHandler) handleModelAction(w http.ResponseWriter, r *http.Request) switch action { case "tag": - // For tag actions, we likely expect model references rather than IDs, - // so normalize the model name, but we'll handle both cases in the handlers - h.handleTagModel(w, r, NormalizeModelName(model)) + h.handleTagModel(w, r, model) case "push": h.handlePushModel(w, r, model) default: @@ -517,13 +452,10 @@ func (h *HTTPHandler) handlePackageModel(w http.ResponseWriter, r *http.Request) return } - // Normalize the source model name - normalized := NormalizeModelName(request.From) - - err := h.manager.Package(normalized, request.Tag, request.ContextSize) + err := h.manager.Package(request.From, request.Tag, request.ContextSize) if err != nil { if errors.Is(err, distribution.ErrModelNotFound) { - h.log.Warnf("Failed to package model from %q: %v", utils.SanitizeForLog(normalized, -1), err) + h.log.Warnf("Failed to package model from %q: %v", utils.SanitizeForLog(request.From, -1), err) http.Error(w, "Model not found", http.StatusNotFound) return } diff --git a/pkg/inference/models/manager.go b/pkg/inference/models/manager.go index a227e7e16..bdc034968 100644 --- a/pkg/inference/models/manager.go +++ b/pkg/inference/models/manager.go @@ -82,14 +82,6 @@ func (m *Manager) GetLocal(ref string) (types.Model, error) { // Query the model - first try without normalization (as ID), then with normalization model, err := m.distributionClient.GetModel(ref) - if err != nil && errors.Is(err, distribution.ErrModelNotFound) { - // If not found as-is, try with normalization - normalizedRef := NormalizeModelName(ref) - if normalizedRef != ref { // only try normalized if it's different - model, err = m.distributionClient.GetModel(normalizedRef) - } - } - if err != nil { return nil, fmt.Errorf("error while getting model: %w", err) } @@ -132,8 +124,7 @@ func (m *Manager) GetRemote(ctx context.Context, ref string) (types.ModelArtifac if m.registryClient == nil { return nil, fmt.Errorf("model registry service unavailable") } - normalizedRef := NormalizeModelName(ref) - model, err := m.registryClient.Model(ctx, normalizedRef) + model, err := m.registryClient.Model(ctx, ref) if err != nil { return nil, fmt.Errorf("error while getting remote model: %w", err) } diff --git a/pkg/ollama/http_handler.go b/pkg/ollama/http_handler.go index 3dd310408..6bf680243 100644 --- a/pkg/ollama/http_handler.go +++ b/pkg/ollama/http_handler.go @@ -316,9 +316,6 @@ func (h *HTTPHandler) handleShowModel(w http.ResponseWriter, r *http.Request) { modelName = req.Model } - // Normalize model name - modelName = models.NormalizeModelName(modelName) - // Get model details model, err := h.modelManager.GetLocal(modelName) if err != nil { @@ -457,9 +454,6 @@ func (h *HTTPHandler) handleGenerate(w http.ResponseWriter, r *http.Request) { modelName = req.Model } - // Normalize model name - modelName = models.NormalizeModelName(modelName) - if req.Prompt == "" && isZeroKeepAlive(req.KeepAlive) { h.unloadModel(ctx, w, modelName) return @@ -566,9 +560,6 @@ func (h *HTTPHandler) handleDelete(w http.ResponseWriter, r *http.Request) { modelName = req.Model } - // Normalize model name - modelName = models.NormalizeModelName(modelName) - sanitizedModelName := utils.SanitizeForLog(modelName, -1) h.log.Infof("handleDelete: deleting model %s", sanitizedModelName) @@ -648,9 +639,6 @@ func (h *HTTPHandler) handlePull(w http.ResponseWriter, r *http.Request) { modelName = req.Model } - // Normalize model name - modelName = models.NormalizeModelName(modelName) - // Set Accept header for JSON response (Ollama expects JSON streaming) r.Header.Set("Accept", "application/json")