diff --git a/.gitignore b/.gitignore index 5b01aba1a..cfc1e2894 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ vendor/ # model-distribution pkg/distribution/bin/ /parallelget +/cli diff --git a/cmd/cli/commands/configure.go b/cmd/cli/commands/configure.go index 848be0587..a7dc3ebfa 100644 --- a/cmd/cli/commands/configure.go +++ b/cmd/cli/commands/configure.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/docker/model-runner/cmd/cli/commands/completion" + "github.com/docker/model-runner/pkg/inference/models" "github.com/docker/model-runner/pkg/inference/scheduling" "github.com/spf13/cobra" ) @@ -33,7 +34,7 @@ func newConfigureCmd() *cobra.Command { argsBeforeDash) } } - opts.Model = args[0] + opts.Model = models.NormalizeModelName(args[0]) opts.RuntimeFlags = args[1:] return nil }, diff --git a/cmd/cli/commands/inspect.go b/cmd/cli/commands/inspect.go index 1a40139f3..3ec87f3a7 100644 --- a/cmd/cli/commands/inspect.go +++ b/cmd/cli/commands/inspect.go @@ -6,6 +6,7 @@ import ( "github.com/docker/model-runner/cmd/cli/commands/completion" "github.com/docker/model-runner/cmd/cli/commands/formatter" "github.com/docker/model-runner/cmd/cli/desktop" + "github.com/docker/model-runner/pkg/inference/models" "github.com/spf13/cobra" ) @@ -47,7 +48,8 @@ func newInspectCmd() *cobra.Command { } func inspectModel(args []string, openai bool, remote bool, desktopClient *desktop.Client) (string, error) { - modelName := args[0] + // Normalize model name to add default org and tag if missing + modelName := models.NormalizeModelName(args[0]) if openai { model, err := desktopClient.InspectOpenAI(modelName) if err != nil { diff --git a/cmd/cli/commands/list.go b/cmd/cli/commands/list.go index cf984bfbc..b4dfa9e44 100644 --- a/cmd/cli/commands/list.go +++ b/cmd/cli/commands/list.go @@ -4,7 +4,6 @@ import ( "bytes" "fmt" "os" - "slices" "strings" "time" @@ -89,12 +88,20 @@ func listModels(openai bool, backend string, desktopClient *desktop.Client, quie } if modelFilter != "" { + // Normalize the filter to match stored model names + normalizedFilter := dmrm.NormalizeModelName(modelFilter) var filteredModels []dmrm.Model for _, m := range models { hasMatchingTag := false for _, tag := range m.Tags { + if tag == normalizedFilter { + hasMatchingTag = true + break + } + // Also check without the tag part modelName, _, _ := strings.Cut(tag, ":") - if slices.Contains([]string{modelName, tag + ":latest", tag}, modelFilter) { + filterName, _, _ := strings.Cut(normalizedFilter, ":") + if modelName == filterName { hasMatchingTag = true break } @@ -165,8 +172,10 @@ func appendRow(table *tablewriter.Table, tag string, model dmrm.Model) { fmt.Fprintf(os.Stderr, "invalid model ID for model: %v\n", model) return } + // Strip default "ai/" prefix and ":latest" tag for display + displayTag := stripDefaultsFromModelName(tag) table.Append([]string{ - tag, + displayTag, model.Config.Parameters, model.Config.Quantization, model.Config.Architecture, diff --git a/cmd/cli/commands/ps.go b/cmd/cli/commands/ps.go index 775b73761..82480ad4a 100644 --- a/cmd/cli/commands/ps.go +++ b/cmd/cli/commands/ps.go @@ -54,6 +54,9 @@ func psTable(ps []desktop.BackendStatus) string { modelName := status.ModelName if strings.HasPrefix(modelName, "sha256:") { modelName = modelName[7:19] + } else { + // Strip default "ai/" prefix and ":latest" tag for display + modelName = stripDefaultsFromModelName(modelName) } table.Append([]string{ modelName, diff --git a/cmd/cli/commands/pull.go b/cmd/cli/commands/pull.go index c311a747c..38f3c7b5b 100644 --- a/cmd/cli/commands/pull.go +++ b/cmd/cli/commands/pull.go @@ -6,6 +6,7 @@ import ( "github.com/docker/model-runner/cmd/cli/commands/completion" "github.com/docker/model-runner/cmd/cli/desktop" + "github.com/docker/model-runner/pkg/inference/models" "github.com/mattn/go-isatty" "github.com/spf13/cobra" ) @@ -41,6 +42,8 @@ func newPullCmd() *cobra.Command { } func pullModel(cmd *cobra.Command, desktopClient *desktop.Client, model string, ignoreRuntimeMemoryCheck bool) error { + // Normalize model name to add default org and tag if missing + model = models.NormalizeModelName(model) var progress func(string) if isatty.IsTerminal(os.Stdout.Fd()) { progress = TUIProgress diff --git a/cmd/cli/commands/push.go b/cmd/cli/commands/push.go index 4aa83c2cc..72f614c67 100644 --- a/cmd/cli/commands/push.go +++ b/cmd/cli/commands/push.go @@ -5,6 +5,7 @@ import ( "github.com/docker/model-runner/cmd/cli/commands/completion" "github.com/docker/model-runner/cmd/cli/desktop" + "github.com/docker/model-runner/pkg/inference/models" "github.com/spf13/cobra" ) @@ -34,6 +35,8 @@ func newPushCmd() *cobra.Command { } func pushModel(cmd *cobra.Command, desktopClient *desktop.Client, model string) error { + // Normalize model name to add default org and tag if missing + model = models.NormalizeModelName(model) response, progressShown, err := desktopClient.Push(model, TUIProgress) // Add a newline before any output (success or error) if progress was shown. diff --git a/cmd/cli/commands/rm.go b/cmd/cli/commands/rm.go index a02b1be0b..fb9f39490 100644 --- a/cmd/cli/commands/rm.go +++ b/cmd/cli/commands/rm.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/docker/model-runner/cmd/cli/commands/completion" + "github.com/docker/model-runner/pkg/inference/models" "github.com/spf13/cobra" ) @@ -27,7 +28,12 @@ func newRemoveCmd() *cobra.Command { if _, err := ensureStandaloneRunnerAvailable(cmd.Context(), cmd); err != nil { return fmt.Errorf("unable to initialize standalone model runner: %w", err) } - response, err := desktopClient.Remove(args, force) + // Normalize model names to add default org and tag if missing + normalizedArgs := make([]string, len(args)) + for i, arg := range args { + normalizedArgs[i] = models.NormalizeModelName(arg) + } + response, err := desktopClient.Remove(normalizedArgs, force) if response != "" { cmd.Print(response) } diff --git a/cmd/cli/commands/run.go b/cmd/cli/commands/run.go index 379e609c6..2f2ff9822 100644 --- a/cmd/cli/commands/run.go +++ b/cmd/cli/commands/run.go @@ -15,6 +15,7 @@ import ( "github.com/docker/model-runner/cmd/cli/commands/completion" "github.com/docker/model-runner/cmd/cli/desktop" "github.com/docker/model-runner/cmd/cli/readline" + "github.com/docker/model-runner/pkg/inference/models" "github.com/fatih/color" "github.com/spf13/cobra" "golang.org/x/term" @@ -561,7 +562,8 @@ func newRunCmd() *cobra.Command { return err } - model := args[0] + // Normalize model name to add default org and tag if missing + model := models.NormalizeModelName(args[0]) prompt := "" argsLen := len(args) if argsLen > 1 { diff --git a/cmd/cli/commands/tag.go b/cmd/cli/commands/tag.go index 4cf08663c..f7197b363 100644 --- a/cmd/cli/commands/tag.go +++ b/cmd/cli/commands/tag.go @@ -6,6 +6,7 @@ import ( "github.com/docker/model-runner/cmd/cli/commands/completion" "github.com/docker/model-runner/cmd/cli/desktop" + "github.com/docker/model-runner/pkg/inference/models" "github.com/google/go-containerregistry/pkg/name" "github.com/spf13/cobra" ) @@ -36,6 +37,8 @@ func newTagCmd() *cobra.Command { } func tagModel(cmd *cobra.Command, desktopClient *desktop.Client, source, target string) error { + // Normalize source model name to add default org and tag if missing + source = models.NormalizeModelName(source) // Ensure tag is valid tag, err := name.NewTag(target) if err != nil { diff --git a/cmd/cli/commands/unload.go b/cmd/cli/commands/unload.go index 11fd85fda..9e32c335a 100644 --- a/cmd/cli/commands/unload.go +++ b/cmd/cli/commands/unload.go @@ -5,6 +5,7 @@ import ( "github.com/docker/model-runner/cmd/cli/commands/completion" "github.com/docker/model-runner/cmd/cli/desktop" + "github.com/docker/model-runner/pkg/inference/models" "github.com/spf13/cobra" ) @@ -16,8 +17,13 @@ func newUnloadCmd() *cobra.Command { c := &cobra.Command{ Use: "unload " + cmdArgs, Short: "Unload running models", - RunE: func(cmd *cobra.Command, models []string) error { - unloadResp, err := desktopClient.Unload(desktop.UnloadRequest{All: all, Backend: backend, Models: models}) + RunE: func(cmd *cobra.Command, modelArgs []string) error { + // Normalize model names + normalizedModels := make([]string, len(modelArgs)) + for i, model := range modelArgs { + normalizedModels[i] = models.NormalizeModelName(model) + } + unloadResp, err := desktopClient.Unload(desktop.UnloadRequest{All: all, Backend: backend, Models: normalizedModels}) if err != nil { err = handleClientError(err, "Failed to unload models") return handleNotRunningError(err) diff --git a/cmd/cli/commands/utils.go b/cmd/cli/commands/utils.go index 46c4b5063..c8b4abbe3 100644 --- a/cmd/cli/commands/utils.go +++ b/cmd/cli/commands/utils.go @@ -10,6 +10,11 @@ import ( "github.com/pkg/errors" ) +const ( + defaultOrg = "ai" + defaultTag = "latest" +) + const ( enableViaCLI = "Enable Docker Model Runner via the CLI → docker desktop enable model-runner" enableViaGUI = "Enable Docker Model Runner via the GUI → Go to Settings->AI->Enable Docker Model Runner" @@ -32,3 +37,25 @@ func handleNotRunningError(err error) error { } return err } + +// stripDefaultsFromModelName removes the default "ai/" prefix and ":latest" tag for display. +// Examples: +// - "ai/gemma3:latest" -> "gemma3" +// - "ai/gemma3:v1" -> "ai/gemma3:v1" +// - "myorg/gemma3:latest" -> "myorg/gemma3" +// - "gemma3:latest" -> "gemma3" +// - "hf.co/bartowski/model:latest" -> "hf.co/bartowski/model" +func stripDefaultsFromModelName(model string) string { + // Check if model has ai/ prefix without tag (implicitly :latest) - strip just ai/ + if strings.HasPrefix(model, defaultOrg+"/") { + model = strings.TrimPrefix(model, defaultOrg+"/") + } + + // Check if model has :latest but no slash (no org specified) - strip :latest + if strings.HasSuffix(model, ":"+defaultTag) { + model = strings.TrimSuffix(model, ":"+defaultTag) + } + + // For other cases (ai/ with custom tag, custom org with :latest, etc.), keep as-is + return model +} diff --git a/cmd/cli/commands/utils_test.go b/cmd/cli/commands/utils_test.go new file mode 100644 index 000000000..ad4a8e3c3 --- /dev/null +++ b/cmd/cli/commands/utils_test.go @@ -0,0 +1,148 @@ +package commands + +import ( + "testing" + + "github.com/docker/model-runner/pkg/inference/models" +) + +func TestNormalizeModelName(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "simple model name", + input: "gemma3", + expected: "ai/gemma3:latest", + }, + { + name: "model name with tag", + input: "gemma3:v1", + expected: "ai/gemma3:v1", + }, + { + name: "model name with org", + input: "myorg/gemma3", + expected: "myorg/gemma3:latest", + }, + { + name: "model name with org and tag", + input: "myorg/gemma3:v1", + expected: "myorg/gemma3:v1", + }, + { + name: "fully qualified model name", + input: "ai/gemma3:latest", + expected: "ai/gemma3:latest", + }, + { + name: "huggingface model", + input: "hf.co/bartowski/model", + expected: "hf.co/bartowski/model:latest", + }, + { + name: "huggingface model with tag", + input: "hf.co/bartowski/model:Q4_K_S", + expected: "hf.co/bartowski/model:q4_k_s", + }, + { + name: "registry with model", + input: "docker.io/library/model", + expected: "docker.io/library/model:latest", + }, + { + name: "registry with model and tag", + input: "docker.io/library/model:v1", + expected: "docker.io/library/model:v1", + }, + { + name: "empty string", + input: "", + expected: "", + }, + { + name: "ai prefix already present", + input: "ai/gemma3", + expected: "ai/gemma3:latest", + }, + { + name: "model name with latest tag already", + input: "gemma3:latest", + expected: "ai/gemma3:latest", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := models.NormalizeModelName(tt.input) + if result != tt.expected { + t.Errorf("NormalizeModelName(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestStripDefaultsFromModelName(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "ai prefix and latest tag", + input: "ai/gemma3:latest", + expected: "gemma3", + }, + { + name: "ai prefix with custom tag", + input: "ai/gemma3:v1", + expected: "gemma3:v1", + }, + { + name: "custom org with latest tag", + input: "myorg/gemma3:latest", + expected: "myorg/gemma3", + }, + { + name: "simple model name with latest", + input: "gemma3:latest", + expected: "gemma3", + }, + { + name: "simple model name without tag", + input: "gemma3", + expected: "gemma3", + }, + { + name: "ai prefix without tag", + input: "ai/gemma3", + expected: "gemma3", + }, + { + name: "huggingface model with latest", + input: "hf.co/bartowski/model:latest", + expected: "hf.co/bartowski/model", + }, + { + name: "huggingface model with custom tag", + input: "hf.co/bartowski/model:Q4_K_S", + expected: "hf.co/bartowski/model:Q4_K_S", + }, + { + name: "empty string", + input: "", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := stripDefaultsFromModelName(tt.input) + if result != tt.expected { + t.Errorf("stripDefaultsFromModelName(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} diff --git a/cmd/cli/desktop/desktop.go b/cmd/cli/desktop/desktop.go index 4b745d6c0..01b7b080f 100644 --- a/cmd/cli/desktop/desktop.go +++ b/cmd/cli/desktop/desktop.go @@ -58,14 +58,6 @@ type Status struct { Error error `json:"error"` } -// normalizeHuggingFaceModelName converts Hugging Face model names to lowercase -func normalizeHuggingFaceModelName(model string) string { - if strings.HasPrefix(model, "hf.co/") { - return strings.ToLower(model) - } - return model -} - func (c *Client) Status() Status { // TODO: Query "/". resp, err := c.doRequest(http.MethodGet, inference.ModelsPrefix, nil) @@ -108,7 +100,7 @@ func (c *Client) Status() Status { } func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func(string)) (string, bool, error) { - model = normalizeHuggingFaceModelName(model) + model = dmrm.NormalizeModelName(model) jsonData, err := json.Marshal(dmrm.ModelCreateRequest{From: model, IgnoreRuntimeMemoryCheck: ignoreRuntimeMemoryCheck}) if err != nil { return "", false, fmt.Errorf("error marshaling request: %w", err) @@ -176,7 +168,7 @@ func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func } func (c *Client) Push(model string, progress func(string)) (string, bool, error) { - model = normalizeHuggingFaceModelName(model) + model = dmrm.NormalizeModelName(model) pushPath := inference.ModelsPrefix + "/" + model + "/push" resp, err := c.doRequest( http.MethodPost, @@ -271,7 +263,7 @@ func (c *Client) ListOpenAI(backend, apiKey string) (dmrm.OpenAIModelList, error } func (c *Client) Inspect(model string, remote bool) (dmrm.Model, error) { - model = normalizeHuggingFaceModelName(model) + model = dmrm.NormalizeModelName(model) if model != "" { if !strings.Contains(strings.Trim(model, "/"), "/") { // Do an extra API call to check if the model parameter isn't a model ID. @@ -295,7 +287,7 @@ func (c *Client) Inspect(model string, remote bool) (dmrm.Model, error) { } func (c *Client) InspectOpenAI(model string) (dmrm.OpenAIModel, error) { - model = normalizeHuggingFaceModelName(model) + model = dmrm.NormalizeModelName(model) modelsRoute := inference.InferencePrefix + "/v1/models" if !strings.Contains(strings.Trim(model, "/"), "/") { // Do an extra API call to check if the model parameter isn't a model ID. @@ -371,7 +363,7 @@ func (c *Client) Chat(backend, model, prompt, apiKey string, outputFunc func(str // ChatWithContext performs a chat request with context support for cancellation and streams the response content with selective markdown rendering. func (c *Client) ChatWithContext(ctx context.Context, backend, model, prompt, apiKey string, outputFunc func(string), shouldUseMarkdown bool) error { - model = normalizeHuggingFaceModelName(model) + model = dmrm.NormalizeModelName(model) if !strings.Contains(strings.Trim(model, "/"), "/") { // Do an extra API call to check if the model parameter isn't a model ID. if expanded, err := c.fullModelID(model); err == nil { @@ -521,10 +513,10 @@ func (c *Client) ChatWithContext(ctx context.Context, backend, model, prompt, ap return nil } -func (c *Client) Remove(models []string, force bool) (string, error) { +func (c *Client) Remove(modelArgs []string, force bool) (string, error) { modelRemoved := "" - for _, model := range models { - model = normalizeHuggingFaceModelName(model) + for _, model := range modelArgs { + model = dmrm.NormalizeModelName(model) // Check if not a model ID passed as parameter. if !strings.Contains(model, "/") { if expanded, err := c.fullModelID(model); err == nil { @@ -808,7 +800,7 @@ func (c *Client) handleQueryError(err error, path string) error { } func (c *Client) Tag(source, targetRepo, targetTag string) error { - source = normalizeHuggingFaceModelName(source) + source = dmrm.NormalizeModelName(source) // Check if the source is a model ID, and expand it if necessary if !strings.Contains(strings.Trim(source, "/"), "/") { // Do an extra API call to check if the model parameter might be a model ID diff --git a/cmd/cli/desktop/desktop_test.go b/cmd/cli/desktop/desktop_test.go index e5ab6a2b0..32d7907ce 100644 --- a/cmd/cli/desktop/desktop_test.go +++ b/cmd/cli/desktop/desktop_test.go @@ -20,7 +20,7 @@ func TestPullHuggingFaceModel(t *testing.T) { // Test case for pulling a Hugging Face model with mixed case modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" - expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf:latest" mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) mockContext := NewContextForMock(mockClient) @@ -46,7 +46,7 @@ func TestChatHuggingFaceModel(t *testing.T) { // Test case for chatting with a Hugging Face model with mixed case modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" - expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf:latest" prompt := "Hello" mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) @@ -73,7 +73,7 @@ func TestInspectHuggingFaceModel(t *testing.T) { // Test case for inspecting a Hugging Face model with mixed case modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" - expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf:latest" mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) mockContext := NewContextForMock(mockClient) @@ -108,6 +108,7 @@ func TestNonHuggingFaceModel(t *testing.T) { // Test case for a non-Hugging Face model (should not be converted to lowercase) modelName := "docker.io/library/llama2" + expectedWithTag := "docker.io/library/llama2:latest" mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) mockContext := NewContextForMock(mockClient) client := New(mockContext) @@ -116,7 +117,7 @@ func TestNonHuggingFaceModel(t *testing.T) { var reqBody models.ModelCreateRequest err := json.NewDecoder(req.Body).Decode(&reqBody) require.NoError(t, err) - assert.Equal(t, modelName, reqBody.From) + assert.Equal(t, expectedWithTag, reqBody.From) }).Return(&http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)), @@ -132,7 +133,7 @@ func TestPushHuggingFaceModel(t *testing.T) { // Test case for pushing a Hugging Face model with mixed case modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" - expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf:latest" mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) mockContext := NewContextForMock(mockClient) @@ -155,7 +156,7 @@ func TestRemoveHuggingFaceModel(t *testing.T) { // Test case for removing a Hugging Face model with mixed case modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" - expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf:latest" mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) mockContext := NewContextForMock(mockClient) @@ -178,7 +179,7 @@ func TestTagHuggingFaceModel(t *testing.T) { // Test case for tagging a Hugging Face model with mixed case sourceModel := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" - expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf:latest" targetRepo := "myrepo" targetTag := "latest" @@ -202,7 +203,7 @@ func TestInspectOpenAIHuggingFaceModel(t *testing.T) { // Test case for inspecting a Hugging Face model with mixed case modelName := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF" - expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf" + expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf:latest" mockClient := mockdesktop.NewMockDockerHttpClient(ctrl) mockContext := NewContextForMock(mockClient) diff --git a/pkg/inference/models/manager.go b/pkg/inference/models/manager.go index 7c3a265a8..ed52760a3 100644 --- a/pkg/inference/models/manager.go +++ b/pkg/inference/models/manager.go @@ -28,6 +28,8 @@ const ( // maximumConcurrentModelPulls is the maximum number of concurrent model // pulls that a model manager will allow. maximumConcurrentModelPulls = 2 + defaultOrg = "ai" + defaultTag = "latest" ) // Manager manages inference model pulls and storage. @@ -121,6 +123,52 @@ func (m *Manager) RebuildRoutes(allowedOrigins []string) { m.httpHandler = middleware.CorsMiddleware(allowedOrigins, m.router) } +// NormalizeModelName adds the default organization prefix (ai/) and tag (:latest) if missing. +// It also converts Hugging Face model names to lowercase. +// Examples: +// - "gemma3" -> "ai/gemma3:latest" +// - "gemma3:v1" -> "ai/gemma3:v1" +// - "myorg/gemma3" -> "myorg/gemma3:latest" +// - "ai/gemma3:latest" -> "ai/gemma3:latest" (unchanged) +// - "hf.co/model" -> "hf.co/model:latest" (unchanged - has registry) +// - "hf.co/Model" -> "hf.co/model:latest" (converted to lowercase) +func NormalizeModelName(model string) string { + // If the model is empty, return as-is + if model == "" { + return model + } + + // Normalize HuggingFace model names (lowercase) + if strings.HasPrefix(model, "hf.co/") { + model = strings.ToLower(model) + } + + // Check if model contains a registry (domain with dot before first slash) + firstSlash := strings.Index(model, "/") + if firstSlash > 0 && strings.Contains(model[:firstSlash], ".") { + // Has a registry, just ensure tag + if !strings.Contains(model, ":") { + return model + ":" + defaultTag + } + return model + } + + // Split by colon to check for tag + parts := strings.SplitN(model, ":", 2) + nameWithOrg := parts[0] + tag := defaultTag + if len(parts) == 2 { + tag = parts[1] + } + + // If name doesn't contain a slash, add the default org + if !strings.Contains(nameWithOrg, "/") { + nameWithOrg = defaultOrg + "/" + nameWithOrg + } + + return nameWithOrg + ":" + tag +} + func (m *Manager) routeHandlers() map[string]http.HandlerFunc { return map[string]http.HandlerFunc{ "POST " + inference.ModelsPrefix + "/create": m.handleCreateModel, @@ -150,6 +198,9 @@ func (m *Manager) handleCreateModel(w http.ResponseWriter, r *http.Request) { return } + // Normalize the model name to add defaults + request.From = NormalizeModelName(request.From) + // Pull the model. In the future, we may support additional operations here // besides pulling (such as model building). if memory.RuntimeMemoryCheckEnabled() && !request.IgnoreRuntimeMemoryCheck { @@ -243,6 +294,9 @@ func (m *Manager) handleGetModels(w http.ResponseWriter, r *http.Request) { // handleGetModel handles GET /models/{name} requests. func (m *Manager) handleGetModel(w http.ResponseWriter, r *http.Request) { + // Normalize model name + modelName := NormalizeModelName(r.PathValue("name")) + // Parse remote query parameter remote := false if r.URL.Query().Has("remote") { @@ -262,9 +316,9 @@ func (m *Manager) handleGetModel(w http.ResponseWriter, r *http.Request) { var err error if remote { - apiModel, err = getRemoteModel(r.Context(), m, r.PathValue("name")) + apiModel, err = getRemoteModel(r.Context(), m, modelName) } else { - apiModel, err = getLocalModel(m, r.PathValue("name")) + apiModel, err = getLocalModel(m, modelName) } if err != nil { @@ -373,6 +427,9 @@ func (m *Manager) handleDeleteModel(w http.ResponseWriter, r *http.Request) { // the runner process exits (though this won't work for Windows, where we // might need some separate cleanup process). + // Normalize model name + modelName := NormalizeModelName(r.PathValue("name")) + var force bool if r.URL.Query().Has("force") { if val, err := strconv.ParseBool(r.URL.Query().Get("force")); err != nil { @@ -382,7 +439,7 @@ func (m *Manager) handleDeleteModel(w http.ResponseWriter, r *http.Request) { } } - resp, err := m.distributionClient.DeleteModel(r.PathValue("name"), force) + resp, err := m.distributionClient.DeleteModel(modelName, force) if err != nil { if errors.Is(err, distribution.ErrModelNotFound) { http.Error(w, err.Error(), http.StatusNotFound) @@ -439,8 +496,11 @@ func (m *Manager) handleOpenAIGetModel(w http.ResponseWriter, r *http.Request) { return } + // Normalize model name + modelName := NormalizeModelName(r.PathValue("name")) + // Query the model. - model, err := m.GetModel(r.PathValue("name")) + model, err := m.GetModel(modelName) if err != nil { if errors.Is(err, distribution.ErrModelNotFound) { http.Error(w, err.Error(), http.StatusNotFound) @@ -469,6 +529,8 @@ func (m *Manager) handleOpenAIGetModel(w http.ResponseWriter, r *http.Request) { func (m *Manager) handleModelAction(w http.ResponseWriter, r *http.Request) { model, action := path.Split(r.PathValue("nameAndAction")) model = strings.TrimRight(model, "/") + // Normalize model name + model = NormalizeModelName(model) switch action { case "tag": m.handleTagModel(w, r, model)