Skip to content

Commit ed99eef

Browse files
authored
Merge pull request docker#262 from docker/fix-local-model-error
Fix model namespace normalization for package and tag commands
2 parents e9302fa + 82db531 commit ed99eef

File tree

5 files changed

+65
-3
lines changed

5 files changed

+65
-3
lines changed

cmd/cli/commands/package.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/docker/model-runner/pkg/distribution/registry"
1616
"github.com/docker/model-runner/pkg/distribution/tarball"
1717
"github.com/docker/model-runner/pkg/distribution/types"
18+
"github.com/docker/model-runner/pkg/inference/models"
1819
"github.com/google/go-containerregistry/pkg/name"
1920
"github.com/spf13/cobra"
2021

@@ -313,7 +314,9 @@ func newModelRunnerTarget(client *desktop.Client, tag string) (*modelRunnerTarge
313314
}
314315
if tag != "" {
315316
var err error
316-
target.tag, err = name.NewTag(tag)
317+
// Normalize the tag to add default namespace (ai/) and tag (:latest) if missing
318+
normalizedTag := models.NormalizeModelName(tag)
319+
target.tag, err = name.NewTag(normalizedTag)
317320
if err != nil {
318321
return nil, fmt.Errorf("invalid tag: %w", err)
319322
}

cmd/cli/commands/tag.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ func newTagCmd() *cobra.Command {
3939
func tagModel(cmd *cobra.Command, desktopClient *desktop.Client, source, target string) error {
4040
// Normalize source model name to add default org and tag if missing
4141
source = models.NormalizeModelName(source)
42+
// Normalize target model name to add default org and tag if missing
43+
target = models.NormalizeModelName(target)
4244
// Ensure tag is valid
4345
tag, err := name.NewTag(target)
4446
if err != nil {

cmd/cli/commands/utils_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,51 @@ func TestStripDefaultsFromModelName(t *testing.T) {
146146
})
147147
}
148148
}
149+
150+
// TestNormalizeModelNameConsistency verifies that locally packaged models
151+
// (without namespace) get normalized the same way as other operations.
152+
// This test documents the fix for the bug where `docker model package my-model`
153+
// would create a model that couldn't be run with `docker model run my-model`.
154+
func TestNormalizeModelNameConsistency(t *testing.T) {
155+
tests := []struct {
156+
name string
157+
userProvidedName string
158+
expectedNormalizedName string
159+
description string
160+
}{
161+
{
162+
name: "locally packaged model without namespace",
163+
userProvidedName: "my-model",
164+
expectedNormalizedName: "ai/my-model:latest",
165+
description: "When a user packages a local model as 'my-model', it should be normalized to 'ai/my-model:latest'",
166+
},
167+
{
168+
name: "locally packaged model without namespace but with tag",
169+
userProvidedName: "my-model:v1.0",
170+
expectedNormalizedName: "ai/my-model:v1.0",
171+
description: "When a user packages a local model as 'my-model:v1.0', it should be normalized to 'ai/my-model:v1.0'",
172+
},
173+
{
174+
name: "model with explicit namespace",
175+
userProvidedName: "myorg/my-model",
176+
expectedNormalizedName: "myorg/my-model:latest",
177+
description: "When a user packages a model with explicit org 'myorg/my-model', it should keep the org",
178+
},
179+
{
180+
name: "model with ai namespace explicitly set",
181+
userProvidedName: "ai/my-model",
182+
expectedNormalizedName: "ai/my-model:latest",
183+
description: "When a user explicitly sets 'ai/' namespace, it should remain the same",
184+
},
185+
}
186+
187+
for _, tt := range tests {
188+
t.Run(tt.name, func(t *testing.T) {
189+
result := models.NormalizeModelName(tt.userProvidedName)
190+
if result != tt.expectedNormalizedName {
191+
t.Errorf("%s: NormalizeModelName(%q) = %q, want %q",
192+
tt.description, tt.userProvidedName, result, tt.expectedNormalizedName)
193+
}
194+
})
195+
}
196+
}

cmd/cli/desktop/desktop.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -799,8 +799,17 @@ func (c *Client) handleQueryError(err error, path string) error {
799799
return fmt.Errorf("error querying %s: %w", path, err)
800800
}
801801

802+
// normalizeHuggingFaceModelName converts Hugging Face model names to lowercase
803+
func normalizeHuggingFaceModelName(model string) string {
804+
if strings.HasPrefix(model, "hf.co/") {
805+
return strings.ToLower(model)
806+
}
807+
808+
return model
809+
}
810+
802811
func (c *Client) Tag(source, targetRepo, targetTag string) error {
803-
source = dmrm.NormalizeModelName(source)
812+
source = normalizeHuggingFaceModelName(source)
804813
// Check if the source is a model ID, and expand it if necessary
805814
if !strings.Contains(strings.Trim(source, "/"), "/") {
806815
// Do an extra API call to check if the model parameter might be a model ID

cmd/cli/desktop/desktop_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ func TestTagHuggingFaceModel(t *testing.T) {
179179

180180
// Test case for tagging a Hugging Face model with mixed case
181181
sourceModel := "hf.co/Bartowski/Llama-3.2-1B-Instruct-GGUF"
182-
expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf:latest"
182+
expectedLowercase := "hf.co/bartowski/llama-3.2-1b-instruct-gguf"
183183
targetRepo := "myrepo"
184184
targetTag := "latest"
185185

0 commit comments

Comments
 (0)