Skip to content

Commit ec2c025

Browse files
ekcaseydoringeman
authored andcommitted
Exposes Packaging SDK (docker#74)
* Provides API for packaging models - introduces builder package for building artifacts - extracts registry packge for direct interaction with OCI registries Signed-off-by: Emily Casey <[email protected]> * cleanup Signed-off-by: Emily Casey <[email protected]> * slim down artifact interface Signed-off-by: Emily Casey <[email protected]> * cleanup Signed-off-by: Emily Casey <[email protected]> --------- Signed-off-by: Emily Casey <[email protected]>
1 parent bec13c3 commit ec2c025

File tree

12 files changed

+348
-176
lines changed

12 files changed

+348
-176
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package builder
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"io"
7+
8+
"github.com/docker/model-distribution/internal/gguf"
9+
"github.com/docker/model-distribution/internal/mutate"
10+
"github.com/docker/model-distribution/internal/partial"
11+
"github.com/docker/model-distribution/types"
12+
)
13+
14+
// Builder builds a model artifact
15+
type Builder struct {
16+
model types.ModelArtifact
17+
}
18+
19+
// FromGGUF returns a *Builder that builds a model artifacts from a GGUF file
20+
func FromGGUF(path string) (*Builder, error) {
21+
mdl, err := gguf.NewModel(path)
22+
if err != nil {
23+
return nil, err
24+
}
25+
return &Builder{
26+
model: mdl,
27+
}, nil
28+
}
29+
30+
// WithLicense adds a license file to the artifact
31+
func (b *Builder) WithLicense(path string) (*Builder, error) {
32+
licenseLayer, err := partial.NewLayer(path, types.MediaTypeLicense)
33+
if err != nil {
34+
return nil, fmt.Errorf("license layer from %q: %w", path, err)
35+
}
36+
return &Builder{
37+
model: mutate.AppendLayers(b.model, licenseLayer),
38+
}, nil
39+
}
40+
41+
// Target represents a build target
42+
type Target interface {
43+
Write(context.Context, types.ModelArtifact, io.Writer) error
44+
}
45+
46+
// Build finalizes the artifact and writes it to the given target, reporting progress to the given writer
47+
func (b *Builder) Build(ctx context.Context, target Target, pw io.Writer) error {
48+
return target.Write(ctx, b.model, pw)
49+
}

pkg/distribution/distribution/client.go

Lines changed: 26 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,20 @@ import (
66
"io"
77
"net/http"
88
"os"
9-
"strings"
109

11-
"github.com/google/go-containerregistry/pkg/authn"
12-
"github.com/google/go-containerregistry/pkg/name"
13-
v1 "github.com/google/go-containerregistry/pkg/v1"
14-
"github.com/google/go-containerregistry/pkg/v1/remote"
1510
"github.com/sirupsen/logrus"
1611

12+
"github.com/docker/model-distribution/internal/progress"
1713
"github.com/docker/model-distribution/internal/store"
14+
"github.com/docker/model-distribution/registry"
1815
"github.com/docker/model-distribution/types"
1916
)
2017

21-
const (
22-
defaultUserAgent = "model-distribution"
23-
)
24-
2518
// Client provides model distribution functionality
2619
type Client struct {
27-
store *store.LocalStore
28-
log *logrus.Entry
29-
remoteOptions []remote.Option
20+
store *store.LocalStore
21+
log *logrus.Entry
22+
registry *registry.Client
3023
}
3124

3225
// GetStorePath returns the root path where models are stored
@@ -84,8 +77,8 @@ func WithUserAgent(ua string) Option {
8477
func defaultOptions() *options {
8578
return &options{
8679
logger: logrus.NewEntry(logrus.StandardLogger()),
87-
transport: remote.DefaultTransport,
88-
userAgent: defaultUserAgent,
80+
transport: registry.DefaultTransport,
81+
userAgent: registry.DefaultUserAgent,
8982
}
9083
}
9184

@@ -111,50 +104,29 @@ func NewClient(opts ...Option) (*Client, error) {
111104
return &Client{
112105
store: s,
113106
log: options.logger,
114-
remoteOptions: []remote.Option{
115-
remote.WithAuthFromKeychain(authn.DefaultKeychain),
116-
remote.WithTransport(options.transport),
117-
remote.WithUserAgent(options.userAgent),
118-
},
107+
registry: registry.NewClient(
108+
registry.WithTransport(options.transport),
109+
registry.WithUserAgent(options.userAgent),
110+
),
119111
}, nil
120112
}
121113

122114
// PullModel pulls a model from a registry and returns the local file path
123115
func (c *Client) PullModel(ctx context.Context, reference string, progressWriter io.Writer) error {
124116
c.log.Infoln("Starting model pull:", reference)
125117

126-
// Parse the reference
127-
ref, err := name.ParseReference(reference)
128-
if err != nil {
129-
return NewReferenceError(reference, err)
130-
}
131-
132-
// First, check the remote registry for the model's digest
133-
c.log.Infoln("Checking remote registry for model:", reference)
134-
opts := append([]remote.Option{remote.WithContext(ctx)}, c.remoteOptions...)
135-
remoteImg, err := remote.Image(ref, opts...)
118+
remoteModel, err := c.registry.Model(ctx, reference)
136119
if err != nil {
137-
errStr := err.Error()
138-
if strings.Contains(errStr, "UNAUTHORIZED") {
139-
return NewPullError(reference, "UNAUTHORIZED", "Authentication required for this model", err)
140-
}
141-
if strings.Contains(errStr, "MANIFEST_UNKNOWN") {
142-
return NewPullError(reference, "MANIFEST_UNKNOWN", "Model not found", err)
143-
}
144-
if strings.Contains(errStr, "NAME_UNKNOWN") {
145-
return NewPullError(reference, "NAME_UNKNOWN", "Repository not found", err)
146-
}
147-
c.log.Errorln("Failed to check remote image:", err, "reference:", reference)
148-
return NewPullError(reference, "UNKNOWN", err.Error(), err)
120+
return fmt.Errorf("reading model from registry: %w", err)
149121
}
150122

151123
//Check for supported type
152-
if err := checkCompat(remoteImg); err != nil {
124+
if err := checkCompat(remoteModel); err != nil {
153125
return err
154126
}
155127

156128
// Get the remote image digest
157-
remoteDigest, err := remoteImg.Digest()
129+
remoteDigest, err := remoteModel.Digest()
158130
if err != nil {
159131
c.log.Errorln("Failed to get remote image digest:", err)
160132
return fmt.Errorf("getting remote image digest: %w", err)
@@ -178,7 +150,7 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter
178150

179151
// Report progress for local model
180152
size := fileInfo.Size()
181-
err = writeSuccess(progressWriter, fmt.Sprintf("Using cached model: %.2f MB", float64(size)/1024/1024))
153+
err = progress.WriteSuccess(progressWriter, fmt.Sprintf("Using cached model: %.2f MB", float64(size)/1024/1024))
182154
if err != nil {
183155
c.log.Warnf("Writing progress: %v", err)
184156
// If we fail to write progress, don't try again
@@ -196,23 +168,23 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter
196168

197169
// Model doesn't exist in local store or digests don't match, pull from remote
198170

199-
pr := newProgressReporter(progressWriter, pullMsg)
171+
pr := progress.NewProgressReporter(progressWriter, progress.PullMsg)
200172
defer func() {
201173
if err := pr.Wait(); err != nil {
202174
c.log.Warnf("Failed to write progress: %v", err)
203175
}
204176
}()
205177

206-
if err = c.store.Write(remoteImg, []string{reference}, pr.updates()); err != nil {
207-
if writeErr := writeError(progressWriter, fmt.Sprintf("Error: %s", err.Error())); writeErr != nil {
178+
if err = c.store.Write(remoteModel, []string{reference}, pr.Updates()); err != nil {
179+
if writeErr := progress.WriteError(progressWriter, fmt.Sprintf("Error: %s", err.Error())); writeErr != nil {
208180
c.log.Warnf("Failed to write error message: %v", writeErr)
209181
// If we fail to write error message, don't try again
210182
progressWriter = nil
211183
}
212184
return fmt.Errorf("writing image to store: %w", err)
213185
}
214186

215-
if err := writeSuccess(progressWriter, "Model pulled successfully"); err != nil {
187+
if err := progress.WriteSuccess(progressWriter, "Model pulled successfully"); err != nil {
216188
c.log.Warnf("Failed to write success message: %v", err)
217189
// If we fail to write success message, don't try again
218190
progressWriter = nil
@@ -307,9 +279,9 @@ func (c *Client) Tag(source string, target string) error {
307279
// PushModel pushes a tagged model from the content store to the registry.
308280
func (c *Client) PushModel(ctx context.Context, tag string, progressWriter io.Writer) (err error) {
309281
// Parse the tag
310-
ref, err := name.NewTag(tag)
282+
target, err := c.registry.NewTarget(tag)
311283
if err != nil {
312-
return fmt.Errorf("invalid tag %q: %w", tag, err)
284+
return fmt.Errorf("new tag: %w", err)
313285
}
314286

315287
// Get the model from the store
@@ -320,36 +292,23 @@ func (c *Client) PushModel(ctx context.Context, tag string, progressWriter io.Wr
320292

321293
// Push the model
322294
c.log.Infoln("Pushing model:", tag)
323-
324-
pr := newProgressReporter(progressWriter, pushMsg)
325-
defer func() {
326-
if err := pr.Wait(); err != nil {
327-
c.log.Warnf("Failed to write progress: %v", err)
328-
}
329-
}()
330-
331-
opts := append([]remote.Option{
332-
remote.WithContext(ctx),
333-
remote.WithProgress(pr.updates()),
334-
}, c.remoteOptions...)
335-
336-
if err := remote.Write(ref, mdl, opts...); err != nil {
295+
if err := target.Write(ctx, mdl, progressWriter); err != nil {
337296
c.log.Errorln("Failed to push image:", err, "reference:", tag)
338-
if writeErr := writeError(progressWriter, fmt.Sprintf("Error: %s", err.Error())); writeErr != nil {
297+
if writeErr := progress.WriteError(progressWriter, fmt.Sprintf("Error: %s", err.Error())); writeErr != nil {
339298
c.log.Warnf("Failed to write error message: %v", writeErr)
340299
}
341300
return fmt.Errorf("pushing image: %w", err)
342301
}
343302

344303
c.log.Infoln("Successfully pushed model:", tag)
345-
if err := writeSuccess(progressWriter, "Model pushed successfully"); err != nil {
304+
if err := progress.WriteSuccess(progressWriter, "Model pushed successfully"); err != nil {
346305
c.log.Warnf("Failed to write success message: %v", err)
347306
}
348307

349308
return nil
350309
}
351310

352-
func checkCompat(image v1.Image) error {
311+
func checkCompat(image types.ModelArtifact) error {
353312
manifest, err := image.Manifest()
354313
if err != nil {
355314
return err

pkg/distribution/distribution/client_test.go

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@ import (
1616
"strings"
1717
"testing"
1818

19+
"github.com/docker/model-distribution/internal/progress"
1920
"github.com/google/go-containerregistry/pkg/name"
2021
"github.com/google/go-containerregistry/pkg/registry"
2122
"github.com/google/go-containerregistry/pkg/v1/remote"
2223
"github.com/sirupsen/logrus"
2324

2425
"github.com/docker/model-distribution/internal/gguf"
2526
"github.com/docker/model-distribution/internal/mutate"
27+
mdregistry "github.com/docker/model-distribution/registry"
2628
)
2729

2830
var (
@@ -158,8 +160,8 @@ func TestClientPullModel(t *testing.T) {
158160
t.Fatal("Expected error for non-existent model, got nil")
159161
}
160162

161-
// Verify it's a PullError
162-
var pullErr *PullError
163+
// Verify it's a registry.Error
164+
var pullErr *mdregistry.Error
163165
ok := errors.As(err, &pullErr)
164166
if !ok {
165167
t.Fatalf("Expected PullError, got %T", err)
@@ -178,7 +180,7 @@ func TestClientPullModel(t *testing.T) {
178180
if pullErr.Err == nil {
179181
t.Error("Expected underlying error to be non-nil")
180182
}
181-
if !errors.Is(pullErr, ErrModelNotFound) {
183+
if !errors.Is(pullErr, mdregistry.ErrModelNotFound) {
182184
t.Errorf("Expected underlying error to match ErrModelNotFound, got %v", pullErr.Err)
183185
}
184186
})
@@ -422,11 +424,11 @@ func TestClientPullModel(t *testing.T) {
422424
}
423425

424426
// Parse progress output as JSON
425-
var messages []ProgressMessage
427+
var messages []progress.ProgressMessage
426428
scanner := bufio.NewScanner(&progressBuffer)
427429
for scanner.Scan() {
428430
line := scanner.Text()
429-
var msg ProgressMessage
431+
var msg progress.ProgressMessage
430432
if err := json.Unmarshal([]byte(line), &msg); err != nil {
431433
t.Fatalf("Failed to parse JSON progress message: %v, line: %s", err, line)
432434
}
@@ -496,10 +498,9 @@ func TestClientPullModel(t *testing.T) {
496498
t.Fatal("Expected error for non-existent model, got nil")
497499
}
498500

499-
// Verify it's a PullError
500-
var pullErr *PullError
501-
if !errors.As(err, &pullErr) {
502-
t.Fatalf("Expected PullError, got %T", err)
501+
// Verify it matches registry.ErrModelNotFound
502+
if !errors.Is(err, mdregistry.ErrModelNotFound) {
503+
t.Fatalf("Expected registry.ErrModelNotFound, got %T", err)
503504
}
504505

505506
// No JSON messages should be in the buffer for this error case
@@ -815,18 +816,8 @@ func TestNewReferenceError(t *testing.T) {
815816
t.Fatal("Expected error for invalid reference, got nil")
816817
}
817818

818-
// Verify it's a ReferenceError
819-
refErr, ok := err.(*ReferenceError)
820-
if !ok {
821-
t.Fatalf("Expected ReferenceError, got %T", err)
822-
}
823-
824-
// Verify error fields
825-
if refErr.Reference != invalidRef {
826-
t.Errorf("Expected reference %q, got %q", invalidRef, refErr.Reference)
827-
}
828-
if refErr.Err == nil {
829-
t.Error("Expected underlying error to be non-nil")
819+
if !errors.Is(err, ErrInvalidReference) {
820+
t.Fatalf("Expected error to match sentinel invalid reference error, got %v", err)
830821
}
831822
}
832823

pkg/distribution/distribution/errors.go

Lines changed: 3 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@ import (
55
"fmt"
66

77
"github.com/docker/model-distribution/internal/store"
8+
"github.com/docker/model-distribution/registry"
89
"github.com/docker/model-distribution/types"
910
)
1011

1112
var (
12-
ErrInvalidReference = errors.New("invalid model reference")
13-
ErrModelNotFound = store.ErrModelNotFound
14-
ErrUnauthorized = errors.New("unauthorized access to model")
13+
ErrInvalidReference = registry.ErrInvalidReference
14+
ErrModelNotFound = store.ErrModelNotFound // model not found in store
1515
ErrUnsupportedMediaType = errors.New(fmt.Sprintf(
1616
"client supports only models of type %q and older - try upgrading",
1717
types.MediaTypeModelConfigV01,
@@ -37,51 +37,3 @@ func (e *ReferenceError) Unwrap() error {
3737
func (e *ReferenceError) Is(target error) bool {
3838
return target == ErrInvalidReference
3939
}
40-
41-
// PullError represents an error that occurs when pulling a model
42-
type PullError struct {
43-
Reference string
44-
// Code should be one of error codes defined in the distribution spec
45-
// (see https://github.com/opencontainers/distribution-spec/blob/583e014d15418d839d67f68152bc2c83821770e0/spec.md#error-codes)
46-
Code string
47-
Message string
48-
Err error
49-
}
50-
51-
func (e *PullError) Error() string {
52-
return fmt.Sprintf("failed to pull model %q: %s - %s", e.Reference, e.Code, e.Message)
53-
}
54-
55-
func (e *PullError) Unwrap() error {
56-
return e.Err
57-
}
58-
59-
// Is implements error matching for PullError
60-
func (e *PullError) Is(target error) bool {
61-
switch target {
62-
case ErrModelNotFound:
63-
return e.Code == "MANIFEST_UNKNOWN" || e.Code == "NAME_UNKNOWN"
64-
case ErrUnauthorized:
65-
return e.Code == "UNAUTHORIZED"
66-
default:
67-
return false
68-
}
69-
}
70-
71-
// NewReferenceError creates a new ReferenceError
72-
func NewReferenceError(reference string, err error) error {
73-
return &ReferenceError{
74-
Reference: reference,
75-
Err: err,
76-
}
77-
}
78-
79-
// NewPullError creates a new PullError
80-
func NewPullError(reference, code, message string, err error) error {
81-
return &PullError{
82-
Reference: reference,
83-
Code: code,
84-
Message: message,
85-
Err: err,
86-
}
87-
}

0 commit comments

Comments
 (0)