Skip to content

Commit bd2de8b

Browse files
committed
feat: support export to output path for pull command
Signed-off-by: chlins <[email protected]>
1 parent 8c0cf95 commit bd2de8b

File tree

8 files changed

+142
-20
lines changed

8 files changed

+142
-20
lines changed

cmd/pull.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ func init() {
4747
flags := pullCmd.Flags()
4848
flags.BoolVar(&pullConfig.PlainHTTP, "plain-http", false, "use plain HTTP instead of HTTPS")
4949
flags.StringVar(&pullConfig.Proxy, "proxy", "", "use proxy for the pull operation")
50+
flags.StringVar(&pullConfig.OutputPath, "output", "", "specify the output path for exporting the model artifact")
5051

5152
if err := viper.BindPFlags(flags); err != nil {
5253
panic(fmt.Errorf("bind cache pull flags to viper: %w", err))
@@ -73,6 +74,10 @@ func runPull(ctx context.Context, target string) error {
7374
opts = append(opts, backend.WithProxy(pullConfig.Proxy))
7475
}
7576

77+
if pullConfig.OutputPath != "" {
78+
opts = append(opts, backend.WithOutputPath(pullConfig.OutputPath))
79+
}
80+
7681
if err := b.Pull(ctx, target, opts...); err != nil {
7782
return err
7883
}
Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,19 @@
1414
* limitations under the License.
1515
*/
1616

17-
package build
17+
package archiver
1818

1919
import (
2020
"archive/tar"
2121
"fmt"
2222
"io"
2323
"os"
24+
"path/filepath"
25+
"strings"
2426
)
2527

26-
// TarFileToStream tars the target file and return the content by stream.
27-
func TarFileToStream(path string) (io.Reader, error) {
28+
// Tar tars the target file and return the content by stream.
29+
func Tar(path string) (io.Reader, error) {
2830
pr, pw := io.Pipe()
2931
go func() {
3032
defer pw.Close()
@@ -65,3 +67,62 @@ func TarFileToStream(path string) (io.Reader, error) {
6567

6668
return pr, nil
6769
}
70+
71+
// Untar untars the target stream to the destination path.
72+
func Untar(reader io.Reader, destPath string) error {
73+
// uncompress gzip if it is a .tar.gz file
74+
// gzipReader, err := gzip.NewReader(reader)
75+
// if err != nil {
76+
// return err
77+
// }
78+
// defer gzipReader.Close()
79+
// tarReader := tar.NewReader(gzipReader)
80+
81+
tarReader := tar.NewReader(reader)
82+
83+
if err := os.MkdirAll(destPath, 0755); err != nil {
84+
return err
85+
}
86+
87+
for {
88+
header, err := tarReader.Next()
89+
if err == io.EOF {
90+
break
91+
}
92+
if err != nil {
93+
return err
94+
}
95+
96+
// sanitize filepaths to prevent directory traversal.
97+
cleanPath := filepath.Clean(header.Name)
98+
if strings.Contains(cleanPath, "..") {
99+
return fmt.Errorf("tar file contains invalid path: %s", cleanPath)
100+
}
101+
102+
path := filepath.Join(destPath, cleanPath)
103+
// check the file type.
104+
switch header.Typeflag {
105+
case tar.TypeDir:
106+
if err := os.MkdirAll(path, 0755); err != nil {
107+
return err
108+
}
109+
case tar.TypeReg:
110+
file, err := os.Create(path)
111+
if err != nil {
112+
return err
113+
}
114+
115+
if _, err := io.Copy(file, tarReader); err != nil {
116+
file.Close()
117+
return err
118+
}
119+
file.Close()
120+
121+
if err := os.Chmod(path, os.FileMode(header.Mode)); err != nil {
122+
return err
123+
}
124+
}
125+
}
126+
127+
return nil
128+
}

pkg/backend/build/build.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ import (
2424
"path/filepath"
2525
"time"
2626

27-
modelspec "github.com/CloudNativeAI/model-spec/specs-go/v1"
27+
"github.com/CloudNativeAI/modctl/pkg/archiver"
2828
"github.com/CloudNativeAI/modctl/pkg/storage"
29+
modelspec "github.com/CloudNativeAI/model-spec/specs-go/v1"
2930

3031
godigest "github.com/opencontainers/go-digest"
3132
spec "github.com/opencontainers/image-spec/specs-go"
@@ -40,7 +41,7 @@ type ModelConfig struct {
4041

4142
// BuildLayer converts the file to the image blob and push it to the storage.
4243
func BuildLayer(ctx context.Context, store storage.Storage, repo, path, workDir string) (ocispec.Descriptor, error) {
43-
reader, err := TarFileToStream(path)
44+
reader, err := archiver.Tar(path)
4445
if err != nil {
4546
return ocispec.Descriptor{}, fmt.Errorf("failed to tar file: %w", err)
4647
}

pkg/backend/options.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ package backend
1919
type Option func(*Options)
2020

2121
type Options struct {
22-
plainHTTP bool
23-
proxy string
22+
plainHTTP bool
23+
proxy string
24+
outputPath string
2425
}
2526

2627
// WithPlainHTTP sets the plain HTTP option.
@@ -36,3 +37,10 @@ func WithProxy(proxy string) Option {
3637
opts.proxy = proxy
3738
}
3839
}
40+
41+
// WithOutputPath sets the output path option.
42+
func WithOutputPath(outputPath string) Option {
43+
return func(opts *Options) {
44+
opts.outputPath = outputPath
45+
}
46+
}

pkg/backend/progress.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func (p *ProgressBar) Add(prompt string, desc ocispec.Descriptor, reader io.Read
5959

6060
// create a new bar if it does not exist.
6161
bar := p.mpb.New(desc.Size,
62-
mpbv8.BarStyle().Rbound("|"),
62+
mpbv8.BarStyle(),
6363
mpbv8.BarFillerOnComplete("|"),
6464
mpbv8.PrependDecorators(
6565
decor.Name(fmt.Sprintf("%s%s", prompt, desc.Digest.String())),

pkg/backend/pull.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,12 @@ import (
2323
"io"
2424
"net/http"
2525
"net/url"
26+
"path/filepath"
2627

28+
"github.com/CloudNativeAI/modctl/pkg/archiver"
2729
"github.com/CloudNativeAI/modctl/pkg/storage"
2830

31+
modelspec "github.com/CloudNativeAI/model-spec/specs-go/v1"
2932
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
3033
"oras.land/oras-go/v2/registry/remote"
3134
"oras.land/oras-go/v2/registry/remote/auth"
@@ -124,6 +127,13 @@ func (b *backend) Pull(ctx context.Context, target string, opts ...Option) error
124127
return fmt.Errorf("failed to pull manifest to local: %w", err)
125128
}
126129

130+
// export the target model artifact to the output directory if needed.
131+
if options.outputPath != "" {
132+
if err := exportModelArtifact(ctx, dst, manifest, repo, options.outputPath); err != nil {
133+
return fmt.Errorf("failed to export the artifact to the output directory: %w", err)
134+
}
135+
}
136+
127137
return nil
128138
}
129139

@@ -162,3 +172,30 @@ func pullIfNotExist(ctx context.Context, pb *ProgressBar, prompt string, src *re
162172

163173
return nil
164174
}
175+
176+
// exportModelArtifact exports the target model artifact to the output directory, which will open the artifact and extract to restore the orginal repo structure.
177+
func exportModelArtifact(ctx context.Context, store storage.Storage, manifest ocispec.Manifest, repo, outputPath string) error {
178+
for _, layer := range manifest.Layers {
179+
// pull the blob from the storage.
180+
reader, err := store.PullBlob(ctx, repo, layer.Digest.String())
181+
if err != nil {
182+
return fmt.Errorf("failed to pull the blob from storage: %w", err)
183+
}
184+
185+
defer reader.Close()
186+
187+
targetPath := outputPath
188+
// get the original filepath in order to restore the original repo structure.
189+
originalFilePath := layer.Annotations[modelspec.AnnotationFilepath]
190+
if dir := filepath.Dir(originalFilePath); dir != "" {
191+
targetPath = filepath.Join(targetPath, dir)
192+
}
193+
194+
// untar the blob to the output directory.
195+
if err := archiver.Untar(reader, targetPath); err != nil {
196+
return fmt.Errorf("failed to untar the blob to output directory: %w", err)
197+
}
198+
}
199+
200+
return nil
201+
}

pkg/backend/push.go

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package backend
1818

1919
import (
20+
"bytes"
2021
"context"
2122
"encoding/json"
2223
"fmt"
@@ -127,17 +128,16 @@ func pushIfNotExist(ctx context.Context, pb *ProgressBar, prompt string, src sto
127128
return nil
128129
}
129130

130-
// fetch the content from the source storage.
131-
content, err := src.PullBlob(ctx, repo, desc.Digest.String())
132-
if err != nil {
133-
return fmt.Errorf("failed to fetch the content from source: %w", err)
134-
}
135-
136-
defer content.Close()
137131
// push the content to the destination, and wrap the content reader for progress bar,
138132
// manifest should use dst.Manifests().Push, others should use dst.Blobs().Push.
139133
if desc.MediaType == ocispec.MediaTypeImageManifest {
140-
if err := dst.Manifests().Push(ctx, desc, pb.Add(prompt, desc, content)); err != nil {
134+
// fetch the manifest from the source storage.
135+
manifestRaw, _, err := src.PullManifest(ctx, repo, tag)
136+
if err != nil {
137+
return fmt.Errorf("failed to fetch the manifest from source: %w", err)
138+
}
139+
140+
if err := dst.Manifests().Push(ctx, desc, pb.Add(prompt, desc, bytes.NewReader(manifestRaw))); err != nil {
141141
return err
142142
}
143143

@@ -146,6 +146,14 @@ func pushIfNotExist(ctx context.Context, pb *ProgressBar, prompt string, src sto
146146
return err
147147
}
148148
} else {
149+
// fetch the content from the source storage.
150+
content, err := src.PullBlob(ctx, repo, desc.Digest.String())
151+
if err != nil {
152+
return fmt.Errorf("failed to fetch the content from source: %w", err)
153+
}
154+
155+
defer content.Close()
156+
149157
if err := dst.Blobs().Push(ctx, desc, pb.Add(prompt, desc, content)); err != nil {
150158
return err
151159
}

pkg/config/pull.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717
package config
1818

1919
type Pull struct {
20-
PlainHTTP bool
21-
Proxy string
20+
PlainHTTP bool
21+
Proxy string
22+
OutputPath string
2223
}
2324

2425
func NewPull() *Pull {
2526
return &Pull{
26-
PlainHTTP: false,
27-
Proxy: "",
27+
PlainHTTP: false,
28+
Proxy: "",
29+
OutputPath: "",
2830
}
2931
}

0 commit comments

Comments
 (0)