Skip to content

Commit 95ad19a

Browse files
committed
deps: vendor utility dependencies
Signed-off-by: Jacob Howard <[email protected]>
1 parent bd68cb3 commit 95ad19a

File tree

6 files changed

+426
-16
lines changed

6 files changed

+426
-16
lines changed

pkg/inference/backends/llamacpp/download.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import (
1313
"runtime"
1414
"strings"
1515

16-
"github.com/docker/model-runner/pkg/dockerhub"
16+
"github.com/docker/model-runner/pkg/internal/dockerhub"
1717
"github.com/docker/model-runner/pkg/paths"
1818
)
1919

pkg/inference/models/manager.go

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@ import (
1111
"github.com/docker/model-distribution/pkg/distribution"
1212
"github.com/docker/model-distribution/pkg/types"
1313
"github.com/docker/model-runner/pkg/inference"
14-
"github.com/docker/model-runner/pkg/ipc"
1514
"github.com/docker/model-runner/pkg/logger"
16-
"github.com/docker/model-runner/pkg/paths"
1715
)
1816

1917
const (
@@ -36,24 +34,13 @@ type Manager struct {
3634
}
3735

3836
// NewManager creates a new model's manager.
39-
func NewManager(log logger.ComponentLogger, transport http.RoundTripper) *Manager {
40-
// Create the distribution client
41-
distributionClient, err := distribution.NewClient(
42-
distribution.WithStoreRootPath(paths.DockerHome("models")),
43-
distribution.WithLogger(log.WithField("component", "model-distribution")),
44-
distribution.WithTransport(transport),
45-
distribution.WithUserAgent(ipc.UserAgent),
46-
)
47-
if err != nil {
48-
log.Errorf("Failed to create distribution client: %v", err)
49-
// Continue without distribution client
50-
}
37+
func NewManager(log logger.ComponentLogger, client *distribution.Client) *Manager {
5138
// Create the manager.
5239
m := &Manager{
5340
log: log,
5441
pullTokens: make(chan struct{}, maximumConcurrentModelPulls),
5542
router: http.NewServeMux(),
56-
distributionClient: distributionClient,
43+
distributionClient: client,
5744
}
5845

5946
// Register routes.

pkg/internal/archive/archive.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package archive
2+
3+
import (
4+
"fmt"
5+
"path/filepath"
6+
"strings"
7+
)
8+
9+
// CheckRelative returns an error if the filename path escapes dir.
10+
// This is used to protect against path traversal attacks when extracting archives.
11+
// It also rejects absolute filename paths.
12+
func CheckRelative(dir, filename string) (string, error) {
13+
if filepath.IsAbs(filename) {
14+
return "", fmt.Errorf("archive path has absolute path: %q", filename)
15+
}
16+
target := filepath.Join(dir, filename)
17+
if resolved, err := filepath.EvalSymlinks(target); err == nil {
18+
target = resolved
19+
if resolved, err = filepath.EvalSymlinks(dir); err == nil {
20+
dir = resolved
21+
}
22+
}
23+
rel, err := filepath.Rel(dir, target)
24+
if err != nil {
25+
return "", err
26+
}
27+
if strings.HasPrefix(rel, "..") {
28+
return "", fmt.Errorf("archive file %q escapes %q", target, dir)
29+
}
30+
return target, nil
31+
}
32+
33+
// CheckSymlink returns an error if the link path escapes dir.
34+
// This is used to protect against path traversal attacks when extracting archives.
35+
// It also rejects absolute linkname paths.
36+
func CheckSymlink(dir, name, linkname string) error {
37+
if filepath.IsAbs(linkname) {
38+
return fmt.Errorf("archive path has absolute link: %q", linkname)
39+
}
40+
_, err := CheckRelative(dir, filepath.Join(filepath.Dir(name), linkname))
41+
return err
42+
}

pkg/internal/dockerhub/download.go

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
package dockerhub
2+
3+
import (
4+
"context"
5+
"encoding/base64"
6+
"errors"
7+
"fmt"
8+
"log"
9+
"os"
10+
"path/filepath"
11+
"strings"
12+
"time"
13+
14+
"github.com/containerd/containerd/content"
15+
"github.com/containerd/containerd/content/local"
16+
"github.com/containerd/containerd/images"
17+
"github.com/containerd/containerd/images/archive"
18+
"github.com/containerd/containerd/remotes"
19+
"github.com/containerd/containerd/remotes/docker"
20+
"github.com/containerd/platforms"
21+
"github.com/docker/model-runner/pkg/internal/jsonutil"
22+
v1 "github.com/opencontainers/image-spec/specs-go/v1"
23+
"github.com/sirupsen/logrus"
24+
)
25+
26+
func PullPlatform(ctx context.Context, image, destination, requiredOs, requiredArch string) error {
27+
if err := os.MkdirAll(filepath.Dir(destination), 0o755); err != nil {
28+
return fmt.Errorf("creating destination directory %s: %w", filepath.Dir(destination), err)
29+
}
30+
output, err := os.Create(destination)
31+
if err != nil {
32+
return fmt.Errorf("creating destination file %s: %w", destination, err)
33+
}
34+
tmpDir, err := os.MkdirTemp("", "docker-pull")
35+
if err != nil {
36+
return fmt.Errorf("creating temp directory: %w", err)
37+
}
38+
defer os.RemoveAll(tmpDir)
39+
store, err := local.NewStore(tmpDir)
40+
if err != nil {
41+
return fmt.Errorf("creating new content store: %w", err)
42+
}
43+
desc, err := retry(ctx, 10, 1*time.Second, func() (*v1.Descriptor, error) { return fetch(ctx, store, image, requiredOs, requiredArch) })
44+
if err != nil {
45+
return fmt.Errorf("fetching image: %w", err)
46+
}
47+
return archive.Export(ctx, store, output, archive.WithManifest(*desc, image), archive.WithSkipMissing(store))
48+
}
49+
50+
func retry(ctx context.Context, attempts int, sleep time.Duration, f func() (*v1.Descriptor, error)) (*v1.Descriptor, error) {
51+
var err error
52+
var result *v1.Descriptor
53+
for i := 0; i < attempts; i++ {
54+
if i > 0 {
55+
log.Printf("retry %d after error: %v\n", i, err)
56+
select {
57+
case <-ctx.Done():
58+
return nil, ctx.Err()
59+
case <-time.After(sleep):
60+
}
61+
}
62+
result, err = f()
63+
if err == nil {
64+
return result, nil
65+
}
66+
}
67+
return nil, fmt.Errorf("after %d attempts, last error: %s", attempts, err)
68+
}
69+
70+
func fetch(ctx context.Context, store content.Store, ref, requiredOs, requiredArch string) (*v1.Descriptor, error) {
71+
resolver := docker.NewResolver(docker.ResolverOptions{
72+
Hosts: docker.ConfigureDefaultRegistries(
73+
docker.WithAuthorizer(
74+
docker.NewDockerAuthorizer(
75+
docker.WithAuthCreds(dockerCredentials)))),
76+
})
77+
name, desc, err := resolver.Resolve(ctx, ref)
78+
if err != nil {
79+
return nil, err
80+
}
81+
fetcher, err := resolver.Fetcher(ctx, name)
82+
if err != nil {
83+
return nil, err
84+
}
85+
86+
childrenHandler := images.ChildrenHandler(store)
87+
if requiredOs != "" && requiredArch != "" {
88+
requiredPlatform := platforms.Only(v1.Platform{OS: requiredOs, Architecture: requiredArch})
89+
childrenHandler = images.LimitManifests(images.FilterPlatforms(images.ChildrenHandler(store), requiredPlatform), requiredPlatform, 1)
90+
}
91+
h := images.Handlers(remotes.FetchHandler(store, fetcher), childrenHandler)
92+
if err := images.Dispatch(ctx, h, nil, desc); err != nil {
93+
return nil, err
94+
}
95+
return &desc, nil
96+
}
97+
98+
func dockerCredentials(host string) (string, string, error) {
99+
hubUsername, hubPassword := os.Getenv("DOCKER_HUB_USER"), os.Getenv("DOCKER_HUB_PASSWORD")
100+
if hubUsername != "" && hubPassword != "" {
101+
return hubUsername, hubPassword, nil
102+
}
103+
logrus.WithField("host", host).Debug("checking for registry auth config")
104+
home, err := os.UserHomeDir()
105+
if err != nil {
106+
return "", "", err
107+
}
108+
credentialConfig := filepath.Join(home, ".docker", "config.json")
109+
cfg := struct {
110+
Auths map[string]struct {
111+
Auth string
112+
}
113+
}{}
114+
if err := jsonutil.ReadFile(credentialConfig, &cfg); err != nil {
115+
if errors.Is(err, os.ErrNotExist) {
116+
return "", "", nil
117+
}
118+
return "", "", err
119+
}
120+
for h, r := range cfg.Auths {
121+
if h == host {
122+
creds, err := base64.StdEncoding.DecodeString(r.Auth)
123+
if err != nil {
124+
return "", "", err
125+
}
126+
parts := strings.SplitN(string(creds), ":", 2)
127+
if len(parts) != 2 {
128+
logrus.Debugf("skipping not user/password auth for registry %s: %s", host, parts[0])
129+
return "", "", nil
130+
}
131+
logrus.Debugf("using auth for registry %s: user=%s", host, parts[0])
132+
return parts[0], parts[1], nil
133+
}
134+
}
135+
return "", "", nil
136+
}

0 commit comments

Comments
 (0)