Skip to content

Commit c4ea1c7

Browse files
authored
feat: support tag sub command (#90)
1 parent eb9f4b2 commit c4ea1c7

File tree

11 files changed

+504
-7
lines changed

11 files changed

+504
-7
lines changed

Makefile

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ LOCALBIN ?= $(shell pwd)/bin
9898
$(LOCALBIN):
9999
mkdir -p $(LOCALBIN)
100100

101+
MOCKERY_VERSION=v2.52.1
102+
101103
.PHONY: gen
102104
gen: gen-mockery## Generate all we need!
103105

@@ -107,15 +109,15 @@ gen-mockery: check-mockery ## Generate mockery code
107109
@mockery
108110

109111
check-mockery:
110-
@which mockery > /dev/null || { echo "mockery not found. Trying to install via Homebrew..."; $(MAKE) install-mockery; }
111-
@mockery --version | grep -q "2.46.3" || { echo "mockery version is not v2.46.3. Trying to install the correct version..."; $(MAKE) install-mockery; }
112+
@which mockery > /dev/null || { echo "mockery not found. Trying to install via go install..."; $(MAKE) install-mockery; }
113+
@mockery --version | grep -q $(MOCKERY_VERSION) || { echo "mockery version is not $(MOCKERY_VERSION). Trying to install the correct version..."; $(MAKE) install-mockery; }
112114

113115
install-mockery:
114-
@if command -v brew > /dev/null; then \
115-
echo "Installing mockery via Homebrew"; \
116-
brew install mockery; \
116+
@if command -v go > /dev/null; then \
117+
echo "Installing mockery via go install"; \
118+
go install github.com/vektra/mockery/v2@$(MOCKERY_VERSION); \
117119
else \
118-
echo "Error: Homebrew is not installed. Please install Homebrew first and ensure it's in your PATH."; \
120+
echo "Error: Golang is not installed. Please install golang first and ensure it's in your PATH."; \
119121
exit 1; \
120122
fi
121123

cmd/root.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,5 @@ func init() {
7979
rootCmd.AddCommand(inspectCmd)
8080
rootCmd.AddCommand(extractCmd)
8181
rootCmd.AddCommand(modelfileGenCmd)
82+
rootCmd.AddCommand(tagCmd)
8283
}

cmd/tag.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Copyright 2024 The CNAI Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package cmd
18+
19+
import (
20+
"context"
21+
"fmt"
22+
23+
"github.com/CloudNativeAI/modctl/pkg/backend"
24+
25+
"github.com/spf13/cobra"
26+
"github.com/spf13/viper"
27+
)
28+
29+
// tagCmd represents the modctl command for tag.
30+
var tagCmd = &cobra.Command{
31+
Use: "tag [flags] <source> <target>",
32+
Short: "A command line tool for modctl tag",
33+
Args: cobra.ExactArgs(2),
34+
DisableAutoGenTag: true,
35+
SilenceUsage: true,
36+
FParseErrWhitelist: cobra.FParseErrWhitelist{UnknownFlags: true},
37+
RunE: func(cmd *cobra.Command, args []string) error {
38+
return runTag(context.Background(), args[0], args[1])
39+
},
40+
}
41+
42+
// init initializes tag command.
43+
func init() {
44+
flags := tagCmd.Flags()
45+
46+
if err := viper.BindPFlags(flags); err != nil {
47+
panic(fmt.Errorf("bind cache tag flags to viper: %w", err))
48+
}
49+
}
50+
51+
// runTag runs the tag modctl.
52+
func runTag(ctx context.Context, source, target string) error {
53+
b, err := backend.New(rootConfig.StoargeDir)
54+
if err != nil {
55+
return err
56+
}
57+
58+
if source == "" || target == "" {
59+
return fmt.Errorf("source and target are required")
60+
}
61+
62+
return b.Tag(ctx, source, target)
63+
}

pkg/backend/backend.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ type Backend interface {
5353

5454
// Extract extracts the model artifact.
5555
Extract(ctx context.Context, target string, output string) error
56+
57+
// Tag creates a new tag that refers to the source model artifact.
58+
Tag(ctx context.Context, source, target string) error
5659
}
5760

5861
// backend is the implementation of Backend.

pkg/backend/tag.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Copyright 2024 The CNAI Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package backend
18+
19+
import (
20+
"context"
21+
"encoding/json"
22+
"fmt"
23+
24+
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
25+
)
26+
27+
// Tag creates a new tag that refers to the source model artifact.
28+
func (b *backend) Tag(ctx context.Context, source, target string) error {
29+
srcRef, err := ParseReference(source)
30+
if err != nil {
31+
return fmt.Errorf("failed to parse source: %w", err)
32+
}
33+
34+
targetRef, err := ParseReference(target)
35+
if err != nil {
36+
return fmt.Errorf("failed to parse target: %w", err)
37+
}
38+
39+
manifestRaw, _, err := b.store.PullManifest(ctx, srcRef.Repository(), srcRef.Tag())
40+
if err != nil {
41+
return fmt.Errorf("failed to pull manifest: %w", err)
42+
}
43+
44+
var manifest ocispec.Manifest
45+
if err := json.Unmarshal(manifestRaw, &manifest); err != nil {
46+
return fmt.Errorf("failed to unmarshal manifest: %w", err)
47+
}
48+
49+
// mount the blob from source.
50+
layers := []ocispec.Descriptor{manifest.Config}
51+
for _, layer := range manifest.Layers {
52+
layers = append(layers, layer)
53+
}
54+
55+
for _, layer := range layers {
56+
if err := b.store.MountBlob(ctx, srcRef.Repository(), targetRef.Repository(), layer); err != nil {
57+
return fmt.Errorf("failed to mount blob %s: %w", layer.Digest.String(), err)
58+
}
59+
}
60+
61+
if _, err := b.store.PushManifest(ctx, targetRef.Repository(), targetRef.Tag(), manifestRaw); err != nil {
62+
return fmt.Errorf("failed to push manifest: %w", err)
63+
}
64+
65+
return nil
66+
}

pkg/backend/tag_test.go

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
/*
2+
* Copyright 2024 The CNAI Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package backend
18+
19+
import (
20+
"context"
21+
"encoding/json"
22+
"errors"
23+
"testing"
24+
25+
"github.com/CloudNativeAI/modctl/test/mocks/storage"
26+
27+
v1 "github.com/opencontainers/image-spec/specs-go/v1"
28+
"github.com/stretchr/testify/assert"
29+
"github.com/stretchr/testify/mock"
30+
)
31+
32+
func TestTag(t *testing.T) {
33+
tests := []struct {
34+
name string
35+
source string
36+
target string
37+
setupMocks func(*storage.Storage)
38+
expectedErr string
39+
}{
40+
{
41+
name: "successful tag",
42+
source: "localhost:5000/repo:tag1",
43+
target: "localhost:5000/repo:tag2",
44+
setupMocks: func(s *storage.Storage) {
45+
manifest := v1.Manifest{
46+
Config: v1.Descriptor{
47+
MediaType: "application/vnd.oci.image.config.v1+json",
48+
Digest: "sha256:config",
49+
Size: 100,
50+
},
51+
Layers: []v1.Descriptor{
52+
{
53+
MediaType: "application/vnd.oci.image.layer.v1.tar+gzip",
54+
Digest: "sha256:layer1",
55+
Size: 200,
56+
},
57+
{
58+
MediaType: "application/vnd.oci.image.layer.v1.tar+gzip",
59+
Digest: "sha256:layer2",
60+
Size: 300,
61+
},
62+
},
63+
}
64+
manifestBytes, _ := json.Marshal(manifest)
65+
s.On("PullManifest", mock.Anything, "localhost:5000/repo", "tag1").
66+
Return(manifestBytes, "sha256:manifest", nil)
67+
68+
s.On("MountBlob", mock.Anything, "localhost:5000/repo", "localhost:5000/repo", manifest.Config).
69+
Return(nil)
70+
71+
for _, layer := range manifest.Layers {
72+
s.On("MountBlob", mock.Anything, "localhost:5000/repo", "localhost:5000/repo", layer).
73+
Return(nil)
74+
}
75+
76+
s.On("PushManifest", mock.Anything, "localhost:5000/repo", "tag2", manifestBytes).
77+
Return("sha256:manifest", nil)
78+
},
79+
expectedErr: "",
80+
},
81+
{
82+
name: "invalid source reference",
83+
source: "invalid-reference",
84+
target: "localhost:5000/repo:tag2",
85+
setupMocks: func(s *storage.Storage) {
86+
// No mocks needed as we expect to fail before hitting the storage
87+
},
88+
expectedErr: "failed to parse source",
89+
},
90+
{
91+
name: "invalid target reference",
92+
source: "localhost:5000/repo:tag1",
93+
target: "invalid-reference",
94+
setupMocks: func(s *storage.Storage) {
95+
// No mocks needed as we expect to fail before hitting the storage
96+
},
97+
expectedErr: "failed to parse target",
98+
},
99+
{
100+
name: "pull manifest error",
101+
source: "localhost:5000/repo:tag1",
102+
target: "localhost:5000/repo:tag2",
103+
setupMocks: func(s *storage.Storage) {
104+
s.On("PullManifest", mock.Anything, "localhost:5000/repo", "tag1").
105+
Return([]byte{}, "", errors.New("manifest not found"))
106+
},
107+
expectedErr: "failed to pull manifest",
108+
},
109+
{
110+
name: "mount blob error",
111+
source: "localhost:5000/repo:tag1",
112+
target: "localhost:5000/repo:tag2",
113+
setupMocks: func(s *storage.Storage) {
114+
manifest := v1.Manifest{
115+
Config: v1.Descriptor{
116+
MediaType: "application/vnd.oci.image.config.v1+json",
117+
Digest: "sha256:config",
118+
Size: 100,
119+
},
120+
Layers: []v1.Descriptor{
121+
{
122+
MediaType: "application/vnd.oci.image.layer.v1.tar+gzip",
123+
Digest: "sha256:layer1",
124+
Size: 200,
125+
},
126+
},
127+
}
128+
manifestBytes, _ := json.Marshal(manifest)
129+
130+
s.On("PullManifest", mock.Anything, "localhost:5000/repo", "tag1").
131+
Return(manifestBytes, "sha256:manifest", nil)
132+
133+
s.On("MountBlob", mock.Anything, "localhost:5000/repo", "localhost:5000/repo", manifest.Config).
134+
Return(errors.New("mount blob failed"))
135+
},
136+
expectedErr: "failed to mount blob",
137+
},
138+
{
139+
name: "push manifest error",
140+
source: "localhost:5000/repo:tag1",
141+
target: "localhost:5000/repo:tag2",
142+
setupMocks: func(s *storage.Storage) {
143+
manifest := v1.Manifest{
144+
Config: v1.Descriptor{
145+
MediaType: "application/vnd.oci.image.config.v1+json",
146+
Digest: "sha256:config",
147+
Size: 100,
148+
},
149+
Layers: []v1.Descriptor{},
150+
}
151+
manifestBytes, _ := json.Marshal(manifest)
152+
s.On("PullManifest", mock.Anything, "localhost:5000/repo", "tag1").
153+
Return(manifestBytes, "sha256:manifest", nil)
154+
155+
s.On("MountBlob", mock.Anything, "localhost:5000/repo", "localhost:5000/repo", manifest.Config).
156+
Return(nil)
157+
158+
s.On("PushManifest", mock.Anything, "localhost:5000/repo", "tag2", manifestBytes).
159+
Return("", errors.New("push manifest failed"))
160+
},
161+
expectedErr: "failed to push manifest",
162+
},
163+
{
164+
name: "invalid manifest json",
165+
source: "localhost:5000/repo:tag1",
166+
target: "localhost:5000/repo:tag2",
167+
setupMocks: func(s *storage.Storage) {
168+
// Return invalid JSON as manifest
169+
s.On("PullManifest", mock.Anything, "localhost:5000/repo", "tag1").
170+
Return([]byte{123}, "sha256:invalid", nil)
171+
},
172+
expectedErr: "failed to unmarshal manifest",
173+
},
174+
}
175+
176+
for _, tt := range tests {
177+
t.Run(tt.name, func(t *testing.T) {
178+
mockStorage := storage.NewStorage(t)
179+
tt.setupMocks(mockStorage)
180+
181+
b := &backend{
182+
store: mockStorage,
183+
}
184+
185+
err := b.Tag(context.Background(), tt.source, tt.target)
186+
if tt.expectedErr != "" {
187+
assert.Error(t, err)
188+
assert.Contains(t, err.Error(), tt.expectedErr)
189+
} else {
190+
assert.NoError(t, err)
191+
}
192+
})
193+
}
194+
}

0 commit comments

Comments
 (0)