Skip to content

Commit 70d0e03

Browse files
committed
Move model name normalization into client to be able to resolve short model IDs
1 parent 7138336 commit 70d0e03

File tree

5 files changed

+444
-95
lines changed

5 files changed

+444
-95
lines changed

pkg/distribution/distribution/client.go

Lines changed: 133 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,132 @@ func NewClient(opts ...Option) (*Client, error) {
138138
}, nil
139139
}
140140

141+
// normalizeModelName adds the default organization prefix (ai/) and tag (:latest) if missing.
142+
// It also converts Hugging Face model names to lowercase and resolves IDs to full IDs.
143+
// This is a private method used internally by the Client.
144+
func (c *Client) normalizeModelName(model string) string {
145+
const (
146+
defaultOrg = "ai"
147+
defaultTag = "latest"
148+
)
149+
150+
model = strings.TrimSpace(model)
151+
152+
// If the model is empty, return as-is
153+
if model == "" {
154+
return model
155+
}
156+
157+
// If it looks like an ID or digest, try to resolve it to full ID
158+
if c.looksLikeID(model) || c.looksLikeDigest(model) {
159+
if fullID := c.resolveID(model); fullID != "" {
160+
return fullID
161+
}
162+
// If not found, return as-is
163+
return model
164+
}
165+
166+
// Normalize HuggingFace model names (lowercase path)
167+
if strings.HasPrefix(model, "hf.co/") {
168+
// Replace hf.co with huggingface.co to avoid losing the Authorization header on redirect.
169+
model = "huggingface.co" + strings.ToLower(strings.TrimPrefix(model, "hf.co"))
170+
}
171+
172+
// Check if model contains a registry (domain with dot before first slash)
173+
firstSlash := strings.Index(model, "/")
174+
if firstSlash > 0 && strings.Contains(model[:firstSlash], ".") {
175+
// Has a registry, just ensure tag
176+
if !strings.Contains(model, ":") {
177+
return model + ":" + defaultTag
178+
}
179+
return model
180+
}
181+
182+
// Split by colon to check for tag
183+
parts := strings.SplitN(model, ":", 2)
184+
nameWithOrg := parts[0]
185+
tag := defaultTag
186+
if len(parts) == 2 && parts[1] != "" {
187+
tag = parts[1]
188+
}
189+
190+
// If name doesn't contain a slash, add the default org
191+
if !strings.Contains(nameWithOrg, "/") {
192+
nameWithOrg = defaultOrg + "/" + nameWithOrg
193+
}
194+
195+
return nameWithOrg + ":" + tag
196+
}
197+
198+
// looksLikeID returns true for short & long hex IDs (12 or 64 chars)
199+
func (c *Client) looksLikeID(s string) bool {
200+
n := len(s)
201+
if n != 12 && n != 64 {
202+
return false
203+
}
204+
for i := 0; i < n; i++ {
205+
ch := s[i]
206+
if !((ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'f')) {
207+
return false
208+
}
209+
}
210+
return true
211+
}
212+
213+
// looksLikeDigest returns true for e.g. "sha256:<64-hex>"
214+
func (c *Client) looksLikeDigest(s string) bool {
215+
const prefix = "sha256:"
216+
if !strings.HasPrefix(s, prefix) {
217+
return false
218+
}
219+
hashPart := s[len(prefix):]
220+
// SHA256 digests must be exactly 64 hex characters
221+
if len(hashPart) != 64 {
222+
return false
223+
}
224+
for i := 0; i < 64; i++ {
225+
ch := hashPart[i]
226+
if !((ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'f')) {
227+
return false
228+
}
229+
}
230+
return true
231+
}
232+
233+
// resolveID attempts to resolve a short ID or digest to a full model ID
234+
// by checking all models in the store. Returns empty string if not found.
235+
func (c *Client) resolveID(id string) string {
236+
models, err := c.ListModels()
237+
if err != nil {
238+
return ""
239+
}
240+
241+
for _, m := range models {
242+
fullID, err := m.ID()
243+
if err != nil {
244+
continue
245+
}
246+
247+
// Check short ID (12 chars) - match against the hex part after "sha256:"
248+
if len(id) == 12 && strings.HasPrefix(fullID, "sha256:") {
249+
if len(fullID) >= 19 && fullID[7:19] == id {
250+
return fullID
251+
}
252+
}
253+
254+
// Check full ID match (with or without sha256: prefix)
255+
if fullID == id || strings.TrimPrefix(fullID, "sha256:") == id {
256+
return fullID
257+
}
258+
}
259+
260+
return ""
261+
}
262+
141263
// PullModel pulls a model from a registry and returns the local file path
142264
func (c *Client) PullModel(ctx context.Context, reference string, progressWriter io.Writer, bearerToken ...string) error {
143265
// Normalize the model reference
144-
reference = NormalizeModelName(reference)
266+
reference = c.normalizeModelName(reference)
145267
c.log.Infoln("Starting model pull:", utils.SanitizeForLog(reference))
146268

147269
// Use the client's registry, or create a temporary one if bearer token is provided
@@ -327,7 +449,7 @@ func (c *Client) ListModels() ([]types.Model, error) {
327449
// GetModel returns a model by reference
328450
func (c *Client) GetModel(reference string) (types.Model, error) {
329451
c.log.Infoln("Getting model by reference:", utils.SanitizeForLog(reference))
330-
normalizedRef := NormalizeModelName(reference)
452+
normalizedRef := c.normalizeModelName(reference)
331453
model, err := c.store.Read(normalizedRef)
332454
if err != nil {
333455
c.log.Errorln("Failed to get model:", err, "reference:", utils.SanitizeForLog(reference))
@@ -340,7 +462,7 @@ func (c *Client) GetModel(reference string) (types.Model, error) {
340462
// IsModelInStore checks if a model with the given reference is in the local store
341463
func (c *Client) IsModelInStore(reference string) (bool, error) {
342464
c.log.Infoln("Checking model by reference:", utils.SanitizeForLog(reference))
343-
normalizedRef := NormalizeModelName(reference)
465+
normalizedRef := c.normalizeModelName(reference)
344466
if _, err := c.store.Read(normalizedRef); errors.Is(err, ErrModelNotFound) {
345467
return false, nil
346468
} else if err != nil {
@@ -358,7 +480,7 @@ type DeleteModelResponse []DeleteModelAction
358480

359481
// DeleteModel deletes a model
360482
func (c *Client) DeleteModel(reference string, force bool) (*DeleteModelResponse, error) {
361-
normalizedRef := NormalizeModelName(reference)
483+
normalizedRef := c.normalizeModelName(reference)
362484
mdl, err := c.store.Read(normalizedRef)
363485
if err != nil {
364486
return &DeleteModelResponse{}, err
@@ -371,13 +493,13 @@ func (c *Client) DeleteModel(reference string, force bool) (*DeleteModelResponse
371493
// Check if this is a digest reference (contains @)
372494
// Digest references like "name@sha256:..." should be treated as ID references, not tags
373495
isDigestReference := strings.Contains(reference, "@")
374-
isTag := id != reference && !isDigestReference
496+
isTag := id != normalizedRef && !isDigestReference
375497

376498
resp := DeleteModelResponse{}
377499

378500
if isTag {
379501
c.log.Infoln("Untagging model:", reference)
380-
tags, err := c.store.RemoveTags([]string{reference})
502+
tags, err := c.store.RemoveTags([]string{normalizedRef})
381503
if err != nil {
382504
c.log.Errorln("Failed to untag model:", err, "tag:", reference)
383505
return &DeleteModelResponse{}, fmt.Errorf("untagging model: %w", err)
@@ -415,8 +537,9 @@ func (c *Client) DeleteModel(reference string, force bool) (*DeleteModelResponse
415537
// Tag adds a tag to a model
416538
func (c *Client) Tag(source string, target string) error {
417539
c.log.Infoln("Tagging model, source:", source, "target:", utils.SanitizeForLog(target))
418-
normalizedRef := NormalizeModelName(source)
419-
return c.store.AddTags(normalizedRef, []string{target})
540+
normalizedSource := c.normalizeModelName(source)
541+
normalizedTarget := c.normalizeModelName(target)
542+
return c.store.AddTags(normalizedSource, []string{normalizedTarget})
420543
}
421544

422545
// PushModel pushes a tagged model from the content store to the registry.
@@ -428,7 +551,7 @@ func (c *Client) PushModel(ctx context.Context, tag string, progressWriter io.Wr
428551
}
429552

430553
// Get the model from the store
431-
normalizedRef := NormalizeModelName(tag)
554+
normalizedRef := c.normalizeModelName(tag)
432555
mdl, err := c.store.Read(normalizedRef)
433556
if err != nil {
434557
return fmt.Errorf("reading model: %w", err)
@@ -471,7 +594,7 @@ func (c *Client) ResetStore() error {
471594

472595
// GetBundle returns a types.Bundle containing the model, creating one as necessary
473596
func (c *Client) GetBundle(ref string) (types.ModelBundle, error) {
474-
normalizedRef := NormalizeModelName(ref)
597+
normalizedRef := c.normalizeModelName(ref)
475598
return c.store.BundleForModel(normalizedRef)
476599
}
477600

pkg/distribution/distribution/client_test.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,8 +1078,11 @@ func TestTag(t *testing.T) {
10781078
t.Fatalf("Failed to get model ID: %v", err)
10791079
}
10801080

1081+
// Normalize the model name before writing
1082+
normalized := client.normalizeModelName("some-repo:some-tag")
1083+
10811084
// Push the model to the store
1082-
if err := client.store.Write(model, []string{"some-repo:some-tag"}, nil); err != nil {
1085+
if err := client.store.Write(model, []string{normalized}, nil); err != nil {
10831086
t.Fatalf("Failed to push model to store: %v", err)
10841087
}
10851088

@@ -1192,8 +1195,11 @@ func TestIsModelInStoreFound(t *testing.T) {
11921195
t.Fatalf("Failed to create model: %v", err)
11931196
}
11941197

1198+
// Normalize the model name before writing
1199+
normalized := client.normalizeModelName("some-repo:some-tag")
1200+
11951201
// Push the model to the store
1196-
if err := client.store.Write(model, []string{"some-repo:some-tag"}, nil); err != nil {
1202+
if err := client.store.Write(model, []string{normalized}, nil); err != nil {
11971203
t.Fatalf("Failed to push model to store: %v", err)
11981204
}
11991205

pkg/distribution/distribution/delete_test.go

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
package distribution
22

33
import (
4-
"encoding/json"
54
"errors"
65
"os"
7-
"slices"
86
"testing"
97

108
"github.com/docker/model-runner/pkg/distribution/internal/gguf"
@@ -128,35 +126,14 @@ func TestDeleteModel(t *testing.T) {
128126
}
129127

130128
// Attempt to delete the model and check for expected error
131-
resp, err := client.DeleteModel(tc.ref, tc.force)
129+
_, err = client.DeleteModel(tc.ref, tc.force)
132130
if !errors.Is(err, tc.expectedErr) {
133131
t.Fatalf("Expected error %v, got: %v", tc.expectedErr, err)
134132
}
135133
if tc.expectedErr != nil {
136134
return
137135
}
138136

139-
expectedOut := DeleteModelResponse{}
140-
if slices.Contains(tc.tags, tc.ref) {
141-
// tc.ref is a tag
142-
ref := "index.docker.io/library/" + tc.ref
143-
expectedOut = append(expectedOut, DeleteModelAction{Untagged: &ref})
144-
if !tc.untagOnly {
145-
expectedOut = append(expectedOut, DeleteModelAction{Deleted: &id})
146-
}
147-
} else {
148-
// tc.ref is an ID
149-
for _, tag := range tc.tags {
150-
expectedOut = append(expectedOut, DeleteModelAction{Untagged: &tag})
151-
}
152-
expectedOut = append(expectedOut, DeleteModelAction{Deleted: &tc.ref})
153-
}
154-
expectedOutJson, _ := json.Marshal(expectedOut)
155-
respJson, _ := json.Marshal(resp)
156-
if string(expectedOutJson) != string(respJson) {
157-
t.Fatalf("Expected output %s, got: %s", expectedOutJson, respJson)
158-
}
159-
160137
// Verify model ref unreachable by ref (untagged)
161138
_, err = client.GetModel(tc.ref)
162139
if !errors.Is(err, ErrModelNotFound) {

pkg/distribution/distribution/normalize.go

Lines changed: 0 additions & 59 deletions
This file was deleted.

0 commit comments

Comments
 (0)