Skip to content

Commit e698c62

Browse files
authored
feat: support attach the config to existed model artifact (#225)
Signed-off-by: chlins <[email protected]>
1 parent 53d386a commit e698c62

File tree

7 files changed

+145
-120
lines changed

7 files changed

+145
-120
lines changed

cmd/attach.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ func init() {
6060
flags.BoolVar(&attachConfig.Nydusify, "nydusify", false, "[EXPERIMENTAL] nydusify the model artifact")
6161
flags.MarkHidden("nydusify")
6262
flags.BoolVar(&attachConfig.Raw, "raw", false, "turning on this flag will attach model artifact layer in raw format")
63+
flags.BoolVar(&attachConfig.Config, "config", false, "turning on this flag will overwrite model artifact config layer")
6364

6465
if err := viper.BindPFlags(flags); err != nil {
6566
panic(fmt.Errorf("bind cache list flags to viper: %w", err))

pkg/backend/attach.go

Lines changed: 73 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"encoding/json"
2222
"fmt"
2323
"io"
24+
"os"
2425
"reflect"
2526
"slices"
2627
"sort"
@@ -73,33 +74,6 @@ func (b *backend) Attach(ctx context.Context, filepath string, cfg *config.Attac
7374

7475
logrus.Infof("attach: loaded source model config [%+v]", srcModelConfig)
7576

76-
var foundLayer *ocispec.Descriptor
77-
for _, layer := range srcManifest.Layers {
78-
if anno := layer.Annotations; anno != nil {
79-
if anno[modelspec.AnnotationFilepath] == filepath {
80-
if !cfg.Force {
81-
return fmt.Errorf("file %s already exists, please use --force to overwrite if you want to attach it forcibly", filepath)
82-
}
83-
84-
foundLayer = &layer
85-
break
86-
}
87-
}
88-
}
89-
90-
logrus.Infof("attach: found existing layer for file %s [%+v]", filepath, foundLayer)
91-
92-
layers := srcManifest.Layers
93-
if foundLayer != nil {
94-
// Remove the found layer from the layers slice as we need to replace it with the new layer.
95-
for i, layer := range layers {
96-
if layer.Digest == foundLayer.Digest && layer.MediaType == foundLayer.MediaType {
97-
layers = slices.Delete(layers, i, i+1)
98-
break
99-
}
100-
}
101-
}
102-
10377
proc := b.getProcessor(filepath, cfg.Raw)
10478
if proc == nil {
10579
return fmt.Errorf("failed to get processor for file %s", filepath)
@@ -114,40 +88,86 @@ func (b *backend) Attach(ctx context.Context, filepath string, cfg *config.Attac
11488
pb.Start()
11589
defer pb.Stop()
11690

117-
newLayers, err := proc.Process(ctx, builder, ".", processor.WithProgressTracker(pb))
118-
if err != nil {
119-
return fmt.Errorf("failed to process layers: %w", err)
120-
}
91+
layers := srcManifest.Layers
92+
// If attach a normal file, we need to process it and create a new layer.
93+
if !cfg.Config {
94+
var foundLayer *ocispec.Descriptor
95+
for _, layer := range srcManifest.Layers {
96+
if anno := layer.Annotations; anno != nil {
97+
if anno[modelspec.AnnotationFilepath] == filepath {
98+
if !cfg.Force {
99+
return fmt.Errorf("file %s already exists, please use --force to overwrite if you want to attach it forcibly", filepath)
100+
}
101+
102+
foundLayer = &layer
103+
break
104+
}
105+
}
106+
}
121107

122-
// Append the new layers to the original layers.
123-
layers = append(layers, newLayers...)
124-
sortLayers(layers)
108+
logrus.Infof("attach: found existing layer for file %s [%+v]", filepath, foundLayer)
109+
if foundLayer != nil {
110+
// Remove the found layer from the layers slice as we need to replace it with the new layer.
111+
for i, layer := range layers {
112+
if layer.Digest == foundLayer.Digest && layer.MediaType == foundLayer.MediaType {
113+
layers = slices.Delete(layers, i, i+1)
114+
break
115+
}
116+
}
117+
}
125118

126-
logrus.Debugf("attach: generated sorted layers [layers: %+v]", layers)
119+
newLayers, err := proc.Process(ctx, builder, ".", processor.WithProgressTracker(pb))
120+
if err != nil {
121+
return fmt.Errorf("failed to process layers: %w", err)
122+
}
127123

128-
diffIDs := []godigest.Digest{}
129-
for _, layer := range layers {
130-
diffIDs = append(diffIDs, layer.Digest)
131-
}
132-
// Return earlier if the diffID has no changed, which means the artifact has not changed.
133-
if reflect.DeepEqual(diffIDs, srcModelConfig.ModelFS.DiffIDs) {
134-
return nil
124+
// Append the new layers to the original layers.
125+
layers = append(layers, newLayers...)
126+
sortLayers(layers)
127+
128+
logrus.Debugf("attach: generated sorted layers [layers: %+v]", layers)
129+
130+
diffIDs := []godigest.Digest{}
131+
for _, layer := range layers {
132+
diffIDs = append(diffIDs, layer.Digest)
133+
}
134+
// Return earlier if the diffID has no changed, which means the artifact has not changed.
135+
if reflect.DeepEqual(diffIDs, srcModelConfig.ModelFS.DiffIDs) {
136+
return nil
137+
}
135138
}
136139

137-
// Build the model config.
138-
modelConfig := &buildconfig.Model{
139-
Architecture: srcModelConfig.Config.Architecture,
140-
Format: srcModelConfig.Config.Format,
141-
Precision: srcModelConfig.Config.Precision,
142-
Quantization: srcModelConfig.Config.Quantization,
143-
ParamSize: srcModelConfig.Config.ParamSize,
144-
Family: srcModelConfig.Descriptor.Family,
145-
Name: srcModelConfig.Descriptor.Name,
140+
var config modelspec.Model
141+
if !cfg.Config {
142+
config, err = build.BuildModelConfig(&buildconfig.Model{
143+
Architecture: srcModelConfig.Config.Architecture,
144+
Format: srcModelConfig.Config.Format,
145+
Precision: srcModelConfig.Config.Precision,
146+
Quantization: srcModelConfig.Config.Quantization,
147+
ParamSize: srcModelConfig.Config.ParamSize,
148+
Family: srcModelConfig.Descriptor.Family,
149+
Name: srcModelConfig.Descriptor.Name,
150+
SourceURL: srcModelConfig.Descriptor.SourceURL,
151+
SourceRevision: srcModelConfig.Descriptor.Revision,
152+
}, layers)
153+
if err != nil {
154+
return fmt.Errorf("failed to build model config: %w", err)
155+
}
156+
} else {
157+
configFile, err := os.Open(filepath)
158+
if err != nil {
159+
return fmt.Errorf("failed to open config file: %w", err)
160+
}
161+
defer configFile.Close()
162+
163+
if err := json.NewDecoder(configFile).Decode(&config); err != nil {
164+
return fmt.Errorf("failed to decode config file %s: %w", filepath, err)
165+
}
146166
}
147167

148-
logrus.Infof("attach: built model config [%+v]", modelConfig)
168+
logrus.Infof("attach: built model config [%+v]", config)
149169

150-
configDesc, err := builder.BuildConfig(ctx, layers, modelConfig, hooks.NewHooks(
170+
configDesc, err := builder.BuildConfig(ctx, config, hooks.NewHooks(
151171
hooks.WithOnStart(func(name string, size int64, reader io.Reader) io.Reader {
152172
return pb.Add(internalpb.NormalizePrompt("Building config"), name, size, reader)
153173
}),

pkg/backend/build.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ func (b *backend) Build(ctx context.Context, modelfilePath, workDir, target stri
106106
revision += "-dirty"
107107
}
108108
// Build the model config.
109-
modelConfig := &buildconfig.Model{
109+
config, err := build.BuildModelConfig(&buildconfig.Model{
110110
Architecture: modelfile.GetArch(),
111111
Format: modelfile.GetFormat(),
112112
Precision: modelfile.GetPrecision(),
@@ -116,14 +116,17 @@ func (b *backend) Build(ctx context.Context, modelfilePath, workDir, target stri
116116
Name: modelfile.GetName(),
117117
SourceURL: sourceInfo.URL,
118118
SourceRevision: revision,
119+
}, layers)
120+
if err != nil {
121+
return fmt.Errorf("failed to build model config: %w", err)
119122
}
120123

121-
logrus.Infof("build: built model config [family: %s, name: %s, format: %s]", modelConfig.Family, modelConfig.Name, modelConfig.Format)
124+
logrus.Infof("build: built model config [config: %+v]", config)
122125

123126
var configDesc ocispec.Descriptor
124127
// Build the model config.
125128
if err := retry.Do(func() error {
126-
configDesc, err = builder.BuildConfig(ctx, layers, modelConfig, hooks.NewHooks(
129+
configDesc, err = builder.BuildConfig(ctx, config, hooks.NewHooks(
127130
hooks.WithOnStart(func(name string, size int64, reader io.Reader) io.Reader {
128131
return pb.Add(internalpb.NormalizePrompt("Building config"), name, size, reader)
129132
}),

pkg/backend/build/builder.go

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ type Builder interface {
6161
BuildLayer(ctx context.Context, mediaType, workDir, path string, hooks hooks.Hooks) (ocispec.Descriptor, error)
6262

6363
// BuildConfig builds the config blob of the artifact.
64-
BuildConfig(ctx context.Context, layers []ocispec.Descriptor, modelConfig *buildconfig.Model, hooks hooks.Hooks) (ocispec.Descriptor, error)
64+
BuildConfig(ctx context.Context, config modelspec.Model, hooks hooks.Hooks) (ocispec.Descriptor, error)
6565

6666
// BuildManifest builds the manifest blob of the artifact.
6767
BuildManifest(ctx context.Context, layers []ocispec.Descriptor, config ocispec.Descriptor, annotations map[string]string, hooks hooks.Hooks) (ocispec.Descriptor, error)
@@ -202,12 +202,7 @@ func (ab *abstractBuilder) BuildLayer(ctx context.Context, mediaType, workDir, p
202202
return desc, nil
203203
}
204204

205-
func (ab *abstractBuilder) BuildConfig(ctx context.Context, layers []ocispec.Descriptor, modelConfig *buildconfig.Model, hooks hooks.Hooks) (ocispec.Descriptor, error) {
206-
config, err := buildModelConfig(modelConfig, layers)
207-
if err != nil {
208-
return ocispec.Descriptor{}, fmt.Errorf("failed to build model config: %w", err)
209-
}
210-
205+
func (ab *abstractBuilder) BuildConfig(ctx context.Context, config modelspec.Model, hooks hooks.Hooks) (ocispec.Descriptor, error) {
211206
configJSON, err := json.Marshal(config)
212207
if err != nil {
213208
return ocispec.Descriptor{}, fmt.Errorf("failed to marshal config: %w", err)
@@ -242,10 +237,10 @@ func (ab *abstractBuilder) BuildManifest(ctx context.Context, layers []ocispec.D
242237
return ab.strategy.OutputManifest(ctx, manifest.MediaType, digest, int64(len(manifestJSON)), bytes.NewReader(manifestJSON), hooks)
243238
}
244239

245-
// buildModelConfig builds the model config.
246-
func buildModelConfig(modelConfig *buildconfig.Model, layers []ocispec.Descriptor) (*modelspec.Model, error) {
240+
// BuildModelConfig builds the model config.
241+
func BuildModelConfig(modelConfig *buildconfig.Model, layers []ocispec.Descriptor) (modelspec.Model, error) {
247242
if modelConfig == nil {
248-
return nil, fmt.Errorf("model config is nil")
243+
return modelspec.Model{}, fmt.Errorf("model config is nil")
249244
}
250245

251246
config := modelspec.ModelConfig{
@@ -275,7 +270,7 @@ func buildModelConfig(modelConfig *buildconfig.Model, layers []ocispec.Descripto
275270
DiffIDs: diffIDs,
276271
}
277272

278-
return &modelspec.Model{
273+
return modelspec.Model{
279274
Config: config,
280275
Descriptor: descriptor,
281276
ModelFS: fs,

pkg/backend/build/builder_test.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,13 @@ func (s *BuilderTestSuite) TestBuildConfig() {
163163
Name: "llama-2",
164164
}
165165

166+
config, err := BuildModelConfig(modelConfig, []ocispec.Descriptor{})
167+
s.NoError(err)
168+
166169
s.mockOutputStrategy.On("OutputConfig", mock.Anything, modelspec.MediaTypeModelConfig, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
167170
Return(expectedDesc, nil).Once()
168171

169-
desc, err := s.builder.BuildConfig(context.Background(), []ocispec.Descriptor{}, modelConfig, hooks.NewHooks())
172+
desc, err := s.builder.BuildConfig(context.Background(), config, hooks.NewHooks())
170173
s.NoError(err)
171174
s.Equal(expectedDesc, desc)
172175

@@ -184,10 +187,13 @@ func (s *BuilderTestSuite) TestBuildConfig() {
184187
Name: "llama-2",
185188
}
186189

190+
config, err := BuildModelConfig(modelConfig, []ocispec.Descriptor{})
191+
s.NoError(err)
192+
187193
s.mockOutputStrategy.On("OutputConfig", mock.Anything, modelspec.MediaTypeModelConfig, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
188194
Return(ocispec.Descriptor{}, errors.New("output error")).Once()
189195

190-
_, err := s.builder.BuildConfig(context.Background(), []ocispec.Descriptor{}, modelConfig, hooks.NewHooks())
196+
_, err = s.builder.BuildConfig(context.Background(), config, hooks.NewHooks())
191197
s.Error(err)
192198
s.True(strings.Contains(err.Error(), "output error"))
193199
})
@@ -248,7 +254,7 @@ func (s *BuilderTestSuite) TestBuildModelConfig() {
248254
Name: "llama-2",
249255
}
250256

251-
model, err := buildModelConfig(modelConfig, []ocispec.Descriptor{
257+
model, err := BuildModelConfig(modelConfig, []ocispec.Descriptor{
252258
{Digest: godigest.Digest("sha256:layer-1")},
253259
{Digest: godigest.Digest("sha256:layer-2")},
254260
})

pkg/config/attach.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ type Attach struct {
2727
Nydusify bool
2828
Force bool
2929
Raw bool
30+
Config bool
3031
}
3132

3233
func NewAttach() *Attach {
@@ -39,6 +40,7 @@ func NewAttach() *Attach {
3940
Nydusify: false,
4041
Force: false,
4142
Raw: false,
43+
Config: false,
4244
}
4345
}
4446

0 commit comments

Comments
 (0)