@@ -6,27 +6,20 @@ import (
66 "io"
77 "net/http"
88 "os"
9- "strings"
109
11- "github.com/google/go-containerregistry/pkg/authn"
12- "github.com/google/go-containerregistry/pkg/name"
13- v1 "github.com/google/go-containerregistry/pkg/v1"
14- "github.com/google/go-containerregistry/pkg/v1/remote"
1510 "github.com/sirupsen/logrus"
1611
12+ "github.com/docker/model-distribution/internal/progress"
1713 "github.com/docker/model-distribution/internal/store"
14+ "github.com/docker/model-distribution/registry"
1815 "github.com/docker/model-distribution/types"
1916)
2017
21- const (
22- defaultUserAgent = "model-distribution"
23- )
24-
2518// Client provides model distribution functionality
2619type Client struct {
27- store * store.LocalStore
28- log * logrus.Entry
29- remoteOptions []remote. Option
20+ store * store.LocalStore
21+ log * logrus.Entry
22+ registry * registry. Client
3023}
3124
3225// GetStorePath returns the root path where models are stored
@@ -84,8 +77,8 @@ func WithUserAgent(ua string) Option {
8477func defaultOptions () * options {
8578 return & options {
8679 logger : logrus .NewEntry (logrus .StandardLogger ()),
87- transport : remote .DefaultTransport ,
88- userAgent : defaultUserAgent ,
80+ transport : registry .DefaultTransport ,
81+ userAgent : registry . DefaultUserAgent ,
8982 }
9083}
9184
@@ -111,50 +104,29 @@ func NewClient(opts ...Option) (*Client, error) {
111104 return & Client {
112105 store : s ,
113106 log : options .logger ,
114- remoteOptions : []remote.Option {
115- remote .WithAuthFromKeychain (authn .DefaultKeychain ),
116- remote .WithTransport (options .transport ),
117- remote .WithUserAgent (options .userAgent ),
118- },
107+ registry : registry .NewClient (
108+ registry .WithTransport (options .transport ),
109+ registry .WithUserAgent (options .userAgent ),
110+ ),
119111 }, nil
120112}
121113
122114// PullModel pulls a model from a registry and returns the local file path
123115func (c * Client ) PullModel (ctx context.Context , reference string , progressWriter io.Writer ) error {
124116 c .log .Infoln ("Starting model pull:" , reference )
125117
126- // Parse the reference
127- ref , err := name .ParseReference (reference )
128- if err != nil {
129- return NewReferenceError (reference , err )
130- }
131-
132- // First, check the remote registry for the model's digest
133- c .log .Infoln ("Checking remote registry for model:" , reference )
134- opts := append ([]remote.Option {remote .WithContext (ctx )}, c .remoteOptions ... )
135- remoteImg , err := remote .Image (ref , opts ... )
118+ remoteModel , err := c .registry .Model (ctx , reference )
136119 if err != nil {
137- errStr := err .Error ()
138- if strings .Contains (errStr , "UNAUTHORIZED" ) {
139- return NewPullError (reference , "UNAUTHORIZED" , "Authentication required for this model" , err )
140- }
141- if strings .Contains (errStr , "MANIFEST_UNKNOWN" ) {
142- return NewPullError (reference , "MANIFEST_UNKNOWN" , "Model not found" , err )
143- }
144- if strings .Contains (errStr , "NAME_UNKNOWN" ) {
145- return NewPullError (reference , "NAME_UNKNOWN" , "Repository not found" , err )
146- }
147- c .log .Errorln ("Failed to check remote image:" , err , "reference:" , reference )
148- return NewPullError (reference , "UNKNOWN" , err .Error (), err )
120+ return fmt .Errorf ("reading model from registry: %w" , err )
149121 }
150122
151123 //Check for supported type
152- if err := checkCompat (remoteImg ); err != nil {
124+ if err := checkCompat (remoteModel ); err != nil {
153125 return err
154126 }
155127
156128 // Get the remote image digest
157- remoteDigest , err := remoteImg .Digest ()
129+ remoteDigest , err := remoteModel .Digest ()
158130 if err != nil {
159131 c .log .Errorln ("Failed to get remote image digest:" , err )
160132 return fmt .Errorf ("getting remote image digest: %w" , err )
@@ -178,7 +150,7 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter
178150
179151 // Report progress for local model
180152 size := fileInfo .Size ()
181- err = writeSuccess (progressWriter , fmt .Sprintf ("Using cached model: %.2f MB" , float64 (size )/ 1024 / 1024 ))
153+ err = progress . WriteSuccess (progressWriter , fmt .Sprintf ("Using cached model: %.2f MB" , float64 (size )/ 1024 / 1024 ))
182154 if err != nil {
183155 c .log .Warnf ("Writing progress: %v" , err )
184156 // If we fail to write progress, don't try again
@@ -196,23 +168,23 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter
196168
197169 // Model doesn't exist in local store or digests don't match, pull from remote
198170
199- pr := newProgressReporter (progressWriter , pullMsg )
171+ pr := progress . NewProgressReporter (progressWriter , progress . PullMsg )
200172 defer func () {
201173 if err := pr .Wait (); err != nil {
202174 c .log .Warnf ("Failed to write progress: %v" , err )
203175 }
204176 }()
205177
206- if err = c .store .Write (remoteImg , []string {reference }, pr .updates ()); err != nil {
207- if writeErr := writeError (progressWriter , fmt .Sprintf ("Error: %s" , err .Error ())); writeErr != nil {
178+ if err = c .store .Write (remoteModel , []string {reference }, pr .Updates ()); err != nil {
179+ if writeErr := progress . WriteError (progressWriter , fmt .Sprintf ("Error: %s" , err .Error ())); writeErr != nil {
208180 c .log .Warnf ("Failed to write error message: %v" , writeErr )
209181 // If we fail to write error message, don't try again
210182 progressWriter = nil
211183 }
212184 return fmt .Errorf ("writing image to store: %w" , err )
213185 }
214186
215- if err := writeSuccess (progressWriter , "Model pulled successfully" ); err != nil {
187+ if err := progress . WriteSuccess (progressWriter , "Model pulled successfully" ); err != nil {
216188 c .log .Warnf ("Failed to write success message: %v" , err )
217189 // If we fail to write success message, don't try again
218190 progressWriter = nil
@@ -307,9 +279,9 @@ func (c *Client) Tag(source string, target string) error {
307279// PushModel pushes a tagged model from the content store to the registry.
308280func (c * Client ) PushModel (ctx context.Context , tag string , progressWriter io.Writer ) (err error ) {
309281 // Parse the tag
310- ref , err := name . NewTag (tag )
282+ target , err := c . registry . NewTarget (tag )
311283 if err != nil {
312- return fmt .Errorf ("invalid tag %q : %w" , tag , err )
284+ return fmt .Errorf ("new tag: %w" , err )
313285 }
314286
315287 // Get the model from the store
@@ -320,36 +292,23 @@ func (c *Client) PushModel(ctx context.Context, tag string, progressWriter io.Wr
320292
321293 // Push the model
322294 c .log .Infoln ("Pushing model:" , tag )
323-
324- pr := newProgressReporter (progressWriter , pushMsg )
325- defer func () {
326- if err := pr .Wait (); err != nil {
327- c .log .Warnf ("Failed to write progress: %v" , err )
328- }
329- }()
330-
331- opts := append ([]remote.Option {
332- remote .WithContext (ctx ),
333- remote .WithProgress (pr .updates ()),
334- }, c .remoteOptions ... )
335-
336- if err := remote .Write (ref , mdl , opts ... ); err != nil {
295+ if err := target .Write (ctx , mdl , progressWriter ); err != nil {
337296 c .log .Errorln ("Failed to push image:" , err , "reference:" , tag )
338- if writeErr := writeError (progressWriter , fmt .Sprintf ("Error: %s" , err .Error ())); writeErr != nil {
297+ if writeErr := progress . WriteError (progressWriter , fmt .Sprintf ("Error: %s" , err .Error ())); writeErr != nil {
339298 c .log .Warnf ("Failed to write error message: %v" , writeErr )
340299 }
341300 return fmt .Errorf ("pushing image: %w" , err )
342301 }
343302
344303 c .log .Infoln ("Successfully pushed model:" , tag )
345- if err := writeSuccess (progressWriter , "Model pushed successfully" ); err != nil {
304+ if err := progress . WriteSuccess (progressWriter , "Model pushed successfully" ); err != nil {
346305 c .log .Warnf ("Failed to write success message: %v" , err )
347306 }
348307
349308 return nil
350309}
351310
352- func checkCompat (image v1. Image ) error {
311+ func checkCompat (image types. ModelArtifact ) error {
353312 manifest , err := image .Manifest ()
354313 if err != nil {
355314 return err
0 commit comments