From f10a5da6f9365ca0cf19fe4375420c1a47c4e439 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Tue, 30 Sep 2025 11:18:12 +0200 Subject: [PATCH 01/19] feat: add support for safetensors model format and related functionality --- pkg/distribution/builder/builder.go | 34 +++++ pkg/distribution/internal/bundle/bundle.go | 18 +++ pkg/distribution/internal/bundle/unpack.go | 138 +++++++++++++++++- pkg/distribution/internal/partial/partial.go | 19 +++ .../internal/safetensors/create.go | 118 +++++++++++++++ .../internal/safetensors/model.go | 99 +++++++++++++ .../internal/safetensors/model_test.go | 79 ++++++++++ pkg/distribution/internal/store/model.go | 8 + pkg/distribution/types/config.go | 10 +- pkg/distribution/types/model.go | 4 + .../backends/llamacpp/llamacpp_config_test.go | 8 + 11 files changed, 528 insertions(+), 7 deletions(-) create mode 100644 pkg/distribution/internal/safetensors/create.go create mode 100644 pkg/distribution/internal/safetensors/model.go create mode 100644 pkg/distribution/internal/safetensors/model_test.go diff --git a/pkg/distribution/builder/builder.go b/pkg/distribution/builder/builder.go index ca8c284e8..b5773b73d 100644 --- a/pkg/distribution/builder/builder.go +++ b/pkg/distribution/builder/builder.go @@ -8,6 +8,7 @@ import ( "github.com/docker/model-runner/pkg/distribution/internal/gguf" "github.com/docker/model-runner/pkg/distribution/internal/mutate" "github.com/docker/model-runner/pkg/distribution/internal/partial" + "github.com/docker/model-runner/pkg/distribution/internal/safetensors" "github.com/docker/model-runner/pkg/distribution/types" ) @@ -27,6 +28,28 @@ func FromGGUF(path string) (*Builder, error) { }, nil } +// FromSafetensors returns a *Builder that builds model artifacts from safetensors files +func FromSafetensors(paths []string) (*Builder, error) { + mdl, err := safetensors.NewModel(paths) + if err != nil { + return nil, err + } + return &Builder{ + model: mdl, + }, nil +} + +// FromSafetensorsWithConfig returns a *Builder that builds model artifacts from safetensors files with a config archive +func FromSafetensorsWithConfig(safetensorsPaths []string, configArchivePath string) (*Builder, error) { + mdl, err := safetensors.NewModelWithConfigArchive(safetensorsPaths, configArchivePath) + if err != nil { + return nil, err + } + return &Builder{ + model: mdl, + }, nil +} + // WithLicense adds a license file to the artifact func (b *Builder) WithLicense(path string) (*Builder, error) { licenseLayer, err := partial.NewLayer(path, types.MediaTypeLicense) @@ -66,6 +89,17 @@ func (b *Builder) WithChatTemplateFile(path string) (*Builder, error) { }, nil } +// WithConfigArchive adds a config archive (tar) file to the artifact +func (b *Builder) WithConfigArchive(path string) (*Builder, error) { + configLayer, err := partial.NewLayer(path, types.MediaTypeVLLMConfigArchive) + if err != nil { + return nil, fmt.Errorf("config archive layer from %q: %w", path, err) + } + return &Builder{ + model: mutate.AppendLayers(b.model, configLayer), + }, nil +} + // Target represents a build target type Target interface { Write(context.Context, types.ModelArtifact, io.Writer) error diff --git a/pkg/distribution/internal/bundle/bundle.go b/pkg/distribution/internal/bundle/bundle.go index 5984757d2..d117d63ca 100644 --- a/pkg/distribution/internal/bundle/bundle.go +++ b/pkg/distribution/internal/bundle/bundle.go @@ -11,6 +11,8 @@ type Bundle struct { dir string mmprojPath string ggufFile string // path to GGUF file (first shard when model is split among files) + safetensorsFile string // path to safetensors file (first shard when model is split among files) + configDir string // path to extracted config directory runtimeConfig types.Config chatTemplatePath string } @@ -45,6 +47,22 @@ func (b *Bundle) ChatTemplatePath() string { return filepath.Join(b.dir, b.chatTemplatePath) } +// SafetensorsPath returns the path to model safetensors file. If the model is sharded this will be the path to the first shard. +func (b *Bundle) SafetensorsPath() string { + if b.safetensorsFile == "" { + return "" + } + return filepath.Join(b.dir, b.safetensorsFile) +} + +// ConfigDir returns the path to the extracted config directory or "" if none is present. +func (b *Bundle) ConfigDir() string { + if b.configDir == "" { + return "" + } + return filepath.Join(b.dir, b.configDir) +} + // RuntimeConfig returns config that should be respected by the backend at runtime. func (b *Bundle) RuntimeConfig() types.Config { return b.runtimeConfig diff --git a/pkg/distribution/internal/bundle/unpack.go b/pkg/distribution/internal/bundle/unpack.go index 90cb19541..34836b6f1 100644 --- a/pkg/distribution/internal/bundle/unpack.go +++ b/pkg/distribution/internal/bundle/unpack.go @@ -7,6 +7,7 @@ import ( "path/filepath" "github.com/docker/model-runner/pkg/distribution/types" + ggcrtypes "github.com/google/go-containerregistry/pkg/v1/types" ) // Unpack creates and return a Bundle by unpacking files and config from model into dir. @@ -14,21 +15,86 @@ func Unpack(dir string, model types.Model) (*Bundle, error) { bundle := &Bundle{ dir: dir, } - if err := unpackGGUFs(bundle, model); err != nil { - return nil, fmt.Errorf("add GGUF file(s) to runtime bundle: %w", err) + + // Inspect layers to determine what to unpack + modelFormat := detectModelFormat(model) + + // Unpack model weights based on detected format + switch modelFormat { + case types.FormatGGUF: + if err := unpackGGUFs(bundle, model); err != nil { + return nil, fmt.Errorf("unpack GGUF files: %w", err) + } + case types.FormatSafetensors: + if err := unpackSafetensors(bundle, model); err != nil { + return nil, fmt.Errorf("unpack safetensors files: %w", err) + } + default: + return nil, fmt.Errorf("no supported model weights found (neither GGUF nor safetensors)") } - if err := unpackMultiModalProjector(bundle, model); err != nil { - return nil, fmt.Errorf("add multi-model projector file to runtime bundle: %w", err) + + // Unpack optional components based on their presence + if hasLayerWithMediaType(model, types.MediaTypeMultimodalProjector) { + if err := unpackMultiModalProjector(bundle, model); err != nil { + return nil, fmt.Errorf("add multi-model projector file to runtime bundle: %w", err) + } } - if err := unpackTemplate(bundle, model); err != nil { - return nil, fmt.Errorf("add chat template file to runtime bundle: %w", err) + + if hasLayerWithMediaType(model, types.MediaTypeChatTemplate) { + if err := unpackTemplate(bundle, model); err != nil { + return nil, fmt.Errorf("add chat template file to runtime bundle: %w", err) + } + } + + if hasLayerWithMediaType(model, types.MediaTypeVLLMConfigArchive) { + if err := unpackConfigArchive(bundle, model); err != nil { + return nil, fmt.Errorf("add config archive to runtime bundle: %w", err) + } } + + // Always create the runtime config if err := unpackRuntimeConfig(bundle, model); err != nil { return nil, fmt.Errorf("add config.json to runtime bundle: %w", err) } + return bundle, nil } +// detectModelFormat inspects the model to determine the primary model format +func detectModelFormat(model types.Model) types.Format { + // Check for GGUF files + ggufPaths, err := model.GGUFPaths() + if err == nil && len(ggufPaths) > 0 { + return types.FormatGGUF + } + + // Check for Safetensors files + safetensorsPaths, err := model.SafetensorsPaths() + if err == nil && len(safetensorsPaths) > 0 { + return types.FormatSafetensors + } + + return "" +} + +// hasLayerWithMediaType checks if the model contains a layer with the specified media type +func hasLayerWithMediaType(model types.Model, targetMediaType ggcrtypes.MediaType) bool { + // Check specific media types using the model's methods + switch targetMediaType { + case types.MediaTypeMultimodalProjector: + path, err := model.MMPROJPath() + return err == nil && path != "" + case types.MediaTypeChatTemplate: + path, err := model.ChatTemplatePath() + return err == nil && path != "" + case types.MediaTypeVLLMConfigArchive: + path, err := model.ConfigArchivePath() + return err == nil && path != "" + default: + return false + } +} + func unpackRuntimeConfig(bundle *Bundle, mdl types.Model) error { cfg, err := mdl.Config() if err != nil { @@ -95,6 +161,66 @@ func unpackTemplate(bundle *Bundle, mdl types.Model) error { return nil } +func unpackSafetensors(bundle *Bundle, mdl types.Model) error { + safetensorsPaths, err := mdl.SafetensorsPaths() + if err != nil { + return fmt.Errorf("get safetensors files for model: %w", err) + } + + if len(safetensorsPaths) == 0 { + return fmt.Errorf("no safetensors files found") + } + + if len(safetensorsPaths) == 1 { + if err := unpackFile(filepath.Join(bundle.dir, "model.safetensors"), safetensorsPaths[0]); err != nil { + return err + } + bundle.safetensorsFile = "model.safetensors" + return nil + } + + // Handle sharded safetensors files + for i := range safetensorsPaths { + name := fmt.Sprintf("model-%05d-of-%05d.safetensors", i+1, len(safetensorsPaths)) + if err := unpackFile(filepath.Join(bundle.dir, name), safetensorsPaths[i]); err != nil { + return err + } + if i == 0 { + bundle.safetensorsFile = name + } + } + + return nil +} + +func unpackConfigArchive(bundle *Bundle, mdl types.Model) error { + archivePath, err := mdl.ConfigArchivePath() + if err != nil { + return nil // no config archive + } + + // Create config directory + configDir := filepath.Join(bundle.dir, "configs") + if err := os.MkdirAll(configDir, 0755); err != nil { + return fmt.Errorf("create config directory: %w", err) + } + + // Extract the tar archive + if err := extractTarArchive(archivePath, configDir); err != nil { + return fmt.Errorf("extract config archive: %w", err) + } + + bundle.configDir = "configs" + return nil +} + +func extractTarArchive(archivePath, destDir string) error { + // For now, we'll just link the tar file. + // TODO: Implement proper tar extraction using archive/tar package + // This would extract files like tokenizer.json, config.json, etc. + return os.Link(archivePath, filepath.Join(destDir, "config.tar")) +} + func unpackFile(bundlePath string, srcPath string) error { return os.Link(srcPath, bundlePath) } diff --git a/pkg/distribution/internal/partial/partial.go b/pkg/distribution/internal/partial/partial.go index 00bedef64..9b9fe4ca3 100644 --- a/pkg/distribution/internal/partial/partial.go +++ b/pkg/distribution/internal/partial/partial.go @@ -99,6 +99,25 @@ func ChatTemplatePath(i WithLayers) (string, error) { return paths[0], err } +func SafetensorsPaths(i WithLayers) ([]string, error) { + return layerPathsByMediaType(i, types.MediaTypeSafetensors) +} + +func ConfigArchivePath(i WithLayers) (string, error) { + paths, err := layerPathsByMediaType(i, types.MediaTypeVLLMConfigArchive) + if err != nil { + return "", fmt.Errorf("get config archive layer paths: %w", err) + } + if len(paths) == 0 { + return "", fmt.Errorf("model does not contain any layer of type %q", types.MediaTypeVLLMConfigArchive) + } + if len(paths) > 1 { + return "", fmt.Errorf("found %d files of type %q, expected exactly 1", + len(paths), types.MediaTypeVLLMConfigArchive) + } + return paths[0], err +} + // layerPathsByMediaType is a generic helper function that finds a layer by media type and returns its path func layerPathsByMediaType(i WithLayers, mediaType ggcr.MediaType) ([]string, error) { layers, err := i.Layers() diff --git a/pkg/distribution/internal/safetensors/create.go b/pkg/distribution/internal/safetensors/create.go new file mode 100644 index 000000000..96e41748a --- /dev/null +++ b/pkg/distribution/internal/safetensors/create.go @@ -0,0 +1,118 @@ +package safetensors + +import ( + "fmt" + "path/filepath" + "strings" + "time" + + v1 "github.com/google/go-containerregistry/pkg/v1" + + "github.com/docker/model-runner/pkg/distribution/internal/partial" + "github.com/docker/model-runner/pkg/distribution/types" +) + +// NewModel creates a new safetensors model from one or more safetensors files +func NewModel(paths []string) (*Model, error) { + if len(paths) == 0 { + return nil, fmt.Errorf("at least one safetensors file is required") + } + + layers := make([]v1.Layer, len(paths)) + diffIDs := make([]v1.Hash, len(paths)) + + for i, path := range paths { + layer, err := partial.NewLayer(path, types.MediaTypeSafetensors) + if err != nil { + return nil, fmt.Errorf("create safetensors layer from %q: %w", path, err) + } + diffID, err := layer.DiffID() + if err != nil { + return nil, fmt.Errorf("get safetensors layer diffID: %w", err) + } + layers[i] = layer + diffIDs[i] = diffID + } + + created := time.Now() + return &Model{ + configFile: types.ConfigFile{ + Config: configFromFiles(paths), + Descriptor: types.Descriptor{ + Created: &created, + }, + RootFS: v1.RootFS{ + Type: "rootfs", + DiffIDs: diffIDs, + }, + }, + layers: layers, + }, nil +} + +// NewModelWithConfigArchive creates a new safetensors model with a config archive +func NewModelWithConfigArchive(safetensorsPaths []string, configArchivePath string) (*Model, error) { + model, err := NewModel(safetensorsPaths) + if err != nil { + return nil, err + } + + // Add config archive layer + if configArchivePath != "" { + configLayer, err := partial.NewLayer(configArchivePath, types.MediaTypeVLLMConfigArchive) + if err != nil { + return nil, fmt.Errorf("create config archive layer from %q: %w", configArchivePath, err) + } + + diffID, err := configLayer.DiffID() + if err != nil { + return nil, fmt.Errorf("get config archive layer diffID: %w", err) + } + + model.layers = append(model.layers, configLayer) + model.configFile.RootFS.DiffIDs = append(model.configFile.RootFS.DiffIDs, diffID) + } + + return model, nil +} + +func configFromFiles(paths []string) types.Config { + // Extract basic metadata from file paths + // This is a simplified version - in production, you might want to + // parse safetensors headers for more detailed metadata + + var totalFiles int + var architecture string + + if len(paths) > 0 { + totalFiles = len(paths) + // Try to extract architecture from filename + baseName := filepath.Base(paths[0]) + baseName = strings.ToLower(baseName) + + // Common patterns in model filenames + if strings.Contains(baseName, "llama") { + architecture = "llama" + } else if strings.Contains(baseName, "mistral") { + architecture = "mistral" + } else if strings.Contains(baseName, "qwen") { + architecture = "qwen" + } else if strings.Contains(baseName, "gemma") { + architecture = "gemma" + } + } + + safetensorsMetadata := map[string]string{ + "total_files": fmt.Sprintf("%d", totalFiles), + } + + if architecture != "" { + safetensorsMetadata["architecture"] = architecture + } + + return types.Config{ + Format: types.FormatSafetensors, + Architecture: architecture, + Safetensors: safetensorsMetadata, + } +} diff --git a/pkg/distribution/internal/safetensors/model.go b/pkg/distribution/internal/safetensors/model.go new file mode 100644 index 000000000..4562dfac2 --- /dev/null +++ b/pkg/distribution/internal/safetensors/model.go @@ -0,0 +1,99 @@ +package safetensors + +import ( + "encoding/json" + "fmt" + + v1 "github.com/google/go-containerregistry/pkg/v1" + "github.com/google/go-containerregistry/pkg/v1/partial" + ggcr "github.com/google/go-containerregistry/pkg/v1/types" + + mdpartial "github.com/docker/model-runner/pkg/distribution/internal/partial" + "github.com/docker/model-runner/pkg/distribution/types" +) + +var _ types.ModelArtifact = &Model{} + +type Model struct { + configFile types.ConfigFile + layers []v1.Layer + manifest *v1.Manifest +} + +func (m *Model) Layers() ([]v1.Layer, error) { + return m.layers, nil +} + +func (m *Model) Size() (int64, error) { + return partial.Size(m) +} + +func (m *Model) ConfigName() (v1.Hash, error) { + return partial.ConfigName(m) +} + +func (m *Model) ConfigFile() (*v1.ConfigFile, error) { + return nil, fmt.Errorf("invalid for model") +} + +func (m *Model) Digest() (v1.Hash, error) { + return partial.Digest(m) +} + +func (m *Model) Manifest() (*v1.Manifest, error) { + return mdpartial.ManifestForLayers(m) +} + +func (m *Model) LayerByDigest(hash v1.Hash) (v1.Layer, error) { + for _, l := range m.layers { + d, err := l.Digest() + if err != nil { + return nil, fmt.Errorf("get layer digest: %w", err) + } + if d == hash { + return l, nil + } + } + return nil, fmt.Errorf("layer not found") +} + +func (m *Model) LayerByDiffID(hash v1.Hash) (v1.Layer, error) { + for _, l := range m.layers { + d, err := l.DiffID() + if err != nil { + return nil, fmt.Errorf("get layer digest: %w", err) + } + if d == hash { + return l, nil + } + } + return nil, fmt.Errorf("layer not found") +} + +func (m *Model) RawManifest() ([]byte, error) { + return partial.RawManifest(m) +} + +func (m *Model) RawConfigFile() ([]byte, error) { + return json.Marshal(m.configFile) +} + +func (m *Model) MediaType() (ggcr.MediaType, error) { + manifest, err := m.Manifest() + if err != nil { + return "", fmt.Errorf("compute manifest: %w", err) + } + return manifest.MediaType, nil +} + +func (m *Model) ID() (string, error) { + return mdpartial.ID(m) +} + +func (m *Model) Config() (types.Config, error) { + return mdpartial.Config(m) +} + +func (m *Model) Descriptor() (types.Descriptor, error) { + return mdpartial.Descriptor(m) +} diff --git a/pkg/distribution/internal/safetensors/model_test.go b/pkg/distribution/internal/safetensors/model_test.go new file mode 100644 index 000000000..9f8dd663f --- /dev/null +++ b/pkg/distribution/internal/safetensors/model_test.go @@ -0,0 +1,79 @@ +package safetensors + +import ( + "testing" + + "github.com/docker/model-runner/pkg/distribution/types" +) + +func TestNewModel(t *testing.T) { + // Create a test safetensors model + // Note: In a real test, you would use actual safetensors files + // For now, we'll test with dummy paths to verify the structure + + t.Run("single file", func(t *testing.T) { + paths := []string{"test-model.safetensors"} + model, err := NewModel(paths) + if err == nil { + t.Error("Expected error for non-existent file, got nil") + } + // The error is expected since the file doesn't exist + // In a real test, we'd use test fixtures + _ = model + }) + + t.Run("empty paths", func(t *testing.T) { + var paths []string + _, err := NewModel(paths) + if err == nil { + t.Error("Expected error for empty paths, got nil") + } + }) + + t.Run("config extraction", func(t *testing.T) { + config := configFromFiles([]string{"llama-7b-model.safetensors"}) + if config.Format != types.FormatSafetensors { + t.Errorf("Expected format %s, got %s", types.FormatSafetensors, config.Format) + } + if config.Architecture != "llama" { + t.Errorf("Expected architecture 'llama', got %s", config.Architecture) + } + if config.Safetensors["total_files"] != "1" { + t.Errorf("Expected total_files '1', got %s", config.Safetensors["total_files"]) + } + }) + + t.Run("architecture detection", func(t *testing.T) { + tests := []struct { + filename string + expected string + }{ + {"mistral-7b-instruct.safetensors", "mistral"}, + {"qwen2-vl-7b.safetensors", "qwen"}, + {"gemma-2b.safetensors", "gemma"}, + {"unknown-model.safetensors", ""}, + } + + for _, tt := range tests { + config := configFromFiles([]string{tt.filename}) + if config.Architecture != tt.expected { + t.Errorf("For file %s, expected architecture %q, got %q", + tt.filename, tt.expected, config.Architecture) + } + } + }) +} + +func TestNewModelWithConfigArchive(t *testing.T) { + // Test that the function properly handles config archives + // In a real test, we'd use actual files + + safetensorsPaths := []string{"model.safetensors"} + configPath := "config.tar" + + _, err := NewModelWithConfigArchive(safetensorsPaths, configPath) + if err == nil { + t.Error("Expected error for non-existent files, got nil") + } + // The error is expected since the files don't exist +} diff --git a/pkg/distribution/internal/store/model.go b/pkg/distribution/internal/store/model.go index 131d7908f..a063e5cf2 100644 --- a/pkg/distribution/internal/store/model.go +++ b/pkg/distribution/internal/store/model.go @@ -130,6 +130,14 @@ func (m *Model) ChatTemplatePath() (string, error) { return mdpartial.ChatTemplatePath(m) } +func (m *Model) SafetensorsPaths() ([]string, error) { + return mdpartial.SafetensorsPaths(m) +} + +func (m *Model) ConfigArchivePath() (string, error) { + return mdpartial.ConfigArchivePath(m) +} + func (m *Model) Tags() []string { return m.tags } diff --git a/pkg/distribution/types/config.go b/pkg/distribution/types/config.go index 0261a9f96..adce5d9f8 100644 --- a/pkg/distribution/types/config.go +++ b/pkg/distribution/types/config.go @@ -17,6 +17,12 @@ const ( // MediaTypeGGUF indicates a file in GGUF version 3 format, containing a tensor model. MediaTypeGGUF = types.MediaType("application/vnd.docker.ai.gguf.v3") + // MediaTypeSafetensors indicates a file in safetensors format, containing model weights. + MediaTypeSafetensors = types.MediaType("application/vnd.docker.ai.safetensors") + + // MediaTypeVLLMConfigArchive indicates a tar archive containing vLLM-specific config files. + MediaTypeVLLMConfigArchive = types.MediaType("application/vnd.docker.ai.vllm.config.tar") + // MediaTypeLicense indicates a plain text file containing a license MediaTypeLicense = types.MediaType("application/vnd.docker.ai.license") @@ -26,7 +32,8 @@ const ( // MediaTypeChatTemplate indicates a Jinja chat template MediaTypeChatTemplate = types.MediaType("application/vnd.docker.ai.chat.template.jinja") - FormatGGUF = Format("gguf") + FormatGGUF = Format("gguf") + FormatSafetensors = Format("safetensors") ) type Format string @@ -45,6 +52,7 @@ type Config struct { Architecture string `json:"architecture,omitempty"` Size string `json:"size,omitempty"` GGUF map[string]string `json:"gguf,omitempty"` + Safetensors map[string]string `json:"safetensors,omitempty"` ContextSize *uint64 `json:"context_size,omitempty"` } diff --git a/pkg/distribution/types/model.go b/pkg/distribution/types/model.go index 7f9ba3948..4200ba619 100644 --- a/pkg/distribution/types/model.go +++ b/pkg/distribution/types/model.go @@ -7,6 +7,8 @@ import ( type Model interface { ID() (string, error) GGUFPaths() ([]string, error) + SafetensorsPaths() ([]string, error) + ConfigArchivePath() (string, error) MMPROJPath() (string, error) Config() (Config, error) Tags() []string @@ -24,6 +26,8 @@ type ModelArtifact interface { type ModelBundle interface { RootDir() string GGUFPath() string + SafetensorsPath() string + ConfigDir() string ChatTemplatePath() string MMPROJPath() string RuntimeConfig() Config diff --git a/pkg/inference/backends/llamacpp/llamacpp_config_test.go b/pkg/inference/backends/llamacpp/llamacpp_config_test.go index 83c8c84bc..cfbbe4424 100644 --- a/pkg/inference/backends/llamacpp/llamacpp_config_test.go +++ b/pkg/inference/backends/llamacpp/llamacpp_config_test.go @@ -279,6 +279,14 @@ func (f *fakeBundle) MMPROJPath() string { return "" } +func (f *fakeBundle) SafetensorsPath() string { + return "" +} + +func (f *fakeBundle) ConfigDir() string { + return "" +} + func (f *fakeBundle) RuntimeConfig() types.Config { return f.config } From 74cd834c0cc725771c81d3231df57ec08e90ccf0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Tue, 30 Sep 2025 11:24:14 +0200 Subject: [PATCH 02/19] implement extraction logic --- pkg/distribution/internal/bundle/unpack.go | 87 +++++++++++++++++++++- 1 file changed, 83 insertions(+), 4 deletions(-) diff --git a/pkg/distribution/internal/bundle/unpack.go b/pkg/distribution/internal/bundle/unpack.go index 34836b6f1..dde9a6a71 100644 --- a/pkg/distribution/internal/bundle/unpack.go +++ b/pkg/distribution/internal/bundle/unpack.go @@ -1,10 +1,13 @@ package bundle import ( + "archive/tar" "encoding/json" "fmt" + "io" "os" "path/filepath" + "strings" "github.com/docker/model-runner/pkg/distribution/types" ggcrtypes "github.com/google/go-containerregistry/pkg/v1/types" @@ -215,10 +218,86 @@ func unpackConfigArchive(bundle *Bundle, mdl types.Model) error { } func extractTarArchive(archivePath, destDir string) error { - // For now, we'll just link the tar file. - // TODO: Implement proper tar extraction using archive/tar package - // This would extract files like tokenizer.json, config.json, etc. - return os.Link(archivePath, filepath.Join(destDir, "config.tar")) + // Open the tar file + file, err := os.Open(archivePath) + if err != nil { + return fmt.Errorf("open tar archive: %w", err) + } + defer file.Close() + + // Create tar reader + tr := tar.NewReader(file) + + // Extract files + for { + header, err := tr.Next() + if err == io.EOF { + break // End of archive + } + if err != nil { + return fmt.Errorf("read tar header: %w", err) + } + + // Clean and validate the path to prevent directory traversal + target := filepath.Join(destDir, header.Name) + if !strings.HasPrefix(filepath.Clean(target), filepath.Clean(destDir)+string(os.PathSeparator)) { + return fmt.Errorf("invalid tar entry: %s attempts to escape destination directory", header.Name) + } + + // Process based on header type + switch header.Typeflag { + case tar.TypeDir: + // Create directory + if err := os.MkdirAll(target, os.FileMode(header.Mode)); err != nil { + return fmt.Errorf("create directory %s: %w", target, err) + } + + case tar.TypeReg: + // Extract regular file + if err := extractFile(tr, target, os.FileMode(header.Mode)); err != nil { + return fmt.Errorf("extract file %s: %w", target, err) + } + + case tar.TypeSymlink: + // Handle symlinks safely + linkTarget := filepath.Join(filepath.Dir(target), header.Linkname) + if !strings.HasPrefix(filepath.Clean(linkTarget), filepath.Clean(destDir)+string(os.PathSeparator)) { + // Skip symlinks that would escape the destination directory + continue + } + if err := os.Symlink(header.Linkname, target); err != nil { + return fmt.Errorf("create symlink %s: %w", target, err) + } + + default: + // Skip other types (block devices, char devices, FIFOs, etc.) + continue + } + } + + return nil +} + +// extractFile extracts a single file from the tar reader +func extractFile(tr io.Reader, target string, mode os.FileMode) error { + // Ensure parent directory exists + if err := os.MkdirAll(filepath.Dir(target), 0755); err != nil { + return fmt.Errorf("create parent directory: %w", err) + } + + // Create the file + file, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode) + if err != nil { + return fmt.Errorf("create file: %w", err) + } + defer file.Close() + + // Copy contents + if _, err := io.Copy(file, tr); err != nil { + return fmt.Errorf("write file contents: %w", err) + } + + return nil } func unpackFile(bundlePath string, srcPath string) error { From bdb03a8d7a23da36e6fe953b03df0e84433ffb4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Tue, 30 Sep 2025 15:28:39 +0200 Subject: [PATCH 03/19] feat: enhance safetensors support with sharded model discovery and packaging --- Makefile | 1 + cmd/mdltool/main.go | 243 +++++++++++++++--- pkg/distribution/builder/builder.go | 11 - .../internal/safetensors/create.go | 62 ++++- 4 files changed, 274 insertions(+), 43 deletions(-) diff --git a/Makefile b/Makefile index 661a2f9f8..d070f9f1c 100644 --- a/Makefile +++ b/Makefile @@ -140,5 +140,6 @@ help: @echo "Model distribution tool examples:" @echo " make mdl-pull TAG=registry.example.com/models/llama:v1.0" @echo " make mdl-package SOURCE=./model.gguf TAG=registry.example.com/models/llama:v1.0 LICENSE=./license.txt" + @echo " make mdl-package SOURCE=./qwen2.5-3b-instruct TAG=registry.example.com/models/qwen:v1.0" @echo " make mdl-list" @echo " make mdl-rm TAG=registry.example.com/models/llama:v1.0" diff --git a/cmd/mdltool/main.go b/cmd/mdltool/main.go index 9db316a2c..34480e884 100644 --- a/cmd/mdltool/main.go +++ b/cmd/mdltool/main.go @@ -1,9 +1,11 @@ package main import ( + "archive/tar" "context" "flag" "fmt" + "io" "os" "path/filepath" "strings" @@ -178,7 +180,12 @@ func cmdPackage(args []string) int { fs.StringVar(&chatTemplate, "chat-template", "", "Jinja chat template file") fs.Usage = func() { - fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool package [OPTIONS] \n\n") + fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool package [OPTIONS] \n\n") + fmt.Fprintf(os.Stderr, "Examples:\n") + fmt.Fprintf(os.Stderr, " # GGUF model:\n") + fmt.Fprintf(os.Stderr, " model-distribution-tool package model.gguf --tag registry/model:tag\n\n") + fmt.Fprintf(os.Stderr, " # Safetensors model:\n") + fmt.Fprintf(os.Stderr, " model-distribution-tool package ./qwen-model-dir --tag registry/model:tag\n\n") fmt.Fprintf(os.Stderr, "Options:\n") fs.PrintDefaults() } @@ -189,32 +196,74 @@ func cmdPackage(args []string) int { } args = fs.Args() + // Get the source from positional argument if len(args) < 1 { - fmt.Fprintf(os.Stderr, "Error: missing arguments\n") + fmt.Fprintf(os.Stderr, "Error: no model file or directory specified\n") fs.Usage() return 1 } - if file == "" && tag == "" { - fmt.Fprintf(os.Stderr, "Error: one of --file or --tag is required\n") - fs.Usage() + + source := args[0] + var isSafetensors bool + var configArchive string // For safetensors config + var safetensorsPaths []string // For safetensors model files + + // Check if source exists + sourceInfo, err := os.Stat(source) + if os.IsNotExist(err) { + fmt.Fprintf(os.Stderr, "Error: source does not exist: %s\n", source) return 1 } - source := args[0] - ctx := context.Background() + // Handle directory-based packaging (for safetensors models) + if sourceInfo.IsDir() { + fmt.Printf("Detected directory, scanning for safetensors model...\n") + var err error + safetensorsPaths, configArchive, err = packageFromDirectory(source) + if err != nil { + fmt.Fprintf(os.Stderr, "Error scanning directory: %v\n", err) + return 1 + } - // Check if source file exists - if _, err := os.Stat(source); os.IsNotExist(err) { - fmt.Fprintf(os.Stderr, "Error: source file does not exist: %s\n", source) - return 1 + isSafetensors = true + fmt.Printf("Found %d safetensors file(s)\n", len(safetensorsPaths)) + + // Clean up temp config archive when done + if configArchive != "" { + defer os.Remove(configArchive) + fmt.Printf("Created temporary config archive from directory\n") + } + } else { + // Handle single file (GGUF or safetensors) + if strings.HasSuffix(strings.ToLower(source), ".safetensors") { + isSafetensors = true + safetensorsPaths = []string{source} + fmt.Println("Detected safetensors model file") + + // Auto-discover configs from file's directory + parentDir := filepath.Dir(source) + _, configArchive, err = packageFromDirectory(parentDir) + if err == nil && configArchive != "" { + defer os.Remove(configArchive) + fmt.Printf("Auto-discovered config files from %s\n", parentDir) + } + } else if strings.HasSuffix(strings.ToLower(source), ".gguf") { + isSafetensors = false + fmt.Println("Detected GGUF model file") + } else { + fmt.Fprintf(os.Stderr, "Warning: could not determine model type for: %s\n", source) + fmt.Fprintf(os.Stderr, "Assuming GGUF format.\n") + } } - // Check if source file is a GGUF file - if !strings.HasSuffix(strings.ToLower(source), ".gguf") { - fmt.Fprintf(os.Stderr, "Warning: source file does not have .gguf extension: %s\n", source) - fmt.Fprintf(os.Stderr, "Continuing anyway, but this may cause issues.\n") + if file == "" && tag == "" { + fmt.Fprintf(os.Stderr, "Error: one of --file or --tag is required\n") + fs.Usage() + return 1 } + ctx := context.Background() + // Prepare registry client options registryClientOpts := []registry.ClientOption{ registry.WithUserAgent("model-distribution-tool/" + version), @@ -230,13 +279,11 @@ func cmdPackage(args []string) int { // Create registry client once with all options registryClient := registry.NewClient(registryClientOpts...) - var ( - target builder.Target - err error - ) + var target builder.Target if file != "" { target = tarball.NewFileTarget(file) } else { + var err error target, err = registryClient.NewTarget(tag) if err != nil { fmt.Fprintf(os.Stderr, "Create packaging target: %v\n", err) @@ -244,17 +291,32 @@ func cmdPackage(args []string) int { } } - // Create image with layer - builder, err := builder.FromGGUF(source) - if err != nil { - fmt.Fprintf(os.Stderr, "Error creating model from gguf: %v\n", err) - return 1 + // Create builder based on model type + var b *builder.Builder + if isSafetensors { + if configArchive != "" { + fmt.Printf("Creating safetensors model with config archive: %s\n", configArchive) + b, err = builder.FromSafetensorsWithConfig(safetensorsPaths, configArchive) + if err != nil { + fmt.Fprintf(os.Stderr, "Error creating model from safetensors with config: %v\n", err) + return 1 + } + } else { + fmt.Fprintf(os.Stderr, "Error: config archive is required for safetensors models\n") + return 1 + } + } else { + b, err = builder.FromGGUF(source) + if err != nil { + fmt.Fprintf(os.Stderr, "Error creating model from gguf: %v\n", err) + return 1 + } } // Add all license files as layers for _, path := range licensePaths { fmt.Println("Adding license file:", path) - builder, err = builder.WithLicense(path) + b, err = b.WithLicense(path) if err != nil { fmt.Fprintf(os.Stderr, "Error adding license layer for %s: %v\n", path, err) return 1 @@ -263,12 +325,12 @@ func cmdPackage(args []string) int { if contextSize > 0 { fmt.Println("Setting context size:", contextSize) - builder = builder.WithContextSize(contextSize) + b = b.WithContextSize(contextSize) } if mmproj != "" { fmt.Println("Adding multimodal projector file:", mmproj) - builder, err = builder.WithMultimodalProjector(mmproj) + b, err = b.WithMultimodalProjector(mmproj) if err != nil { fmt.Fprintf(os.Stderr, "Error adding multimodal projector layer for %s: %v\n", mmproj, err) return 1 @@ -277,7 +339,7 @@ func cmdPackage(args []string) int { if chatTemplate != "" { fmt.Println("Adding chat template file:", chatTemplate) - builder, err = builder.WithChatTemplateFile(chatTemplate) + b, err = b.WithChatTemplateFile(chatTemplate) if err != nil { fmt.Fprintf(os.Stderr, "Error adding chat template layer for %s: %v\n", chatTemplate, err) return 1 @@ -285,7 +347,7 @@ func cmdPackage(args []string) int { } // Push the image - if err := builder.Build(ctx, target, os.Stdout); err != nil { + if err := b.Build(ctx, target, os.Stdout); err != nil { fmt.Fprintf(os.Stderr, "Error writing model to registry: %v\n", err) return 1 } @@ -525,3 +587,126 @@ func cmdBundle(client *distribution.Client, args []string) int { fmt.Fprint(os.Stdout, bundle.RootDir()) return 0 } + +// packageFromDirectory scans a directory for safetensors files and config files, +// creating a temporary tar archive of the config files +func packageFromDirectory(dirPath string) (safetensorsPaths []string, tempConfigArchive string, err error) { + // Read directory contents (only top level, no subdirectories) + entries, err := os.ReadDir(dirPath) + if err != nil { + return nil, "", fmt.Errorf("read directory: %w", err) + } + + var configFiles []string + + for _, entry := range entries { + if entry.IsDir() { + continue // Skip subdirectories + } + + name := entry.Name() + fullPath := filepath.Join(dirPath, name) + + // Collect safetensors files + if strings.HasSuffix(strings.ToLower(name), ".safetensors") { + safetensorsPaths = append(safetensorsPaths, fullPath) + } + + // Collect config files: *.json, merges.txt + if strings.HasSuffix(strings.ToLower(name), ".json") || + name == "merges.txt" { + configFiles = append(configFiles, fullPath) + } + } + + if len(safetensorsPaths) == 0 { + return nil, "", fmt.Errorf("no safetensors files found in directory: %s", dirPath) + } + + // Create temporary tar archive with config files if any exist + if len(configFiles) > 0 { + tempConfigArchive, err = createTempConfigArchive(configFiles) + if err != nil { + return nil, "", fmt.Errorf("create config archive: %w", err) + } + } + + return safetensorsPaths, tempConfigArchive, nil +} + +// createTempConfigArchive creates a temporary tar archive containing the specified config files +func createTempConfigArchive(configFiles []string) (string, error) { + // Create temp file + tmpFile, err := os.CreateTemp("", "vllm-config-*.tar") + if err != nil { + return "", fmt.Errorf("create temp file: %w", err) + } + tmpPath := tmpFile.Name() + + // Create tar writer + tw := tar.NewWriter(tmpFile) + + // Add each config file to tar (preserving just filename, not full path) + for _, filePath := range configFiles { + // Open the file + file, err := os.Open(filePath) + if err != nil { + tw.Close() + tmpFile.Close() + os.Remove(tmpPath) + return "", fmt.Errorf("open config file %s: %w", filePath, err) + } + + // Get file info for tar header + fileInfo, err := file.Stat() + if err != nil { + file.Close() + tw.Close() + tmpFile.Close() + os.Remove(tmpPath) + return "", fmt.Errorf("stat config file %s: %w", filePath, err) + } + + // Create tar header (use only basename, not full path) + header := &tar.Header{ + Name: filepath.Base(filePath), + Size: fileInfo.Size(), + Mode: int64(fileInfo.Mode()), + ModTime: fileInfo.ModTime(), + } + + // Write header + if err := tw.WriteHeader(header); err != nil { + file.Close() + tw.Close() + tmpFile.Close() + os.Remove(tmpPath) + return "", fmt.Errorf("write tar header for %s: %w", filePath, err) + } + + // Copy file contents + if _, err := io.Copy(tw, file); err != nil { + file.Close() + tw.Close() + tmpFile.Close() + os.Remove(tmpPath) + return "", fmt.Errorf("write tar content for %s: %w", filePath, err) + } + + file.Close() + } + + // Close tar writer and file + if err := tw.Close(); err != nil { + tmpFile.Close() + os.Remove(tmpPath) + return "", fmt.Errorf("close tar writer: %w", err) + } + + if err := tmpFile.Close(); err != nil { + os.Remove(tmpPath) + return "", fmt.Errorf("close temp file: %w", err) + } + + return tmpPath, nil +} diff --git a/pkg/distribution/builder/builder.go b/pkg/distribution/builder/builder.go index b5773b73d..f202dcd23 100644 --- a/pkg/distribution/builder/builder.go +++ b/pkg/distribution/builder/builder.go @@ -28,17 +28,6 @@ func FromGGUF(path string) (*Builder, error) { }, nil } -// FromSafetensors returns a *Builder that builds model artifacts from safetensors files -func FromSafetensors(paths []string) (*Builder, error) { - mdl, err := safetensors.NewModel(paths) - if err != nil { - return nil, err - } - return &Builder{ - model: mdl, - }, nil -} - // FromSafetensorsWithConfig returns a *Builder that builds model artifacts from safetensors files with a config archive func FromSafetensorsWithConfig(safetensorsPaths []string, configArchivePath string) (*Builder, error) { mdl, err := safetensors.NewModelWithConfigArchive(safetensorsPaths, configArchivePath) diff --git a/pkg/distribution/internal/safetensors/create.go b/pkg/distribution/internal/safetensors/create.go index 96e41748a..615abb8d9 100644 --- a/pkg/distribution/internal/safetensors/create.go +++ b/pkg/distribution/internal/safetensors/create.go @@ -2,7 +2,10 @@ package safetensors import ( "fmt" + "os" "path/filepath" + "regexp" + "strconv" "strings" "time" @@ -13,15 +16,24 @@ import ( ) // NewModel creates a new safetensors model from one or more safetensors files +// If a sharded model pattern is detected (e.g., model-00001-of-00002.safetensors), +// it will auto-discover all related shards func NewModel(paths []string) (*Model, error) { if len(paths) == 0 { return nil, fmt.Errorf("at least one safetensors file is required") } - layers := make([]v1.Layer, len(paths)) - diffIDs := make([]v1.Hash, len(paths)) + // Auto-discover shards if the first path matches the shard pattern + allPaths := discoverSafetensorsShards(paths[0]) + if len(allPaths) == 0 { + // No shards found, use provided paths as-is + allPaths = paths + } + + layers := make([]v1.Layer, len(allPaths)) + diffIDs := make([]v1.Hash, len(allPaths)) - for i, path := range paths { + for i, path := range allPaths { layer, err := partial.NewLayer(path, types.MediaTypeSafetensors) if err != nil { return nil, fmt.Errorf("create safetensors layer from %q: %w", path, err) @@ -76,6 +88,50 @@ func NewModelWithConfigArchive(safetensorsPaths []string, configArchivePath stri return model, nil } +// discoverSafetensorsShards attempts to auto-discover all shards for a given safetensors file +// It looks for the pattern: -XXXXX-of-YYYYY.safetensors +// Returns an empty slice if no shards are found or if it's a single file +func discoverSafetensorsShards(path string) []string { + // Pattern: model-00001-of-00003.safetensors + pattern := regexp.MustCompile(`^(.+)-(\d{5})-of-(\d{5})\.safetensors$`) + + baseName := filepath.Base(path) + matches := pattern.FindStringSubmatch(baseName) + + if len(matches) != 4 { + // Not a sharded file, return empty to indicate single file + return nil + } + + prefix := matches[1] + totalShards, err := strconv.Atoi(matches[3]) + if err != nil { + return nil + } + + dir := filepath.Dir(path) + var shards []string + + // Look for all shards in the same directory + for i := 1; i <= totalShards; i++ { + shardName := fmt.Sprintf("%s-%05d-of-%05d.safetensors", prefix, i, totalShards) + shardPath := filepath.Join(dir, shardName) + + // Check if the file exists + if _, err := os.Stat(shardPath); err == nil { + shards = append(shards, shardPath) + } + } + + // Only return if we found all expected shards + if len(shards) == totalShards { + // Shards are already in order due to sequential loop + return shards + } + + return nil +} + func configFromFiles(paths []string) types.Config { // Extract basic metadata from file paths // This is a simplified version - in production, you might want to From 50814adc018ac3e548bc6b6c022dff1cafa8c9e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Tue, 30 Sep 2025 15:32:38 +0200 Subject: [PATCH 04/19] feat: add regex pattern for safetensors shard filename matching --- pkg/distribution/internal/safetensors/create.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pkg/distribution/internal/safetensors/create.go b/pkg/distribution/internal/safetensors/create.go index 615abb8d9..59bc7e765 100644 --- a/pkg/distribution/internal/safetensors/create.go +++ b/pkg/distribution/internal/safetensors/create.go @@ -15,6 +15,11 @@ import ( "github.com/docker/model-runner/pkg/distribution/types" ) +var ( + // shardPattern matches safetensors shard filenames like "model-00001-of-00003.safetensors" + shardPattern = regexp.MustCompile(`^(.+)-(\d{5})-of-(\d{5})\.safetensors$`) +) + // NewModel creates a new safetensors model from one or more safetensors files // If a sharded model pattern is detected (e.g., model-00001-of-00002.safetensors), // it will auto-discover all related shards @@ -92,11 +97,8 @@ func NewModelWithConfigArchive(safetensorsPaths []string, configArchivePath stri // It looks for the pattern: -XXXXX-of-YYYYY.safetensors // Returns an empty slice if no shards are found or if it's a single file func discoverSafetensorsShards(path string) []string { - // Pattern: model-00001-of-00003.safetensors - pattern := regexp.MustCompile(`^(.+)-(\d{5})-of-(\d{5})\.safetensors$`) - baseName := filepath.Base(path) - matches := pattern.FindStringSubmatch(baseName) + matches := shardPattern.FindStringSubmatch(baseName) if len(matches) != 4 { // Not a sharded file, return empty to indicate single file From 7a65e5bae73ab130d8312b6848d6bea144994104 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Wed, 1 Oct 2025 15:07:54 +0200 Subject: [PATCH 05/19] feat: enhance security with path validation to prevent directory traversal attacks --- pkg/distribution/internal/bundle/unpack.go | 96 +++++++++++++++---- .../internal/safetensors/create.go | 3 + 2 files changed, 83 insertions(+), 16 deletions(-) diff --git a/pkg/distribution/internal/bundle/unpack.go b/pkg/distribution/internal/bundle/unpack.go index dde9a6a71..0f83ba01f 100644 --- a/pkg/distribution/internal/bundle/unpack.go +++ b/pkg/distribution/internal/bundle/unpack.go @@ -7,7 +7,6 @@ import ( "io" "os" "path/filepath" - "strings" "github.com/docker/model-runner/pkg/distribution/types" ggcrtypes "github.com/google/go-containerregistry/pkg/v1/types" @@ -217,6 +216,40 @@ func unpackConfigArchive(bundle *Bundle, mdl types.Model) error { return nil } +// validatePathWithinDirectory checks if targetPath is within baseDir to prevent directory traversal attacks. +// It uses filepath.IsLocal() which is available in Go 1.20+ and provides robust security against +// various directory traversal attempts including edge cases like empty paths, ".", "..", symbolic links, etc. +func validatePathWithinDirectory(baseDir, targetPath string) error { + // Get absolute path of base directory + absBaseDir, err := filepath.Abs(baseDir) + if err != nil { + return fmt.Errorf("get absolute base directory path: %w", err) + } + + // Construct the target path within base directory + target := filepath.Join(absBaseDir, targetPath) + + // Get absolute path of target + absTarget, err := filepath.Abs(target) + if err != nil { + return fmt.Errorf("get absolute target path: %w", err) + } + + // Get relative path from base to target + rel, err := filepath.Rel(absBaseDir, absTarget) + if err != nil { + return fmt.Errorf("compute relative path: %w", err) + } + + // Use filepath.IsLocal() to check if the relative path is local (doesn't escape baseDir) + // This handles all edge cases including empty strings, ".", "..", symlinks, etc. + if !filepath.IsLocal(rel) { + return fmt.Errorf("invalid entry %q: path attempts to escape destination directory", targetPath) + } + + return nil +} + func extractTarArchive(archivePath, destDir string) error { // Open the tar file file, err := os.Open(archivePath) @@ -225,6 +258,12 @@ func extractTarArchive(archivePath, destDir string) error { } defer file.Close() + // Get absolute path of destination directory for security checks + absDestDir, err := filepath.Abs(destDir) + if err != nil { + return fmt.Errorf("get absolute destination path: %w", err) + } + // Create tar reader tr := tar.NewReader(file) @@ -238,35 +277,60 @@ func extractTarArchive(archivePath, destDir string) error { return fmt.Errorf("read tar header: %w", err) } - // Clean and validate the path to prevent directory traversal - target := filepath.Join(destDir, header.Name) - if !strings.HasPrefix(filepath.Clean(target), filepath.Clean(destDir)+string(os.PathSeparator)) { - return fmt.Errorf("invalid tar entry: %s attempts to escape destination directory", header.Name) + // Validate the target path to prevent directory traversal + if err := validatePathWithinDirectory(absDestDir, header.Name); err != nil { + return err } + // Construct the validated target path + absTarget := filepath.Join(absDestDir, header.Name) + // Process based on header type switch header.Typeflag { case tar.TypeDir: // Create directory - if err := os.MkdirAll(target, os.FileMode(header.Mode)); err != nil { - return fmt.Errorf("create directory %s: %w", target, err) + if err := os.MkdirAll(absTarget, os.FileMode(header.Mode)); err != nil { + return fmt.Errorf("create directory %s: %w", absTarget, err) } case tar.TypeReg: // Extract regular file - if err := extractFile(tr, target, os.FileMode(header.Mode)); err != nil { - return fmt.Errorf("extract file %s: %w", target, err) + if err := extractFile(tr, absTarget, os.FileMode(header.Mode)); err != nil { + return fmt.Errorf("extract file %s: %w", absTarget, err) } case tar.TypeSymlink: - // Handle symlinks safely - linkTarget := filepath.Join(filepath.Dir(target), header.Linkname) - if !strings.HasPrefix(filepath.Clean(linkTarget), filepath.Clean(destDir)+string(os.PathSeparator)) { - // Skip symlinks that would escape the destination directory - continue + // Handle symlinks safely - validate where the symlink will actually point after resolution. + // Symlinks are resolved relative to their parent directory, not the base directory. + // We must validate the final resolved absolute path to prevent directory traversal. + + // Calculate the symlink's parent directory (where it will be created) + symlinkParent := filepath.Dir(absTarget) + + // Resolve the symlink target relative to the symlink's parent directory + // This gives us where the symlink will actually point when followed + resolvedTarget := filepath.Join(symlinkParent, header.Linkname) + + // Get the absolute path of where the symlink will point + absResolvedTarget, err := filepath.Abs(resolvedTarget) + if err != nil { + return fmt.Errorf("resolve symlink target for %q: %w", header.Name, err) + } + + // Validate that the resolved absolute path stays within the destination directory + rel, err := filepath.Rel(absDestDir, absResolvedTarget) + if err != nil { + return fmt.Errorf("validate symlink target for %q: %w", header.Name, err) } - if err := os.Symlink(header.Linkname, target); err != nil { - return fmt.Errorf("create symlink %s: %w", target, err) + + // Use filepath.IsLocal() to ensure the symlink target doesn't escape the base directory + if !filepath.IsLocal(rel) { + return fmt.Errorf("invalid symlink %q: target %q", + header.Name, header.Linkname) + } + + if err := os.Symlink(header.Linkname, absTarget); err != nil { + return fmt.Errorf("create symlink %s: %w", absTarget, err) } default: diff --git a/pkg/distribution/internal/safetensors/create.go b/pkg/distribution/internal/safetensors/create.go index 59bc7e765..46e19fd1b 100644 --- a/pkg/distribution/internal/safetensors/create.go +++ b/pkg/distribution/internal/safetensors/create.go @@ -17,6 +17,9 @@ import ( var ( // shardPattern matches safetensors shard filenames like "model-00001-of-00003.safetensors" + // This pattern assumes 5-digit zero-padded numbering (e.g., 00001-of-00003), which is + // the most common format used by popular model repositories. + // The pattern enforces consistent padding width for both the shard number and total count. shardPattern = regexp.MustCompile(`^(.+)-(\d{5})-of-(\d{5})\.safetensors$`) ) From 68a8f28d329c2e8960fa2103428232faba70fb1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Wed, 1 Oct 2025 15:30:09 +0200 Subject: [PATCH 06/19] feat: skip symlinks in model distribution to prevent directory traversal attacks --- pkg/distribution/internal/bundle/unpack.go | 36 +++------------------- 1 file changed, 4 insertions(+), 32 deletions(-) diff --git a/pkg/distribution/internal/bundle/unpack.go b/pkg/distribution/internal/bundle/unpack.go index 0f83ba01f..5d0533ae2 100644 --- a/pkg/distribution/internal/bundle/unpack.go +++ b/pkg/distribution/internal/bundle/unpack.go @@ -300,38 +300,10 @@ func extractTarArchive(archivePath, destDir string) error { } case tar.TypeSymlink: - // Handle symlinks safely - validate where the symlink will actually point after resolution. - // Symlinks are resolved relative to their parent directory, not the base directory. - // We must validate the final resolved absolute path to prevent directory traversal. - - // Calculate the symlink's parent directory (where it will be created) - symlinkParent := filepath.Dir(absTarget) - - // Resolve the symlink target relative to the symlink's parent directory - // This gives us where the symlink will actually point when followed - resolvedTarget := filepath.Join(symlinkParent, header.Linkname) - - // Get the absolute path of where the symlink will point - absResolvedTarget, err := filepath.Abs(resolvedTarget) - if err != nil { - return fmt.Errorf("resolve symlink target for %q: %w", header.Name, err) - } - - // Validate that the resolved absolute path stays within the destination directory - rel, err := filepath.Rel(absDestDir, absResolvedTarget) - if err != nil { - return fmt.Errorf("validate symlink target for %q: %w", header.Name, err) - } - - // Use filepath.IsLocal() to ensure the symlink target doesn't escape the base directory - if !filepath.IsLocal(rel) { - return fmt.Errorf("invalid symlink %q: target %q", - header.Name, header.Linkname) - } - - if err := os.Symlink(header.Linkname, absTarget); err != nil { - return fmt.Errorf("create symlink %s: %w", absTarget, err) - } + // Skip symlinks - not needed for model distribution + // Symlinks could enable directory traversal attacks even with validation + // Model archives should only contain regular files and directories + continue default: // Skip other types (block devices, char devices, FIFOs, etc.) From 6a427c992d13814bbfffd26347bb6d27353c89ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Wed, 1 Oct 2025 15:45:57 +0200 Subject: [PATCH 07/19] feat: update config loading to use allPaths for model creation --- pkg/distribution/internal/safetensors/create.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/distribution/internal/safetensors/create.go b/pkg/distribution/internal/safetensors/create.go index 46e19fd1b..e63fb6afb 100644 --- a/pkg/distribution/internal/safetensors/create.go +++ b/pkg/distribution/internal/safetensors/create.go @@ -57,7 +57,7 @@ func NewModel(paths []string) (*Model, error) { created := time.Now() return &Model{ configFile: types.ConfigFile{ - Config: configFromFiles(paths), + Config: configFromFiles(allPaths), Descriptor: types.Descriptor{ Created: &created, }, From af3e40ed74ee97427709ebd27a77e511f7bf3bfa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Wed, 1 Oct 2025 15:47:11 +0200 Subject: [PATCH 08/19] feat: improve error handling for missing config archive in unpackConfigArchive --- pkg/distribution/internal/bundle/unpack.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pkg/distribution/internal/bundle/unpack.go b/pkg/distribution/internal/bundle/unpack.go index 5d0533ae2..8f6075296 100644 --- a/pkg/distribution/internal/bundle/unpack.go +++ b/pkg/distribution/internal/bundle/unpack.go @@ -198,7 +198,11 @@ func unpackSafetensors(bundle *Bundle, mdl types.Model) error { func unpackConfigArchive(bundle *Bundle, mdl types.Model) error { archivePath, err := mdl.ConfigArchivePath() if err != nil { - return nil // no config archive + // Only suppress "not found" error, propagate others + if os.IsNotExist(err) { + return nil // no config archive + } + return fmt.Errorf("get config archive path: %w", err) } // Create config directory From 68abdd92c2443a39d46aaaea098d5ea6b48e072e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Wed, 1 Oct 2025 15:57:55 +0200 Subject: [PATCH 09/19] feat: prevent duplicate config archive layers during model creation --- pkg/distribution/internal/safetensors/create.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pkg/distribution/internal/safetensors/create.go b/pkg/distribution/internal/safetensors/create.go index e63fb6afb..5fe453109 100644 --- a/pkg/distribution/internal/safetensors/create.go +++ b/pkg/distribution/internal/safetensors/create.go @@ -79,6 +79,14 @@ func NewModelWithConfigArchive(safetensorsPaths []string, configArchivePath stri // Add config archive layer if configArchivePath != "" { + // Check if a config archive layer already exists + for _, layer := range model.layers { + mediaType, err := layer.MediaType() + if err == nil && mediaType == types.MediaTypeVLLMConfigArchive { + return nil, fmt.Errorf("model already has a config archive layer") + } + } + configLayer, err := partial.NewLayer(configArchivePath, types.MediaTypeVLLMConfigArchive) if err != nil { return nil, fmt.Errorf("create config archive layer from %q: %w", configArchivePath, err) From a155d9511e8d9e428c5c457553f0d6f463a992f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Wed, 1 Oct 2025 16:17:24 +0200 Subject: [PATCH 10/19] feat: update packaging command in Makefile for model distribution --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index d070f9f1c..eb464354e 100644 --- a/Makefile +++ b/Makefile @@ -86,7 +86,7 @@ mdl-pull: model-distribution-tool mdl-package: model-distribution-tool @echo "Packaging model $(SOURCE) to $(TAG)..." - ./$(MDL_TOOL_NAME) --store-path $(STORE_PATH) package $(SOURCE) --tag $(TAG) $(if $(LICENSE),--licenses $(LICENSE)) + ./$(MDL_TOOL_NAME) package --tag $(TAG) $(if $(LICENSE),--licenses $(LICENSE)) $(SOURCE) mdl-list: model-distribution-tool @echo "Listing models..." From 57175cf74cf72fb54ed0d36d8c4f5eb9ff725a21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Wed, 1 Oct 2025 16:25:53 +0200 Subject: [PATCH 11/19] feat: update model file handling to differentiate between GGUF and safetensors formats --- cmd/mdltool/main.go | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/cmd/mdltool/main.go b/cmd/mdltool/main.go index 34480e884..9d1a64e25 100644 --- a/cmd/mdltool/main.go +++ b/cmd/mdltool/main.go @@ -234,20 +234,8 @@ func cmdPackage(args []string) int { fmt.Printf("Created temporary config archive from directory\n") } } else { - // Handle single file (GGUF or safetensors) - if strings.HasSuffix(strings.ToLower(source), ".safetensors") { - isSafetensors = true - safetensorsPaths = []string{source} - fmt.Println("Detected safetensors model file") - - // Auto-discover configs from file's directory - parentDir := filepath.Dir(source) - _, configArchive, err = packageFromDirectory(parentDir) - if err == nil && configArchive != "" { - defer os.Remove(configArchive) - fmt.Printf("Auto-discovered config files from %s\n", parentDir) - } - } else if strings.HasSuffix(strings.ToLower(source), ".gguf") { + // Handle single file (GGUF model) + if strings.HasSuffix(strings.ToLower(source), ".gguf") { isSafetensors = false fmt.Println("Detected GGUF model file") } else { From ffd15fd77c36bca459b5d3c67b38a24868057150 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Wed, 1 Oct 2025 16:46:17 +0200 Subject: [PATCH 12/19] feat: remove config directory handling from bundle and unpack logic --- pkg/distribution/internal/bundle/bundle.go | 9 --------- pkg/distribution/internal/bundle/unpack.go | 9 +-------- pkg/distribution/types/model.go | 1 - 3 files changed, 1 insertion(+), 18 deletions(-) diff --git a/pkg/distribution/internal/bundle/bundle.go b/pkg/distribution/internal/bundle/bundle.go index d117d63ca..4f58454d1 100644 --- a/pkg/distribution/internal/bundle/bundle.go +++ b/pkg/distribution/internal/bundle/bundle.go @@ -12,7 +12,6 @@ type Bundle struct { mmprojPath string ggufFile string // path to GGUF file (first shard when model is split among files) safetensorsFile string // path to safetensors file (first shard when model is split among files) - configDir string // path to extracted config directory runtimeConfig types.Config chatTemplatePath string } @@ -55,14 +54,6 @@ func (b *Bundle) SafetensorsPath() string { return filepath.Join(b.dir, b.safetensorsFile) } -// ConfigDir returns the path to the extracted config directory or "" if none is present. -func (b *Bundle) ConfigDir() string { - if b.configDir == "" { - return "" - } - return filepath.Join(b.dir, b.configDir) -} - // RuntimeConfig returns config that should be respected by the backend at runtime. func (b *Bundle) RuntimeConfig() types.Config { return b.runtimeConfig diff --git a/pkg/distribution/internal/bundle/unpack.go b/pkg/distribution/internal/bundle/unpack.go index 8f6075296..40bf5ab67 100644 --- a/pkg/distribution/internal/bundle/unpack.go +++ b/pkg/distribution/internal/bundle/unpack.go @@ -205,18 +205,11 @@ func unpackConfigArchive(bundle *Bundle, mdl types.Model) error { return fmt.Errorf("get config archive path: %w", err) } - // Create config directory - configDir := filepath.Join(bundle.dir, "configs") - if err := os.MkdirAll(configDir, 0755); err != nil { - return fmt.Errorf("create config directory: %w", err) - } - // Extract the tar archive - if err := extractTarArchive(archivePath, configDir); err != nil { + if err := extractTarArchive(archivePath, bundle.dir); err != nil { return fmt.Errorf("extract config archive: %w", err) } - bundle.configDir = "configs" return nil } diff --git a/pkg/distribution/types/model.go b/pkg/distribution/types/model.go index 4200ba619..ca7592ceb 100644 --- a/pkg/distribution/types/model.go +++ b/pkg/distribution/types/model.go @@ -27,7 +27,6 @@ type ModelBundle interface { RootDir() string GGUFPath() string SafetensorsPath() string - ConfigDir() string ChatTemplatePath() string MMPROJPath() string RuntimeConfig() Config From 5829540aa6e3d34984401cf0f89db6b2f3c0fe26 Mon Sep 17 00:00:00 2001 From: Ignasi Date: Wed, 1 Oct 2025 16:56:45 +0200 Subject: [PATCH 13/19] Update pkg/distribution/internal/bundle/unpack.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- pkg/distribution/internal/bundle/unpack.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pkg/distribution/internal/bundle/unpack.go b/pkg/distribution/internal/bundle/unpack.go index 40bf5ab67..0f2ff6ccd 100644 --- a/pkg/distribution/internal/bundle/unpack.go +++ b/pkg/distribution/internal/bundle/unpack.go @@ -198,10 +198,6 @@ func unpackSafetensors(bundle *Bundle, mdl types.Model) error { func unpackConfigArchive(bundle *Bundle, mdl types.Model) error { archivePath, err := mdl.ConfigArchivePath() if err != nil { - // Only suppress "not found" error, propagate others - if os.IsNotExist(err) { - return nil // no config archive - } return fmt.Errorf("get config archive path: %w", err) } From 902cb3c2ba40c9a119ce670408e234f4f4d62a38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Wed, 1 Oct 2025 16:51:58 +0200 Subject: [PATCH 14/19] feat: ensure reproducibility by sorting safetensors and config files before archiving --- cmd/mdltool/main.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/cmd/mdltool/main.go b/cmd/mdltool/main.go index 9d1a64e25..7567366e0 100644 --- a/cmd/mdltool/main.go +++ b/cmd/mdltool/main.go @@ -8,6 +8,7 @@ import ( "io" "os" "path/filepath" + "sort" "strings" "github.com/docker/model-runner/pkg/distribution/builder" @@ -611,8 +612,14 @@ func packageFromDirectory(dirPath string) (safetensorsPaths []string, tempConfig return nil, "", fmt.Errorf("no safetensors files found in directory: %s", dirPath) } + // Sort to ensure reproducible artifacts + sort.Strings(safetensorsPaths) + // Create temporary tar archive with config files if any exist if len(configFiles) > 0 { + // Sort config files for reproducible tar archive + sort.Strings(configFiles) + tempConfigArchive, err = createTempConfigArchive(configFiles) if err != nil { return nil, "", fmt.Errorf("create config archive: %w", err) From 703344e02409106031f8be4dd90b22e3a5305b57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Wed, 1 Oct 2025 16:56:20 +0200 Subject: [PATCH 15/19] feat: simplify safetensors model creation by removing config archive dependency --- cmd/mdltool/main.go | 17 +++++++++++------ pkg/distribution/builder/builder.go | 19 ++++++++++++++++--- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/cmd/mdltool/main.go b/cmd/mdltool/main.go index 7567366e0..490551a52 100644 --- a/cmd/mdltool/main.go +++ b/cmd/mdltool/main.go @@ -283,16 +283,21 @@ func cmdPackage(args []string) int { // Create builder based on model type var b *builder.Builder if isSafetensors { + fmt.Println("Creating safetensors model") + b, err = builder.FromSafetensors(safetensorsPaths) + if err != nil { + fmt.Fprintf(os.Stderr, "Error creating model from safetensors: %v\n", err) + return 1 + } + + // Add config archive if provided if configArchive != "" { - fmt.Printf("Creating safetensors model with config archive: %s\n", configArchive) - b, err = builder.FromSafetensorsWithConfig(safetensorsPaths, configArchive) + fmt.Printf("Adding config archive: %s\n", configArchive) + b, err = b.WithConfigArchive(configArchive) if err != nil { - fmt.Fprintf(os.Stderr, "Error creating model from safetensors with config: %v\n", err) + fmt.Fprintf(os.Stderr, "Error adding config archive: %v\n", err) return 1 } - } else { - fmt.Fprintf(os.Stderr, "Error: config archive is required for safetensors models\n") - return 1 } } else { b, err = builder.FromGGUF(source) diff --git a/pkg/distribution/builder/builder.go b/pkg/distribution/builder/builder.go index f202dcd23..9b6be8c5f 100644 --- a/pkg/distribution/builder/builder.go +++ b/pkg/distribution/builder/builder.go @@ -28,9 +28,9 @@ func FromGGUF(path string) (*Builder, error) { }, nil } -// FromSafetensorsWithConfig returns a *Builder that builds model artifacts from safetensors files with a config archive -func FromSafetensorsWithConfig(safetensorsPaths []string, configArchivePath string) (*Builder, error) { - mdl, err := safetensors.NewModelWithConfigArchive(safetensorsPaths, configArchivePath) +// FromSafetensors returns a *Builder that builds model artifacts from safetensors files +func FromSafetensors(safetensorsPaths []string) (*Builder, error) { + mdl, err := safetensors.NewModel(safetensorsPaths) if err != nil { return nil, err } @@ -80,6 +80,19 @@ func (b *Builder) WithChatTemplateFile(path string) (*Builder, error) { // WithConfigArchive adds a config archive (tar) file to the artifact func (b *Builder) WithConfigArchive(path string) (*Builder, error) { + // Check if config archive already exists + layers, err := b.model.Layers() + if err != nil { + return nil, fmt.Errorf("get model layers: %w", err) + } + + for _, layer := range layers { + mediaType, err := layer.MediaType() + if err == nil && mediaType == types.MediaTypeVLLMConfigArchive { + return nil, fmt.Errorf("model already has a config archive layer") + } + } + configLayer, err := partial.NewLayer(path, types.MediaTypeVLLMConfigArchive) if err != nil { return nil, fmt.Errorf("config archive layer from %q: %w", path, err) From 9ad64324a39cab68d2d2634c5cef2c387434122d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Wed, 1 Oct 2025 17:01:22 +0200 Subject: [PATCH 16/19] simplify --- .../internal/safetensors/create.go | 42 +--------- .../internal/safetensors/model_test.go | 79 ------------------- 2 files changed, 3 insertions(+), 118 deletions(-) delete mode 100644 pkg/distribution/internal/safetensors/model_test.go diff --git a/pkg/distribution/internal/safetensors/create.go b/pkg/distribution/internal/safetensors/create.go index 5fe453109..85b117713 100644 --- a/pkg/distribution/internal/safetensors/create.go +++ b/pkg/distribution/internal/safetensors/create.go @@ -6,7 +6,6 @@ import ( "path/filepath" "regexp" "strconv" - "strings" "time" v1 "github.com/google/go-containerregistry/pkg/v1" @@ -57,7 +56,7 @@ func NewModel(paths []string) (*Model, error) { created := time.Now() return &Model{ configFile: types.ConfigFile{ - Config: configFromFiles(allPaths), + Config: configFromFiles(), Descriptor: types.Descriptor{ Created: &created, }, @@ -145,43 +144,8 @@ func discoverSafetensorsShards(path string) []string { return nil } -func configFromFiles(paths []string) types.Config { - // Extract basic metadata from file paths - // This is a simplified version - in production, you might want to - // parse safetensors headers for more detailed metadata - - var totalFiles int - var architecture string - - if len(paths) > 0 { - totalFiles = len(paths) - // Try to extract architecture from filename - baseName := filepath.Base(paths[0]) - baseName = strings.ToLower(baseName) - - // Common patterns in model filenames - if strings.Contains(baseName, "llama") { - architecture = "llama" - } else if strings.Contains(baseName, "mistral") { - architecture = "mistral" - } else if strings.Contains(baseName, "qwen") { - architecture = "qwen" - } else if strings.Contains(baseName, "gemma") { - architecture = "gemma" - } - } - - safetensorsMetadata := map[string]string{ - "total_files": fmt.Sprintf("%d", totalFiles), - } - - if architecture != "" { - safetensorsMetadata["architecture"] = architecture - } - +func configFromFiles() types.Config { return types.Config{ - Format: types.FormatSafetensors, - Architecture: architecture, - Safetensors: safetensorsMetadata, + Format: types.FormatSafetensors, } } diff --git a/pkg/distribution/internal/safetensors/model_test.go b/pkg/distribution/internal/safetensors/model_test.go deleted file mode 100644 index 9f8dd663f..000000000 --- a/pkg/distribution/internal/safetensors/model_test.go +++ /dev/null @@ -1,79 +0,0 @@ -package safetensors - -import ( - "testing" - - "github.com/docker/model-runner/pkg/distribution/types" -) - -func TestNewModel(t *testing.T) { - // Create a test safetensors model - // Note: In a real test, you would use actual safetensors files - // For now, we'll test with dummy paths to verify the structure - - t.Run("single file", func(t *testing.T) { - paths := []string{"test-model.safetensors"} - model, err := NewModel(paths) - if err == nil { - t.Error("Expected error for non-existent file, got nil") - } - // The error is expected since the file doesn't exist - // In a real test, we'd use test fixtures - _ = model - }) - - t.Run("empty paths", func(t *testing.T) { - var paths []string - _, err := NewModel(paths) - if err == nil { - t.Error("Expected error for empty paths, got nil") - } - }) - - t.Run("config extraction", func(t *testing.T) { - config := configFromFiles([]string{"llama-7b-model.safetensors"}) - if config.Format != types.FormatSafetensors { - t.Errorf("Expected format %s, got %s", types.FormatSafetensors, config.Format) - } - if config.Architecture != "llama" { - t.Errorf("Expected architecture 'llama', got %s", config.Architecture) - } - if config.Safetensors["total_files"] != "1" { - t.Errorf("Expected total_files '1', got %s", config.Safetensors["total_files"]) - } - }) - - t.Run("architecture detection", func(t *testing.T) { - tests := []struct { - filename string - expected string - }{ - {"mistral-7b-instruct.safetensors", "mistral"}, - {"qwen2-vl-7b.safetensors", "qwen"}, - {"gemma-2b.safetensors", "gemma"}, - {"unknown-model.safetensors", ""}, - } - - for _, tt := range tests { - config := configFromFiles([]string{tt.filename}) - if config.Architecture != tt.expected { - t.Errorf("For file %s, expected architecture %q, got %q", - tt.filename, tt.expected, config.Architecture) - } - } - }) -} - -func TestNewModelWithConfigArchive(t *testing.T) { - // Test that the function properly handles config archives - // In a real test, we'd use actual files - - safetensorsPaths := []string{"model.safetensors"} - configPath := "config.tar" - - _, err := NewModelWithConfigArchive(safetensorsPaths, configPath) - if err == nil { - t.Error("Expected error for non-existent files, got nil") - } - // The error is expected since the files don't exist -} From 2db1646aa89c684d494017a732a04d1f28837a65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Wed, 1 Oct 2025 17:27:23 +0200 Subject: [PATCH 17/19] feat: enhance shard discovery by adding error handling for incomplete sets --- .../internal/safetensors/create.go | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/pkg/distribution/internal/safetensors/create.go b/pkg/distribution/internal/safetensors/create.go index 85b117713..f641738c5 100644 --- a/pkg/distribution/internal/safetensors/create.go +++ b/pkg/distribution/internal/safetensors/create.go @@ -31,9 +31,12 @@ func NewModel(paths []string) (*Model, error) { } // Auto-discover shards if the first path matches the shard pattern - allPaths := discoverSafetensorsShards(paths[0]) + allPaths, err := discoverSafetensorsShards(paths[0]) + if err != nil { + return nil, fmt.Errorf("discover safetensors shards: %w", err) + } if len(allPaths) == 0 { - // No shards found, use provided paths as-is + // Not a sharded file, use provided paths as-is allPaths = paths } @@ -105,20 +108,21 @@ func NewModelWithConfigArchive(safetensorsPaths []string, configArchivePath stri // discoverSafetensorsShards attempts to auto-discover all shards for a given safetensors file // It looks for the pattern: -XXXXX-of-YYYYY.safetensors -// Returns an empty slice if no shards are found or if it's a single file -func discoverSafetensorsShards(path string) []string { +// Returns (nil, nil) for single-file models, (paths, nil) for complete shard sets, +// or (nil, error) for incomplete shard sets +func discoverSafetensorsShards(path string) ([]string, error) { baseName := filepath.Base(path) matches := shardPattern.FindStringSubmatch(baseName) if len(matches) != 4 { - // Not a sharded file, return empty to indicate single file - return nil + // Not a sharded file, return empty slice with no error + return nil, nil } prefix := matches[1] totalShards, err := strconv.Atoi(matches[3]) if err != nil { - return nil + return nil, fmt.Errorf("parse shard count: %w", err) } dir := filepath.Dir(path) @@ -135,13 +139,13 @@ func discoverSafetensorsShards(path string) []string { } } - // Only return if we found all expected shards - if len(shards) == totalShards { - // Shards are already in order due to sequential loop - return shards + // Return error if we didn't find all expected shards + if len(shards) != totalShards { + return nil, fmt.Errorf("incomplete shard set: found %d of %d shards for %s", len(shards), totalShards, baseName) } - return nil + // Shards are already in order due to sequential loop + return shards, nil } func configFromFiles() types.Config { From ecd743e9b5aba20f55b21adb076680107d5611a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ignacio=20L=C3=B3pez=20Luna?= Date: Wed, 1 Oct 2025 17:31:06 +0200 Subject: [PATCH 18/19] feat: remove unused ConfigDir method from fakeBundle --- pkg/inference/backends/llamacpp/llamacpp_config_test.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pkg/inference/backends/llamacpp/llamacpp_config_test.go b/pkg/inference/backends/llamacpp/llamacpp_config_test.go index cfbbe4424..ad30e95d5 100644 --- a/pkg/inference/backends/llamacpp/llamacpp_config_test.go +++ b/pkg/inference/backends/llamacpp/llamacpp_config_test.go @@ -283,10 +283,6 @@ func (f *fakeBundle) SafetensorsPath() string { return "" } -func (f *fakeBundle) ConfigDir() string { - return "" -} - func (f *fakeBundle) RuntimeConfig() types.Config { return f.config } From 3e7213e45ba7045b1146ecc69734dc483d1b2c97 Mon Sep 17 00:00:00 2001 From: Ignasi Date: Wed, 1 Oct 2025 17:33:17 +0200 Subject: [PATCH 19/19] Update pkg/distribution/internal/bundle/unpack.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- pkg/distribution/internal/bundle/unpack.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/distribution/internal/bundle/unpack.go b/pkg/distribution/internal/bundle/unpack.go index 0f2ff6ccd..a70b952a6 100644 --- a/pkg/distribution/internal/bundle/unpack.go +++ b/pkg/distribution/internal/bundle/unpack.go @@ -210,7 +210,7 @@ func unpackConfigArchive(bundle *Bundle, mdl types.Model) error { } // validatePathWithinDirectory checks if targetPath is within baseDir to prevent directory traversal attacks. -// It uses filepath.IsLocal() which is available in Go 1.20+ and provides robust security against +// It uses filepath.IsLocal() to provide robust security against // various directory traversal attempts including edge cases like empty paths, ".", "..", symbolic links, etc. func validatePathWithinDirectory(baseDir, targetPath string) error { // Get absolute path of base directory