Skip to content
19 changes: 12 additions & 7 deletions pkg/distribution/distribution/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ import (

// Client provides model distribution functionality
type Client struct {
store *store.LocalStore
log *logrus.Entry
registry *registry.Client
store *store.LocalStore
log *logrus.Entry
registry *registry.Client
transport http.RoundTripper
}

// GetStorePath returns the root path where models are stored
Expand Down Expand Up @@ -130,9 +131,10 @@ func NewClient(opts ...Option) (*Client, error) {

options.logger.Infoln("Successfully initialized store")
return &Client{
store: s,
log: options.logger,
registry: registry.NewClient(registryOpts...),
store: s,
log: options.logger,
registry: registry.NewClient(registryOpts...),
transport: options.transport,
}, nil
}

Expand Down Expand Up @@ -185,7 +187,10 @@ func (c *Client) PullModel(ctx context.Context, reference string, progressWriter

// Model doesn't exist in local store or digests don't match, pull from remote

if err = c.store.Write(remoteModel, []string{reference}, progressWriter); err != nil {
// Wrap the remote model with resumable download support
resumableModel := c.wrapWithResumableImage(ctx, remoteModel, reference, c.transport)

if err = c.store.Write(resumableModel, []string{reference}, progressWriter); err != nil {
if writeErr := progress.WriteError(progressWriter, fmt.Sprintf("Error: %s", err.Error())); writeErr != nil {
c.log.Warnf("Failed to write error message: %v", writeErr)
// If we fail to write error message, don't try again
Expand Down
72 changes: 72 additions & 0 deletions pkg/distribution/distribution/resumable_image.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package distribution

import (
"context"
"net/http"

"github.com/docker/model-runner/pkg/distribution/internal/store"
"github.com/docker/model-runner/pkg/distribution/registry"
v1 "github.com/google/go-containerregistry/pkg/v1"
)

// resumableImage wraps a v1.Image and returns ResumableLayers
type resumableImage struct {
v1.Image
store *store.LocalStore
httpClient *http.Client
reference string
registry *registry.Client
ctx context.Context
}

// Layers returns wrapped layers with resumable download support
func (ri *resumableImage) Layers() ([]v1.Layer, error) {
layers, err := ri.Image.Layers()
if err != nil {
return nil, err
}

// Wrap each layer with resumable support
wrapped := make([]v1.Layer, len(layers))
for i, layer := range layers {
// Get digest for this layer
digest, err := layer.Digest()
if err != nil {
// If we can't get digest, just use original layer
wrapped[i] = layer
continue
}

// Get blob URL for this layer
blobURL, err := ri.registry.BlobURL(ri.reference, digest)
if err != nil {
// If we can't get URL, just use original layer
wrapped[i] = layer
continue
}

// Get auth token
authToken, err := ri.registry.BearerToken(ri.ctx, ri.reference)
if err != nil {
// If we can't get auth token, try without it
authToken = ""
}

// Create resumable layer
wrapped[i] = store.NewResumableLayer(layer, ri.store, ri.httpClient, blobURL, authToken)
}

return wrapped, nil
}

// wrapWithResumableImage wraps an image to support resumable downloads
func (c *Client) wrapWithResumableImage(ctx context.Context, img v1.Image, reference string, transport http.RoundTripper) v1.Image {
return &resumableImage{
Image: img,
store: c.store,
httpClient: &http.Client{Transport: transport},
reference: reference,
registry: c.registry,
ctx: ctx,
}
}
24 changes: 19 additions & 5 deletions pkg/distribution/internal/store/blobs.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ type blob interface {
// writeLayer writes the layer blob to the store.
// It returns true when a new blob was created and the blob's DiffID.
func (s *LocalStore) writeLayer(layer blob, updates chan<- v1.Update) (bool, v1.Hash, error) {
// Check if this is a ResumableLayer and use its special download method
if resumableLayer, ok := layer.(*ResumableLayer); ok {
return resumableLayer.DownloadAndDecompress(updates)
}

// Standard layer download (non-resumable)
hash, err := layer.DiffID()
if err != nil {
return false, v1.Hash{}, fmt.Errorf("get file hash: %w", err)
Expand Down Expand Up @@ -122,21 +128,29 @@ func (s *LocalStore) WriteBlob(diffID v1.Hash, r io.Reader) error {
if err != nil {
return fmt.Errorf("get blob path: %w", err)
}
f, err := createFile(incompletePath(path))

incompletePath := incompletePath(path)

// Create new incomplete file (always truncate for decompressed data)
f, err := createFile(incompletePath)
if err != nil {
return fmt.Errorf("create blob file: %w", err)
}
defer os.Remove(incompletePath(path))
defer f.Close()

if _, err := io.Copy(f, r); err != nil {
return fmt.Errorf("copy blob %q to store: %w", diffID.String(), err)
// Write data
written, err := io.Copy(f, r)
if err != nil {
// Keep incomplete file for potential future resume support
return fmt.Errorf("copy blob %q to store (wrote %d bytes): %w", diffID.String(), written, err)
}

f.Close() // Rename will fail on Windows if the file is still open.
if err := os.Rename(incompletePath(path), path); err != nil {
if err := os.Rename(incompletePath, path); err != nil {
return fmt.Errorf("rename blob file: %w", err)
}
// Clean up incomplete file after successful rename
os.Remove(incompletePath)
return nil
}

Expand Down
9 changes: 6 additions & 3 deletions pkg/distribution/internal/store/blobs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,17 @@ func TestBlobs(t *testing.T) {
t.Fatalf("expected blob file not to exist")
}

// ensure incomplete file is not left behind
// With resumable downloads, incomplete file should be kept for potential resume
blobPath3, err := store.blobPath(hash)
if err != nil {
t.Fatalf("error getting blob path: %v", err)
}
if _, err := os.ReadFile(incompletePath(blobPath3)); !errors.Is(err, os.ErrNotExist) {
t.Fatalf("expected incomplete blob file not to exist")
if _, err := os.Stat(incompletePath(blobPath3)); err != nil {
t.Fatalf("expected incomplete blob file to exist for resume, got: %v", err)
}

// Clean up for other tests
os.Remove(incompletePath(blobPath3))
})

t.Run("WriteBlob reuses existing blob", func(t *testing.T) {
Expand Down
184 changes: 184 additions & 0 deletions pkg/distribution/internal/store/resumable.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
package store

import (
"compress/gzip"
"fmt"
"io"
"net/http"
"os"

"github.com/docker/model-runner/pkg/distribution/internal/progress"
v1 "github.com/google/go-containerregistry/pkg/v1"
)

// ResumableLayer wraps a v1.Layer and adds resumable download capability
type ResumableLayer struct {
v1.Layer
store *LocalStore
httpClient *http.Client
blobURL string
authToken string
}

// NewResumableLayer creates a resumable layer wrapper
func NewResumableLayer(
layer v1.Layer,
store *LocalStore,
httpClient *http.Client,
blobURL string,
authToken string,
) *ResumableLayer {
return &ResumableLayer{
Layer: layer,
store: store,
httpClient: httpClient,
blobURL: blobURL,
authToken: authToken,
}
}

// Compressed returns an io.ReadCloser for the compressed layer contents
// Note: Resume logic is handled in DownloadAndDecompress, this just returns the full layer
func (rl *ResumableLayer) Compressed() (io.ReadCloser, error) {
return rl.Layer.Compressed()
}

// DownloadAndDecompress downloads the layer with resume support and decompresses it
func (rl *ResumableLayer) DownloadAndDecompress(updates chan<- v1.Update) (bool, v1.Hash, error) {
diffID, err := rl.DiffID()
if err != nil {
return false, v1.Hash{}, fmt.Errorf("get diff ID: %w", err)
}

// Check if we already have this blob
hasBlob, err := rl.store.hasBlob(diffID)
if err != nil {
return false, v1.Hash{}, fmt.Errorf("check blob existence: %w", err)
}
if hasBlob {
return false, diffID, nil
}

// Get the compressed digest
compressedDigest, err := rl.Digest()
if err != nil {
return false, v1.Hash{}, fmt.Errorf("get compressed digest: %w", err)
}

// Get path for storing compressed data
compressedPath, err := rl.store.blobPath(compressedDigest)
if err != nil {
return false, v1.Hash{}, fmt.Errorf("get compressed path: %w", err)
}
// Use a different suffix for compressed incomplete files to avoid conflicts
compressedIncompletePath := compressedPath + ".compressed.incomplete"

// Check for existing incomplete file
var offset int64
if stat, err := os.Stat(compressedIncompletePath); err == nil {
offset = stat.Size()
}

// Get compressed reader with resume support if we have an offset
var compressedReader io.ReadCloser
if offset > 0 && rl.httpClient != nil && rl.blobURL != "" {
// Try to resume with HTTP Range
req, err := http.NewRequest("GET", rl.blobURL, nil)
if err == nil {
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", offset))
if rl.authToken != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", rl.authToken))
}

resp, err := rl.httpClient.Do(req)
if err == nil {
if resp.StatusCode == http.StatusPartialContent {
// Successfully resumed!
compressedReader = resp.Body
} else {
// Server doesn't support range, start fresh
resp.Body.Close()
offset = 0
os.Remove(compressedIncompletePath)
// Fall through to get full layer
}
}
}
}

// If we couldn't resume, get the full layer
if compressedReader == nil {
var err error
compressedReader, err = rl.Layer.Compressed()
if err != nil {
return false, v1.Hash{}, fmt.Errorf("get compressed reader: %w", err)
}
}
defer compressedReader.Close()

// Wrap compressed reader with progress reporting for download
var progressReader io.Reader = compressedReader
if updates != nil {
progressReader = progress.NewReader(compressedReader, updates)
}

// Open file for writing (append if offset > 0)
var compressedFile *os.File
if offset > 0 {
compressedFile, err = os.OpenFile(compressedIncompletePath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0666)
if err != nil {
return false, v1.Hash{}, fmt.Errorf("open compressed file for append: %w", err)
}
} else {
compressedFile, err = createFile(compressedIncompletePath)
if err != nil {
return false, v1.Hash{}, fmt.Errorf("create compressed file: %w", err)
}
}

// Download compressed data with progress reporting
written, err := io.Copy(compressedFile, progressReader)
if err != nil {
compressedFile.Close()
// Keep incomplete file for resume
return false, v1.Hash{}, fmt.Errorf("download compressed (offset=%d, wrote=%d): %w", offset, written, err)
}
compressedFile.Close()

// Decompress the complete file
compressedFile, err = os.Open(compressedIncompletePath)
if err != nil {
return false, v1.Hash{}, fmt.Errorf("open for decompression: %w", err)
}

// Try to decompress - if it fails, the data might already be uncompressed
gzipReader, err := gzip.NewReader(compressedFile)
var reader io.Reader
if err != nil {
// Data is not gzipped, use it directly
// Need to reopen the file since gzip.NewReader consumed some bytes
compressedFile.Close()
compressedFile, err = os.Open(compressedIncompletePath)
if err != nil {
return false, v1.Hash{}, fmt.Errorf("reopen for direct read: %w", err)
}
defer compressedFile.Close()
reader = compressedFile
} else {
// gzipReader wraps compressedFile, so close them in proper order
defer gzipReader.Close()
defer compressedFile.Close()
reader = gzipReader
}

// Write data (no progress wrapping here since download progress was already reported)
if err := rl.store.WriteBlob(diffID, reader); err != nil {
return false, v1.Hash{}, fmt.Errorf("write blob: %w", err)
}

// Clean up compressed file
os.Remove(compressedIncompletePath)

return true, diffID, nil
}