Skip to content

Commit 2026bea

Browse files
authored
feat: support export to output path for pull command (#36)
Signed-off-by: chlins <[email protected]>
1 parent 8c0cf95 commit 2026bea

File tree

14 files changed

+305
-18
lines changed

14 files changed

+305
-18
lines changed

cmd/extract.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright 2024 The CNAI Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package cmd
18+
19+
import (
20+
"context"
21+
"fmt"
22+
23+
"github.com/CloudNativeAI/modctl/pkg/backend"
24+
"github.com/CloudNativeAI/modctl/pkg/config"
25+
26+
"github.com/spf13/cobra"
27+
"github.com/spf13/viper"
28+
)
29+
30+
var extractConfig = config.NewExtract()
31+
32+
// extractCmd represents the modctl command for extract.
33+
var extractCmd = &cobra.Command{
34+
Use: "extract <target> --output <output>",
35+
Short: "A command line tool for modctl extract",
36+
Args: cobra.ExactArgs(1),
37+
DisableAutoGenTag: true,
38+
SilenceUsage: true,
39+
FParseErrWhitelist: cobra.FParseErrWhitelist{UnknownFlags: true},
40+
RunE: func(cmd *cobra.Command, args []string) error {
41+
return runExtract(context.Background(), args[0])
42+
},
43+
}
44+
45+
// init initializes extract command.
46+
func init() {
47+
flags := extractCmd.Flags()
48+
flags.StringVar(&extractConfig.Output, "output", "", "specify the output for extracting the model artifact")
49+
50+
if err := viper.BindPFlags(flags); err != nil {
51+
panic(fmt.Errorf("bind cache extract flags to viper: %w", err))
52+
}
53+
}
54+
55+
// runExtract runs the extract modctl.
56+
func runExtract(ctx context.Context, target string) error {
57+
b, err := backend.New()
58+
if err != nil {
59+
return err
60+
}
61+
62+
if target == "" {
63+
return fmt.Errorf("target is required")
64+
}
65+
66+
if extractConfig.Output == "" {
67+
return fmt.Errorf("output is required")
68+
}
69+
70+
if err := b.Extract(ctx, target, extractConfig.Output); err != nil {
71+
return err
72+
}
73+
74+
fmt.Printf("Successfully extracted model artifact %s to %s\n", target, extractConfig.Output)
75+
return nil
76+
}

cmd/pull.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ func init() {
4747
flags := pullCmd.Flags()
4848
flags.BoolVar(&pullConfig.PlainHTTP, "plain-http", false, "use plain HTTP instead of HTTPS")
4949
flags.StringVar(&pullConfig.Proxy, "proxy", "", "use proxy for the pull operation")
50+
flags.StringVar(&pullConfig.ExtractDir, "extract-dir", "", "specify the extract dir for extracting the model artifact")
5051

5152
if err := viper.BindPFlags(flags); err != nil {
5253
panic(fmt.Errorf("bind cache pull flags to viper: %w", err))
@@ -73,6 +74,10 @@ func runPull(ctx context.Context, target string) error {
7374
opts = append(opts, backend.WithProxy(pullConfig.Proxy))
7475
}
7576

77+
if pullConfig.ExtractDir != "" {
78+
opts = append(opts, backend.WithOutput(pullConfig.ExtractDir))
79+
}
80+
7681
if err := b.Pull(ctx, target, opts...); err != nil {
7782
return err
7883
}

cmd/root.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,5 @@ func init() {
6666
rootCmd.AddCommand(rmCmd)
6767
rootCmd.AddCommand(pruneCmd)
6868
rootCmd.AddCommand(inspectCmd)
69+
rootCmd.AddCommand(extractCmd)
6970
}

docs/getting-started.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,14 @@ Push the model artifact to the registry:
8282
$ modctl push registry.com/models/llama3:v1.0.0
8383
```
8484

85+
### Extract
86+
87+
Extract the model artifact to the specified directory:
88+
89+
```shell
90+
$ modctl extract registry.com/models/llama3:v1.0.0 --output /path/to/extract
91+
```
92+
8593
### List
8694

8795
List the model artifacts in the local storage:
Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,19 @@
1414
* limitations under the License.
1515
*/
1616

17-
package build
17+
package archiver
1818

1919
import (
2020
"archive/tar"
2121
"fmt"
2222
"io"
2323
"os"
24+
"path/filepath"
25+
"strings"
2426
)
2527

26-
// TarFileToStream tars the target file and return the content by stream.
27-
func TarFileToStream(path string) (io.Reader, error) {
28+
// Tar tars the target file and return the content by stream.
29+
func Tar(path string) (io.Reader, error) {
2830
pr, pw := io.Pipe()
2931
go func() {
3032
defer pw.Close()
@@ -65,3 +67,62 @@ func TarFileToStream(path string) (io.Reader, error) {
6567

6668
return pr, nil
6769
}
70+
71+
// Untar untars the target stream to the destination path.
72+
func Untar(reader io.Reader, destPath string) error {
73+
// uncompress gzip if it is a .tar.gz file
74+
// gzipReader, err := gzip.NewReader(reader)
75+
// if err != nil {
76+
// return err
77+
// }
78+
// defer gzipReader.Close()
79+
// tarReader := tar.NewReader(gzipReader)
80+
81+
tarReader := tar.NewReader(reader)
82+
83+
if err := os.MkdirAll(destPath, 0755); err != nil {
84+
return err
85+
}
86+
87+
for {
88+
header, err := tarReader.Next()
89+
if err == io.EOF {
90+
break
91+
}
92+
if err != nil {
93+
return err
94+
}
95+
96+
// sanitize filepaths to prevent directory traversal.
97+
cleanPath := filepath.Clean(header.Name)
98+
if strings.Contains(cleanPath, "..") {
99+
return fmt.Errorf("tar file contains invalid path: %s", cleanPath)
100+
}
101+
102+
path := filepath.Join(destPath, cleanPath)
103+
// check the file type.
104+
switch header.Typeflag {
105+
case tar.TypeDir:
106+
if err := os.MkdirAll(path, 0755); err != nil {
107+
return err
108+
}
109+
case tar.TypeReg:
110+
file, err := os.Create(path)
111+
if err != nil {
112+
return err
113+
}
114+
115+
if _, err := io.Copy(file, tarReader); err != nil {
116+
file.Close()
117+
return err
118+
}
119+
file.Close()
120+
121+
if err := os.Chmod(path, os.FileMode(header.Mode)); err != nil {
122+
return err
123+
}
124+
}
125+
}
126+
127+
return nil
128+
}

pkg/backend/backend.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ type Backend interface {
5050

5151
// Inspect inspects the model artifact.
5252
Inspect(ctx context.Context, target string) (*InspectedModelArtifact, error)
53+
54+
// Extract extracts the model artifact.
55+
Extract(ctx context.Context, target string, output string) error
5356
}
5457

5558
// backend is the implementation of Backend.

pkg/backend/build/build.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ import (
2424
"path/filepath"
2525
"time"
2626

27-
modelspec "github.com/CloudNativeAI/model-spec/specs-go/v1"
27+
"github.com/CloudNativeAI/modctl/pkg/archiver"
2828
"github.com/CloudNativeAI/modctl/pkg/storage"
29+
modelspec "github.com/CloudNativeAI/model-spec/specs-go/v1"
2930

3031
godigest "github.com/opencontainers/go-digest"
3132
spec "github.com/opencontainers/image-spec/specs-go"
@@ -40,7 +41,7 @@ type ModelConfig struct {
4041

4142
// BuildLayer converts the file to the image blob and push it to the storage.
4243
func BuildLayer(ctx context.Context, store storage.Storage, repo, path, workDir string) (ocispec.Descriptor, error) {
43-
reader, err := TarFileToStream(path)
44+
reader, err := archiver.Tar(path)
4445
if err != nil {
4546
return ocispec.Descriptor{}, fmt.Errorf("failed to tar file: %w", err)
4647
}

pkg/backend/extract.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright 2024 The CNAI Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package backend
18+
19+
import (
20+
"context"
21+
"encoding/json"
22+
"fmt"
23+
"path/filepath"
24+
25+
"github.com/CloudNativeAI/modctl/pkg/archiver"
26+
"github.com/CloudNativeAI/modctl/pkg/storage"
27+
modelspec "github.com/CloudNativeAI/model-spec/specs-go/v1"
28+
29+
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
30+
)
31+
32+
// Extract extracts the model artifact.
33+
func (b *backend) Extract(ctx context.Context, target string, output string) error {
34+
// parse the repository and tag from the target.
35+
ref, err := ParseReference(target)
36+
if err != nil {
37+
return fmt.Errorf("failed to parse the target: %w", err)
38+
}
39+
40+
repo, tag := ref.Repository(), ref.Tag()
41+
// pull the manifest from the storage.
42+
manifestRaw, _, err := b.store.PullManifest(ctx, repo, tag)
43+
if err != nil {
44+
return fmt.Errorf("failed to pull the manifest from storage: %w", err)
45+
}
46+
// unmarshal the manifest.
47+
var manifest ocispec.Manifest
48+
if err := json.Unmarshal(manifestRaw, &manifest); err != nil {
49+
return fmt.Errorf("failed to unmarshal the manifest: %w", err)
50+
}
51+
52+
return exportModelArtifact(ctx, b.store, manifest, repo, output)
53+
}
54+
55+
// exportModelArtifact exports the target model artifact to the output directory, which will open the artifact and extract to restore the orginal repo structure.
56+
func exportModelArtifact(ctx context.Context, store storage.Storage, manifest ocispec.Manifest, repo, output string) error {
57+
for _, layer := range manifest.Layers {
58+
// pull the blob from the storage.
59+
reader, err := store.PullBlob(ctx, repo, layer.Digest.String())
60+
if err != nil {
61+
return fmt.Errorf("failed to pull the blob from storage: %w", err)
62+
}
63+
64+
defer reader.Close()
65+
66+
targetDir := output
67+
// get the original filepath in order to restore the original repo structure.
68+
originalFilePath := layer.Annotations[modelspec.AnnotationFilepath]
69+
if dir := filepath.Dir(originalFilePath); dir != "" {
70+
targetDir = filepath.Join(targetDir, dir)
71+
}
72+
73+
// untar the blob to the target directory.
74+
if err := archiver.Untar(reader, targetDir); err != nil {
75+
return fmt.Errorf("failed to untar the blob to output directory: %w", err)
76+
}
77+
}
78+
79+
return nil
80+
}

pkg/backend/options.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ type Option func(*Options)
2121
type Options struct {
2222
plainHTTP bool
2323
proxy string
24+
output string
2425
}
2526

2627
// WithPlainHTTP sets the plain HTTP option.
@@ -36,3 +37,10 @@ func WithProxy(proxy string) Option {
3637
opts.proxy = proxy
3738
}
3839
}
40+
41+
// WithOutput sets the output option.
42+
func WithOutput(output string) Option {
43+
return func(opts *Options) {
44+
opts.output = output
45+
}
46+
}

pkg/backend/progress.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func (p *ProgressBar) Add(prompt string, desc ocispec.Descriptor, reader io.Read
5959

6060
// create a new bar if it does not exist.
6161
bar := p.mpb.New(desc.Size,
62-
mpbv8.BarStyle().Rbound("|"),
62+
mpbv8.BarStyle(),
6363
mpbv8.BarFillerOnComplete("|"),
6464
mpbv8.PrependDecorators(
6565
decor.Name(fmt.Sprintf("%s%s", prompt, desc.Digest.String())),

0 commit comments

Comments
 (0)