Skip to content

Commit 8534718

Browse files
authored
feat: support inspect the model artifact from remote registry (#189)
Signed-off-by: chlins <[email protected]>
1 parent eb286b8 commit 8534718

File tree

7 files changed

+68
-35
lines changed

7 files changed

+68
-35
lines changed

cmd/inspect.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,14 @@ import (
2222
"fmt"
2323

2424
"github.com/CloudNativeAI/modctl/pkg/backend"
25+
"github.com/CloudNativeAI/modctl/pkg/config"
2526

2627
"github.com/spf13/cobra"
2728
"github.com/spf13/viper"
2829
)
2930

31+
var inspectConfig = config.NewInspect()
32+
3033
// inspectCmd represents the modctl command for inspect.
3134
var inspectCmd = &cobra.Command{
3235
Use: "inspect [flags] <target>",
@@ -42,7 +45,10 @@ var inspectCmd = &cobra.Command{
4245

4346
// init initializes inspect command.
4447
func init() {
45-
flags := rmCmd.Flags()
48+
flags := inspectCmd.Flags()
49+
flags.BoolVar(&inspectConfig.Remote, "remote", false, "inspect model artifact from remote registry")
50+
flags.BoolVar(&inspectConfig.PlainHTTP, "plain-http", false, "use plain HTTP instead of HTTPS")
51+
flags.BoolVar(&inspectConfig.Insecure, "insecure", false, "allow insecure connections")
4652

4753
if err := viper.BindPFlags(flags); err != nil {
4854
panic(fmt.Errorf("bind cache inspect flags to viper: %w", err))
@@ -60,7 +66,7 @@ func runInspect(ctx context.Context, target string) error {
6066
return fmt.Errorf("target is required")
6167
}
6268

63-
inspected, err := b.Inspect(ctx, target)
69+
inspected, err := b.Inspect(ctx, target, inspectConfig)
6470
if err != nil {
6571
return err
6672
}

pkg/backend/attach.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,12 @@ var (
5858

5959
// Attach attaches user materials into the model artifact which follows the Model Spec.
6060
func (b *backend) Attach(ctx context.Context, filepath string, cfg *config.Attach) error {
61-
srcManifest, err := b.getManifest(ctx, cfg.Source, cfg)
61+
srcManifest, err := b.getManifest(ctx, cfg.Source, cfg.OutputRemote, cfg.PlainHTTP, cfg.Insecure)
6262
if err != nil {
6363
return fmt.Errorf("failed to get source manifest: %w", err)
6464
}
6565

66-
srcModelConfig, err := b.getModelConfig(ctx, cfg.Source, srcManifest.Config, cfg)
66+
srcModelConfig, err := b.getModelConfig(ctx, cfg.Source, srcManifest.Config, cfg.OutputRemote, cfg.PlainHTTP, cfg.Insecure)
6767
if err != nil {
6868
return fmt.Errorf("failed to get source model config: %w", err)
6969
}
@@ -169,7 +169,7 @@ func (b *backend) Attach(ctx context.Context, filepath string, cfg *config.Attac
169169
return nil
170170
}
171171

172-
func (b *backend) getManifest(ctx context.Context, reference string, cfg *config.Attach) (*ocispec.Manifest, error) {
172+
func (b *backend) getManifest(ctx context.Context, reference string, fromRemote, plainHTTP, insecure bool) (*ocispec.Manifest, error) {
173173
ref, err := ParseReference(reference)
174174
if err != nil {
175175
return nil, fmt.Errorf("failed to parse source reference: %w", err)
@@ -181,7 +181,7 @@ func (b *backend) getManifest(ctx context.Context, reference string, cfg *config
181181
}
182182

183183
// Fetch from local storage if it is not remote.
184-
if !cfg.OutputRemote {
184+
if !fromRemote {
185185
manifestRaw, _, err := b.store.PullManifest(ctx, repo, tag)
186186
if err != nil {
187187
return nil, fmt.Errorf("failed to pull manifest: %w", err)
@@ -195,7 +195,7 @@ func (b *backend) getManifest(ctx context.Context, reference string, cfg *config
195195
return &manifest, nil
196196
}
197197

198-
client, err := remote.New(repo, remote.WithPlainHTTP(cfg.PlainHTTP), remote.WithInsecure(cfg.Insecure))
198+
client, err := remote.New(repo, remote.WithPlainHTTP(plainHTTP), remote.WithInsecure(insecure))
199199
if err != nil {
200200
return nil, fmt.Errorf("failed to create remote client: %w", err)
201201
}
@@ -214,7 +214,7 @@ func (b *backend) getManifest(ctx context.Context, reference string, cfg *config
214214
return &manifest, nil
215215
}
216216

217-
func (b *backend) getModelConfig(ctx context.Context, reference string, desc ocispec.Descriptor, cfg *config.Attach) (*modelspec.Model, error) {
217+
func (b *backend) getModelConfig(ctx context.Context, reference string, desc ocispec.Descriptor, fromRemote, plainHTTP, insecure bool) (*modelspec.Model, error) {
218218
ref, err := ParseReference(reference)
219219
if err != nil {
220220
return nil, fmt.Errorf("failed to parse reference: %w", err)
@@ -226,7 +226,7 @@ func (b *backend) getModelConfig(ctx context.Context, reference string, desc oci
226226
}
227227

228228
// Fetch from local storage if it is not remote.
229-
if !cfg.OutputRemote {
229+
if !fromRemote {
230230
reader, err := b.store.PullBlob(ctx, repo, desc.Digest.String())
231231
if err != nil {
232232
return nil, fmt.Errorf("failed to pull blob: %w", err)
@@ -241,7 +241,7 @@ func (b *backend) getModelConfig(ctx context.Context, reference string, desc oci
241241
return &model, nil
242242
}
243243

244-
client, err := remote.New(repo, remote.WithPlainHTTP(cfg.PlainHTTP), remote.WithInsecure(cfg.Insecure))
244+
client, err := remote.New(repo, remote.WithPlainHTTP(plainHTTP), remote.WithInsecure(insecure))
245245
if err != nil {
246246
return nil, fmt.Errorf("failed to create remote client: %w", err)
247247
}

pkg/backend/attach_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,15 @@ func TestBackendGetManifest(t *testing.T) {
4444
mockStore.On("PullManifest", ctx, "localhost/repo", "tag").Return(manifestBytes, "", nil)
4545

4646
cfg := &config.Attach{OutputRemote: false}
47-
result, err := b.getManifest(ctx, "localhost/repo:tag", cfg)
47+
result, err := b.getManifest(ctx, "localhost/repo:tag", cfg.OutputRemote, cfg.PlainHTTP, cfg.Insecure)
4848
assert.NoError(t, err)
4949
assert.Equal(t, manifest.Layers, result.Layers)
5050
mockStore.AssertExpectations(t)
5151
})
5252

5353
t.Run("InvalidReference", func(t *testing.T) {
5454
cfg := &config.Attach{OutputRemote: false}
55-
_, err := b.getManifest(ctx, "invalid", cfg)
55+
_, err := b.getManifest(ctx, "invalid", cfg.OutputRemote, cfg.PlainHTTP, cfg.Insecure)
5656
assert.Error(t, err)
5757
assert.Contains(t, err.Error(), "failed to parse source reference")
5858
})

pkg/backend/backend.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ type Backend interface {
5656
Prune(ctx context.Context, dryRun, removeUntagged bool) error
5757

5858
// Inspect inspects the model artifact.
59-
Inspect(ctx context.Context, target string) (*InspectedModelArtifact, error)
59+
Inspect(ctx context.Context, target string, cfg *config.Inspect) (*InspectedModelArtifact, error)
6060

6161
// Extract extracts the model artifact.
6262
Extract(ctx context.Context, target string, cfg *config.Extract) error

pkg/backend/inspect.go

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ import (
2222
"fmt"
2323
"time"
2424

25+
godigest "github.com/opencontainers/go-digest"
26+
27+
"github.com/CloudNativeAI/modctl/pkg/config"
2528
modelspec "github.com/CloudNativeAI/model-spec/specs-go/v1"
26-
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
2729
)
2830

2931
// InspectedModelArtifact is the data structure for model artifact that has been inspected.
@@ -63,38 +65,30 @@ type InspectedModelArtifactLayer struct {
6365
}
6466

6567
// Inspect inspects the target from the storage.
66-
func (b *backend) Inspect(ctx context.Context, target string) (*InspectedModelArtifact, error) {
67-
ref, err := ParseReference(target)
68+
func (b *backend) Inspect(ctx context.Context, target string, cfg *config.Inspect) (*InspectedModelArtifact, error) {
69+
_, err := ParseReference(target)
6870
if err != nil {
6971
return nil, fmt.Errorf("failed to parse target: %w", err)
7072
}
7173

72-
repo, tag := ref.Repository(), ref.Tag()
73-
manifestRaw, digest, err := b.store.PullManifest(ctx, repo, tag)
74+
manifest, err := b.getManifest(ctx, target, cfg.Remote, cfg.PlainHTTP, cfg.Insecure)
7475
if err != nil {
7576
return nil, fmt.Errorf("failed to get manifest: %w", err)
7677
}
7778

78-
var manifest ocispec.Manifest
79-
if err := json.Unmarshal(manifestRaw, &manifest); err != nil {
80-
return nil, fmt.Errorf("failed to unmarshal manifest: %w", err)
81-
}
82-
83-
// fetch and parse the model config.
84-
configReader, err := b.store.PullBlob(ctx, repo, manifest.Config.Digest.String())
79+
manifestRaw, err := json.Marshal(manifest)
8580
if err != nil {
86-
return nil, fmt.Errorf("failed to pull config: %w", err)
81+
return nil, fmt.Errorf("failed to marshal manifest: %w", err)
8782
}
8883

89-
defer configReader.Close()
90-
var config modelspec.Model
91-
if err := json.NewDecoder(configReader).Decode(&config); err != nil {
92-
return nil, fmt.Errorf("failed to decode config: %w", err)
84+
config, err := b.getModelConfig(ctx, target, manifest.Config, cfg.Remote, cfg.PlainHTTP, cfg.Insecure)
85+
if err != nil {
86+
return nil, fmt.Errorf("failed to get config: %w", err)
9387
}
9488

9589
inspectedModelArtifact := &InspectedModelArtifact{
9690
ID: manifest.Config.Digest.String(),
97-
Digest: digest,
91+
Digest: godigest.FromBytes(manifestRaw).String(),
9892
Architecture: config.Config.Architecture,
9993
Family: config.Descriptor.Family,
10094
Format: config.Config.Format,

pkg/backend/inspect_test.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ import (
2222
"io"
2323
"testing"
2424

25-
"github.com/CloudNativeAI/modctl/test/mocks/storage"
2625
"github.com/stretchr/testify/assert"
26+
27+
pkgconfig "github.com/CloudNativeAI/modctl/pkg/config"
28+
"github.com/CloudNativeAI/modctl/test/mocks/storage"
2729
)
2830

2931
func TestInspect(t *testing.T) {
@@ -126,13 +128,13 @@ func TestInspect(t *testing.T) {
126128
}
127129
}`
128130

129-
mockStore.On("PullManifest", ctx, "example.com/repo", "tag").Return([]byte(manifest), "sha256:2bc8836f5910ec63a01109e20db67c2ad7706cb19bef5a303bc86fa5572ec9a2", nil)
131+
mockStore.On("PullManifest", ctx, "example.com/repo", "tag").Return([]byte(manifest), "sha256:9ca701e8784e5656e2c36f10f82410a0af4c44f859590a28a3d1519ee1eea89d", nil)
130132
mockStore.On("PullBlob", ctx, "example.com/repo", "sha256:e31b55920173ba79526491fbd01efe609c1d0d72c3a83df85b2c4fe74df2eea2").Return(io.NopCloser(bytes.NewReader([]byte(config))), nil)
131133

132-
inspected, err := b.Inspect(ctx, target)
134+
inspected, err := b.Inspect(ctx, target, &pkgconfig.Inspect{})
133135
assert.NoError(t, err)
134136
assert.Equal(t, "sha256:e31b55920173ba79526491fbd01efe609c1d0d72c3a83df85b2c4fe74df2eea2", inspected.ID)
135-
assert.Equal(t, "sha256:2bc8836f5910ec63a01109e20db67c2ad7706cb19bef5a303bc86fa5572ec9a2", inspected.Digest)
137+
assert.Equal(t, "sha256:9ca701e8784e5656e2c36f10f82410a0af4c44f859590a28a3d1519ee1eea89d", inspected.Digest)
136138
assert.Equal(t, "transformer", inspected.Architecture)
137139
assert.Equal(t, "2025-02-12T17:01:43+08:00", inspected.CreatedAt)
138140
assert.Equal(t, "qwen2", inspected.Family)

pkg/config/inspect.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright 2024 The CNAI Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package config
18+
19+
type Inspect struct {
20+
Remote bool
21+
PlainHTTP bool
22+
Insecure bool
23+
}
24+
25+
func NewInspect() *Inspect {
26+
return &Inspect{
27+
Remote: false,
28+
PlainHTTP: false,
29+
Insecure: false,
30+
}
31+
}

0 commit comments

Comments
 (0)