Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/attach.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ func init() {
flags.BoolVar(&attachConfig.Nydusify, "nydusify", false, "[EXPERIMENTAL] nydusify the model artifact")
flags.MarkHidden("nydusify")
flags.BoolVar(&attachConfig.Raw, "raw", false, "turning on this flag will attach model artifact layer in raw format")
flags.BoolVar(&attachConfig.Config, "config", false, "turning on this flag will overwrite model artifact config layer")

if err := viper.BindPFlags(flags); err != nil {
panic(fmt.Errorf("bind cache list flags to viper: %w", err))
Expand Down
126 changes: 73 additions & 53 deletions pkg/backend/attach.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"encoding/json"
"fmt"
"io"
"os"
"reflect"
"slices"
"sort"
Expand Down Expand Up @@ -73,33 +74,6 @@ func (b *backend) Attach(ctx context.Context, filepath string, cfg *config.Attac

logrus.Infof("attach: loaded source model config [%+v]", srcModelConfig)

var foundLayer *ocispec.Descriptor
for _, layer := range srcManifest.Layers {
if anno := layer.Annotations; anno != nil {
if anno[modelspec.AnnotationFilepath] == filepath {
if !cfg.Force {
return fmt.Errorf("file %s already exists, please use --force to overwrite if you want to attach it forcibly", filepath)
}

foundLayer = &layer
break
}
}
}

logrus.Infof("attach: found existing layer for file %s [%+v]", filepath, foundLayer)

layers := srcManifest.Layers
if foundLayer != nil {
// Remove the found layer from the layers slice as we need to replace it with the new layer.
for i, layer := range layers {
if layer.Digest == foundLayer.Digest && layer.MediaType == foundLayer.MediaType {
layers = slices.Delete(layers, i, i+1)
break
}
}
}

proc := b.getProcessor(filepath, cfg.Raw)
if proc == nil {
return fmt.Errorf("failed to get processor for file %s", filepath)
Expand All @@ -114,40 +88,86 @@ func (b *backend) Attach(ctx context.Context, filepath string, cfg *config.Attac
pb.Start()
defer pb.Stop()

newLayers, err := proc.Process(ctx, builder, ".", processor.WithProgressTracker(pb))
if err != nil {
return fmt.Errorf("failed to process layers: %w", err)
}
layers := srcManifest.Layers
// If attach a normal file, we need to process it and create a new layer.
if !cfg.Config {
var foundLayer *ocispec.Descriptor
for _, layer := range srcManifest.Layers {
if anno := layer.Annotations; anno != nil {
if anno[modelspec.AnnotationFilepath] == filepath {
if !cfg.Force {
return fmt.Errorf("file %s already exists, please use --force to overwrite if you want to attach it forcibly", filepath)
}

foundLayer = &layer
break
}
}
}

// Append the new layers to the original layers.
layers = append(layers, newLayers...)
sortLayers(layers)
logrus.Infof("attach: found existing layer for file %s [%+v]", filepath, foundLayer)
if foundLayer != nil {
// Remove the found layer from the layers slice as we need to replace it with the new layer.
for i, layer := range layers {
if layer.Digest == foundLayer.Digest && layer.MediaType == foundLayer.MediaType {
layers = slices.Delete(layers, i, i+1)
break
}
}
}

logrus.Debugf("attach: generated sorted layers [layers: %+v]", layers)
newLayers, err := proc.Process(ctx, builder, ".", processor.WithProgressTracker(pb))
if err != nil {
return fmt.Errorf("failed to process layers: %w", err)
}

diffIDs := []godigest.Digest{}
for _, layer := range layers {
diffIDs = append(diffIDs, layer.Digest)
}
// Return earlier if the diffID has no changed, which means the artifact has not changed.
if reflect.DeepEqual(diffIDs, srcModelConfig.ModelFS.DiffIDs) {
return nil
// Append the new layers to the original layers.
layers = append(layers, newLayers...)
sortLayers(layers)

logrus.Debugf("attach: generated sorted layers [layers: %+v]", layers)

diffIDs := []godigest.Digest{}
for _, layer := range layers {
diffIDs = append(diffIDs, layer.Digest)
}
// Return earlier if the diffID has no changed, which means the artifact has not changed.
if reflect.DeepEqual(diffIDs, srcModelConfig.ModelFS.DiffIDs) {
return nil
}
}

// Build the model config.
modelConfig := &buildconfig.Model{
Architecture: srcModelConfig.Config.Architecture,
Format: srcModelConfig.Config.Format,
Precision: srcModelConfig.Config.Precision,
Quantization: srcModelConfig.Config.Quantization,
ParamSize: srcModelConfig.Config.ParamSize,
Family: srcModelConfig.Descriptor.Family,
Name: srcModelConfig.Descriptor.Name,
var config modelspec.Model
if !cfg.Config {
config, err = build.BuildModelConfig(&buildconfig.Model{
Architecture: srcModelConfig.Config.Architecture,
Format: srcModelConfig.Config.Format,
Precision: srcModelConfig.Config.Precision,
Quantization: srcModelConfig.Config.Quantization,
ParamSize: srcModelConfig.Config.ParamSize,
Family: srcModelConfig.Descriptor.Family,
Name: srcModelConfig.Descriptor.Name,
SourceURL: srcModelConfig.Descriptor.SourceURL,
SourceRevision: srcModelConfig.Descriptor.Revision,
}, layers)
if err != nil {
return fmt.Errorf("failed to build model config: %w", err)
}
} else {
configFile, err := os.Open(filepath)
if err != nil {
return fmt.Errorf("failed to open config file: %w", err)
}
defer configFile.Close()

if err := json.NewDecoder(configFile).Decode(&config); err != nil {
return fmt.Errorf("failed to decode config file %s: %w", filepath, err)
}
}

logrus.Infof("attach: built model config [%+v]", modelConfig)
logrus.Infof("attach: built model config [%+v]", config)

configDesc, err := builder.BuildConfig(ctx, layers, modelConfig, hooks.NewHooks(
configDesc, err := builder.BuildConfig(ctx, config, hooks.NewHooks(
hooks.WithOnStart(func(name string, size int64, reader io.Reader) io.Reader {
return pb.Add(internalpb.NormalizePrompt("Building config"), name, size, reader)
}),
Expand Down
9 changes: 6 additions & 3 deletions pkg/backend/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func (b *backend) Build(ctx context.Context, modelfilePath, workDir, target stri
revision += "-dirty"
}
// Build the model config.
modelConfig := &buildconfig.Model{
config, err := build.BuildModelConfig(&buildconfig.Model{
Architecture: modelfile.GetArch(),
Format: modelfile.GetFormat(),
Precision: modelfile.GetPrecision(),
Expand All @@ -116,14 +116,17 @@ func (b *backend) Build(ctx context.Context, modelfilePath, workDir, target stri
Name: modelfile.GetName(),
SourceURL: sourceInfo.URL,
SourceRevision: revision,
}, layers)
if err != nil {
return fmt.Errorf("failed to build model config: %w", err)
}

logrus.Infof("build: built model config [family: %s, name: %s, format: %s]", modelConfig.Family, modelConfig.Name, modelConfig.Format)
logrus.Infof("build: built model config [config: %+v]", config)

var configDesc ocispec.Descriptor
// Build the model config.
if err := retry.Do(func() error {
configDesc, err = builder.BuildConfig(ctx, layers, modelConfig, hooks.NewHooks(
configDesc, err = builder.BuildConfig(ctx, config, hooks.NewHooks(
hooks.WithOnStart(func(name string, size int64, reader io.Reader) io.Reader {
return pb.Add(internalpb.NormalizePrompt("Building config"), name, size, reader)
}),
Expand Down
17 changes: 6 additions & 11 deletions pkg/backend/build/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ type Builder interface {
BuildLayer(ctx context.Context, mediaType, workDir, path string, hooks hooks.Hooks) (ocispec.Descriptor, error)

// BuildConfig builds the config blob of the artifact.
BuildConfig(ctx context.Context, layers []ocispec.Descriptor, modelConfig *buildconfig.Model, hooks hooks.Hooks) (ocispec.Descriptor, error)
BuildConfig(ctx context.Context, config modelspec.Model, hooks hooks.Hooks) (ocispec.Descriptor, error)

// BuildManifest builds the manifest blob of the artifact.
BuildManifest(ctx context.Context, layers []ocispec.Descriptor, config ocispec.Descriptor, annotations map[string]string, hooks hooks.Hooks) (ocispec.Descriptor, error)
Expand Down Expand Up @@ -202,12 +202,7 @@ func (ab *abstractBuilder) BuildLayer(ctx context.Context, mediaType, workDir, p
return desc, nil
}

func (ab *abstractBuilder) BuildConfig(ctx context.Context, layers []ocispec.Descriptor, modelConfig *buildconfig.Model, hooks hooks.Hooks) (ocispec.Descriptor, error) {
config, err := buildModelConfig(modelConfig, layers)
if err != nil {
return ocispec.Descriptor{}, fmt.Errorf("failed to build model config: %w", err)
}

func (ab *abstractBuilder) BuildConfig(ctx context.Context, config modelspec.Model, hooks hooks.Hooks) (ocispec.Descriptor, error) {
configJSON, err := json.Marshal(config)
if err != nil {
return ocispec.Descriptor{}, fmt.Errorf("failed to marshal config: %w", err)
Expand Down Expand Up @@ -242,10 +237,10 @@ func (ab *abstractBuilder) BuildManifest(ctx context.Context, layers []ocispec.D
return ab.strategy.OutputManifest(ctx, manifest.MediaType, digest, int64(len(manifestJSON)), bytes.NewReader(manifestJSON), hooks)
}

// buildModelConfig builds the model config.
func buildModelConfig(modelConfig *buildconfig.Model, layers []ocispec.Descriptor) (*modelspec.Model, error) {
// BuildModelConfig builds the model config.
func BuildModelConfig(modelConfig *buildconfig.Model, layers []ocispec.Descriptor) (modelspec.Model, error) {
if modelConfig == nil {
return nil, fmt.Errorf("model config is nil")
return modelspec.Model{}, fmt.Errorf("model config is nil")
}

config := modelspec.ModelConfig{
Expand Down Expand Up @@ -275,7 +270,7 @@ func buildModelConfig(modelConfig *buildconfig.Model, layers []ocispec.Descripto
DiffIDs: diffIDs,
}

return &modelspec.Model{
return modelspec.Model{
Config: config,
Descriptor: descriptor,
ModelFS: fs,
Expand Down
12 changes: 9 additions & 3 deletions pkg/backend/build/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,13 @@ func (s *BuilderTestSuite) TestBuildConfig() {
Name: "llama-2",
}

config, err := BuildModelConfig(modelConfig, []ocispec.Descriptor{})
s.NoError(err)

s.mockOutputStrategy.On("OutputConfig", mock.Anything, modelspec.MediaTypeModelConfig, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(expectedDesc, nil).Once()

desc, err := s.builder.BuildConfig(context.Background(), []ocispec.Descriptor{}, modelConfig, hooks.NewHooks())
desc, err := s.builder.BuildConfig(context.Background(), config, hooks.NewHooks())
s.NoError(err)
s.Equal(expectedDesc, desc)

Expand All @@ -184,10 +187,13 @@ func (s *BuilderTestSuite) TestBuildConfig() {
Name: "llama-2",
}

config, err := BuildModelConfig(modelConfig, []ocispec.Descriptor{})
s.NoError(err)

s.mockOutputStrategy.On("OutputConfig", mock.Anything, modelspec.MediaTypeModelConfig, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(ocispec.Descriptor{}, errors.New("output error")).Once()

_, err := s.builder.BuildConfig(context.Background(), []ocispec.Descriptor{}, modelConfig, hooks.NewHooks())
_, err = s.builder.BuildConfig(context.Background(), config, hooks.NewHooks())
s.Error(err)
s.True(strings.Contains(err.Error(), "output error"))
})
Expand Down Expand Up @@ -248,7 +254,7 @@ func (s *BuilderTestSuite) TestBuildModelConfig() {
Name: "llama-2",
}

model, err := buildModelConfig(modelConfig, []ocispec.Descriptor{
model, err := BuildModelConfig(modelConfig, []ocispec.Descriptor{
{Digest: godigest.Digest("sha256:layer-1")},
{Digest: godigest.Digest("sha256:layer-2")},
})
Expand Down
2 changes: 2 additions & 0 deletions pkg/config/attach.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type Attach struct {
Nydusify bool
Force bool
Raw bool
Config bool
}

func NewAttach() *Attach {
Expand All @@ -39,6 +40,7 @@ func NewAttach() *Attach {
Nydusify: false,
Force: false,
Raw: false,
Config: false,
}
}

Expand Down
Loading