Skip to content

Commit e7a45b9

Browse files
committed
Fix model name normalization
Signed-off-by: Christopher Petito <chrisjpetito@gmail.com>
1 parent 22ec140 commit e7a45b9

File tree

2 files changed

+148
-0
lines changed

2 files changed

+148
-0
lines changed

cmd/cli/commands/run.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/docker/model-runner/cmd/cli/commands/completion"
1616
"github.com/docker/model-runner/cmd/cli/desktop"
1717
"github.com/docker/model-runner/cmd/cli/readline"
18+
dmrm "github.com/docker/model-runner/pkg/inference/models"
1819
"github.com/muesli/termenv"
1920

2021
"github.com/fatih/color"
@@ -676,6 +677,8 @@ func newRunCmd() *cobra.Command {
676677
return nil
677678
}
678679

680+
model = dmrm.NormalizeModelName(model)
681+
679682
if _, err := ensureStandaloneRunnerAvailable(cmd.Context(), asPrinter(cmd), debug); err != nil {
680683
return fmt.Errorf("unable to initialize standalone model runner: %w", err)
681684
}

cmd/cli/commands/run_test.go

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"strings"
66
"testing"
77

8+
dmrm "github.com/docker/model-runner/pkg/inference/models"
89
"github.com/spf13/cobra"
910
)
1011

@@ -156,3 +157,147 @@ func TestRunCmdDetachFlag(t *testing.T) {
156157
t.Errorf("Expected detach flag value to be true, got false")
157158
}
158159
}
160+
161+
// TestRunModelNameNormalization verifies that model names are normalized correctly
162+
// in the run command to ensure consistency with how models are stored after pulling
163+
func TestRunModelNameNormalization(t *testing.T) {
164+
tests := []struct {
165+
name string
166+
userProvidedModelName string
167+
expectedNormalizedName string
168+
description string
169+
}{
170+
{
171+
name: "simple model name without namespace",
172+
userProvidedModelName: "llama3",
173+
expectedNormalizedName: "ai/llama3:latest",
174+
description: "When user runs 'docker model run llama3', it should be normalized to 'ai/llama3:latest'",
175+
},
176+
{
177+
name: "model name with tag but no namespace",
178+
userProvidedModelName: "llama3:8b",
179+
expectedNormalizedName: "ai/llama3:8b",
180+
description: "When user runs 'docker model run llama3:8b', it should be normalized to 'ai/llama3:8b'",
181+
},
182+
{
183+
name: "model name with explicit namespace",
184+
userProvidedModelName: "myorg/llama3",
185+
expectedNormalizedName: "myorg/llama3:latest",
186+
description: "When user runs 'docker model run myorg/llama3', it should preserve the namespace",
187+
},
188+
{
189+
name: "model name with ai namespace already set",
190+
userProvidedModelName: "ai/llama3",
191+
expectedNormalizedName: "ai/llama3:latest",
192+
description: "When user runs 'docker model run ai/llama3', it should remain as 'ai/llama3:latest'",
193+
},
194+
{
195+
name: "fully qualified model name",
196+
userProvidedModelName: "ai/llama3:latest",
197+
expectedNormalizedName: "ai/llama3:latest",
198+
description: "When user runs 'docker model run ai/llama3:latest', it should remain unchanged",
199+
},
200+
{
201+
name: "model name with custom org and tag",
202+
userProvidedModelName: "myorg/llama3:v2",
203+
expectedNormalizedName: "myorg/llama3:v2",
204+
description: "When user runs 'docker model run myorg/llama3:v2', it should remain unchanged",
205+
},
206+
{
207+
name: "huggingface model",
208+
userProvidedModelName: "hf.co/meta-llama/Llama-3-8B",
209+
expectedNormalizedName: "hf.co/meta-llama/llama-3-8b:latest",
210+
description: "HuggingFace models should be lowercased and tagged with :latest",
211+
},
212+
{
213+
name: "registry prefixed model",
214+
userProvidedModelName: "registry.example.com/mymodel",
215+
expectedNormalizedName: "registry.example.com/mymodel:latest",
216+
description: "Registry-prefixed models should only get :latest tag added",
217+
},
218+
}
219+
220+
for _, tt := range tests {
221+
t.Run(tt.name, func(t *testing.T) {
222+
// Test that the normalization function produces the expected output
223+
normalizedName := dmrm.NormalizeModelName(tt.userProvidedModelName)
224+
225+
if normalizedName != tt.expectedNormalizedName {
226+
t.Errorf("NormalizeModelName(%q) = %q, want %q\nDescription: %s",
227+
tt.userProvidedModelName,
228+
normalizedName,
229+
tt.expectedNormalizedName,
230+
tt.description)
231+
}
232+
})
233+
}
234+
}
235+
236+
// TestRunModelNameNormalizationConsistency verifies that the run command
237+
// uses the same normalization as the pull command, ensuring that:
238+
// 1. A model pulled as "docker model pull mymodel" creates "ai/mymodel:latest"
239+
// 2. The same model can be run as "docker model run mymodel" (without ai/ prefix)
240+
func TestRunModelNameNormalizationConsistency(t *testing.T) {
241+
testCases := []struct {
242+
name string
243+
userInputForPull string
244+
userInputForRun string
245+
expectedInternalReference string
246+
description string
247+
}{
248+
{
249+
name: "pull and run without namespace",
250+
userInputForPull: "gemma3",
251+
userInputForRun: "gemma3",
252+
expectedInternalReference: "ai/gemma3:latest",
253+
description: "User pulls 'gemma3' and runs 'gemma3' - both should resolve to 'ai/gemma3:latest'",
254+
},
255+
{
256+
name: "pull with namespace and run without namespace",
257+
userInputForPull: "ai/gemma3",
258+
userInputForRun: "gemma3",
259+
expectedInternalReference: "ai/gemma3:latest",
260+
description: "User pulls 'ai/gemma3' and runs 'gemma3' - both should resolve to 'ai/gemma3:latest'",
261+
},
262+
{
263+
name: "pull without namespace and run with namespace",
264+
userInputForPull: "gemma3",
265+
userInputForRun: "ai/gemma3",
266+
expectedInternalReference: "ai/gemma3:latest",
267+
description: "User pulls 'gemma3' and runs 'ai/gemma3' - both should resolve to 'ai/gemma3:latest'",
268+
},
269+
{
270+
name: "pull and run with tag",
271+
userInputForPull: "gemma3:2b",
272+
userInputForRun: "gemma3:2b",
273+
expectedInternalReference: "ai/gemma3:2b",
274+
description: "User pulls 'gemma3:2b' and runs 'gemma3:2b' - both should resolve to 'ai/gemma3:2b'",
275+
},
276+
{
277+
name: "custom org is preserved",
278+
userInputForPull: "myorg/gemma3",
279+
userInputForRun: "myorg/gemma3",
280+
expectedInternalReference: "myorg/gemma3:latest",
281+
description: "User pulls 'myorg/gemma3' and runs 'myorg/gemma3' - both should resolve to 'myorg/gemma3:latest'",
282+
},
283+
}
284+
285+
for _, tc := range testCases {
286+
t.Run(tc.name, func(t *testing.T) {
287+
// Simulate what happens during pull (model name is normalized before storage)
288+
normalizedPullName := dmrm.NormalizeModelName(tc.userInputForPull)
289+
290+
// Simulate what should happen during run (model name is normalized before lookup)
291+
normalizedRunName := dmrm.NormalizeModelName(tc.userInputForRun)
292+
293+
// Both should normalize to the same internal reference
294+
if normalizedPullName != tc.expectedInternalReference || normalizedRunName != tc.expectedInternalReference {
295+
t.Errorf("Normalization failed for test case %q:\n Pull input: %q -> Got %q, Want %q\n Run input: %q -> Got %q, Want %q\n Description: %s",
296+
tc.name,
297+
tc.userInputForPull, normalizedPullName, tc.expectedInternalReference,
298+
tc.userInputForRun, normalizedRunName, tc.expectedInternalReference,
299+
tc.description)
300+
}
301+
})
302+
}
303+
}

0 commit comments

Comments
 (0)