Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
f10a5da
feat: add support for safetensors model format and related functionality
ilopezluna Sep 30, 2025
74cd834
implement extraction logic
ilopezluna Sep 30, 2025
bdb03a8
feat: enhance safetensors support with sharded model discovery and pa…
ilopezluna Sep 30, 2025
50814ad
feat: add regex pattern for safetensors shard filename matching
ilopezluna Sep 30, 2025
7a65e5b
feat: enhance security with path validation to prevent directory trav…
ilopezluna Oct 1, 2025
68a8f28
feat: skip symlinks in model distribution to prevent directory traver…
ilopezluna Oct 1, 2025
6a427c9
feat: update config loading to use allPaths for model creation
ilopezluna Oct 1, 2025
af3e40e
feat: improve error handling for missing config archive in unpackConf…
ilopezluna Oct 1, 2025
68abdd9
feat: prevent duplicate config archive layers during model creation
ilopezluna Oct 1, 2025
a155d95
feat: update packaging command in Makefile for model distribution
ilopezluna Oct 1, 2025
57175cf
feat: update model file handling to differentiate between GGUF and sa…
ilopezluna Oct 1, 2025
ffd15fd
feat: remove config directory handling from bundle and unpack logic
ilopezluna Oct 1, 2025
5829540
Update pkg/distribution/internal/bundle/unpack.go
ilopezluna Oct 1, 2025
902cb3c
feat: ensure reproducibility by sorting safetensors and config files …
ilopezluna Oct 1, 2025
703344e
feat: simplify safetensors model creation by removing config archive …
ilopezluna Oct 1, 2025
9ad6432
simplify
ilopezluna Oct 1, 2025
2db1646
feat: enhance shard discovery by adding error handling for incomplete…
ilopezluna Oct 1, 2025
ecd743e
feat: remove unused ConfigDir method from fakeBundle
ilopezluna Oct 1, 2025
3e7213e
Update pkg/distribution/internal/bundle/unpack.go
ilopezluna Oct 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ mdl-pull: model-distribution-tool

mdl-package: model-distribution-tool
@echo "Packaging model $(SOURCE) to $(TAG)..."
./$(MDL_TOOL_NAME) --store-path $(STORE_PATH) package $(SOURCE) --tag $(TAG) $(if $(LICENSE),--licenses $(LICENSE))
./$(MDL_TOOL_NAME) package --tag $(TAG) $(if $(LICENSE),--licenses $(LICENSE)) $(SOURCE)

mdl-list: model-distribution-tool
@echo "Listing models..."
Expand Down Expand Up @@ -140,5 +140,6 @@ help:
@echo "Model distribution tool examples:"
@echo " make mdl-pull TAG=registry.example.com/models/llama:v1.0"
@echo " make mdl-package SOURCE=./model.gguf TAG=registry.example.com/models/llama:v1.0 LICENSE=./license.txt"
@echo " make mdl-package SOURCE=./qwen2.5-3b-instruct TAG=registry.example.com/models/qwen:v1.0"
@echo " make mdl-list"
@echo " make mdl-rm TAG=registry.example.com/models/llama:v1.0"
243 changes: 214 additions & 29 deletions cmd/mdltool/main.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package main

import (
"archive/tar"
"context"
"flag"
"fmt"
"io"
"os"
"path/filepath"
"sort"
"strings"

"github.com/docker/model-runner/pkg/distribution/builder"
Expand Down Expand Up @@ -178,7 +181,12 @@ func cmdPackage(args []string) int {
fs.StringVar(&chatTemplate, "chat-template", "", "Jinja chat template file")

fs.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool package [OPTIONS] <path-to-gguf>\n\n")
fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool package [OPTIONS] <path-to-model-or-directory>\n\n")
fmt.Fprintf(os.Stderr, "Examples:\n")
fmt.Fprintf(os.Stderr, " # GGUF model:\n")
fmt.Fprintf(os.Stderr, " model-distribution-tool package model.gguf --tag registry/model:tag\n\n")
fmt.Fprintf(os.Stderr, " # Safetensors model:\n")
fmt.Fprintf(os.Stderr, " model-distribution-tool package ./qwen-model-dir --tag registry/model:tag\n\n")
fmt.Fprintf(os.Stderr, "Options:\n")
fs.PrintDefaults()
}
Expand All @@ -189,32 +197,62 @@ func cmdPackage(args []string) int {
}
args = fs.Args()

// Get the source from positional argument
if len(args) < 1 {
fmt.Fprintf(os.Stderr, "Error: missing arguments\n")
fmt.Fprintf(os.Stderr, "Error: no model file or directory specified\n")
fs.Usage()
return 1
}
if file == "" && tag == "" {
fmt.Fprintf(os.Stderr, "Error: one of --file or --tag is required\n")
fs.Usage()

source := args[0]
var isSafetensors bool
var configArchive string // For safetensors config
var safetensorsPaths []string // For safetensors model files

// Check if source exists
sourceInfo, err := os.Stat(source)
if os.IsNotExist(err) {
fmt.Fprintf(os.Stderr, "Error: source does not exist: %s\n", source)
return 1
}

source := args[0]
ctx := context.Background()
// Handle directory-based packaging (for safetensors models)
if sourceInfo.IsDir() {
fmt.Printf("Detected directory, scanning for safetensors model...\n")
var err error
safetensorsPaths, configArchive, err = packageFromDirectory(source)
if err != nil {
fmt.Fprintf(os.Stderr, "Error scanning directory: %v\n", err)
return 1
}

// Check if source file exists
if _, err := os.Stat(source); os.IsNotExist(err) {
fmt.Fprintf(os.Stderr, "Error: source file does not exist: %s\n", source)
return 1
isSafetensors = true
fmt.Printf("Found %d safetensors file(s)\n", len(safetensorsPaths))

// Clean up temp config archive when done
if configArchive != "" {
defer os.Remove(configArchive)
fmt.Printf("Created temporary config archive from directory\n")
}
} else {
// Handle single file (GGUF model)
if strings.HasSuffix(strings.ToLower(source), ".gguf") {
isSafetensors = false
fmt.Println("Detected GGUF model file")
} else {
fmt.Fprintf(os.Stderr, "Warning: could not determine model type for: %s\n", source)
fmt.Fprintf(os.Stderr, "Assuming GGUF format.\n")
}
}

// Check if source file is a GGUF file
if !strings.HasSuffix(strings.ToLower(source), ".gguf") {
fmt.Fprintf(os.Stderr, "Warning: source file does not have .gguf extension: %s\n", source)
fmt.Fprintf(os.Stderr, "Continuing anyway, but this may cause issues.\n")
if file == "" && tag == "" {
fmt.Fprintf(os.Stderr, "Error: one of --file or --tag is required\n")
fs.Usage()
return 1
}

ctx := context.Background()

// Prepare registry client options
registryClientOpts := []registry.ClientOption{
registry.WithUserAgent("model-distribution-tool/" + version),
Expand All @@ -230,31 +268,49 @@ func cmdPackage(args []string) int {
// Create registry client once with all options
registryClient := registry.NewClient(registryClientOpts...)

var (
target builder.Target
err error
)
var target builder.Target
if file != "" {
target = tarball.NewFileTarget(file)
} else {
var err error
target, err = registryClient.NewTarget(tag)
if err != nil {
fmt.Fprintf(os.Stderr, "Create packaging target: %v\n", err)
return 1
}
}

// Create image with layer
builder, err := builder.FromGGUF(source)
if err != nil {
fmt.Fprintf(os.Stderr, "Error creating model from gguf: %v\n", err)
return 1
// Create builder based on model type
var b *builder.Builder
if isSafetensors {
fmt.Println("Creating safetensors model")
b, err = builder.FromSafetensors(safetensorsPaths)
if err != nil {
fmt.Fprintf(os.Stderr, "Error creating model from safetensors: %v\n", err)
return 1
}

// Add config archive if provided
if configArchive != "" {
fmt.Printf("Adding config archive: %s\n", configArchive)
b, err = b.WithConfigArchive(configArchive)
if err != nil {
fmt.Fprintf(os.Stderr, "Error adding config archive: %v\n", err)
return 1
}
}
} else {
b, err = builder.FromGGUF(source)
if err != nil {
fmt.Fprintf(os.Stderr, "Error creating model from gguf: %v\n", err)
return 1
}
}

// Add all license files as layers
for _, path := range licensePaths {
fmt.Println("Adding license file:", path)
builder, err = builder.WithLicense(path)
b, err = b.WithLicense(path)
if err != nil {
fmt.Fprintf(os.Stderr, "Error adding license layer for %s: %v\n", path, err)
return 1
Expand All @@ -263,12 +319,12 @@ func cmdPackage(args []string) int {

if contextSize > 0 {
fmt.Println("Setting context size:", contextSize)
builder = builder.WithContextSize(contextSize)
b = b.WithContextSize(contextSize)
}

if mmproj != "" {
fmt.Println("Adding multimodal projector file:", mmproj)
builder, err = builder.WithMultimodalProjector(mmproj)
b, err = b.WithMultimodalProjector(mmproj)
if err != nil {
fmt.Fprintf(os.Stderr, "Error adding multimodal projector layer for %s: %v\n", mmproj, err)
return 1
Expand All @@ -277,15 +333,15 @@ func cmdPackage(args []string) int {

if chatTemplate != "" {
fmt.Println("Adding chat template file:", chatTemplate)
builder, err = builder.WithChatTemplateFile(chatTemplate)
b, err = b.WithChatTemplateFile(chatTemplate)
if err != nil {
fmt.Fprintf(os.Stderr, "Error adding chat template layer for %s: %v\n", chatTemplate, err)
return 1
}
}

// Push the image
if err := builder.Build(ctx, target, os.Stdout); err != nil {
if err := b.Build(ctx, target, os.Stdout); err != nil {
fmt.Fprintf(os.Stderr, "Error writing model to registry: %v\n", err)
return 1
}
Expand Down Expand Up @@ -525,3 +581,132 @@ func cmdBundle(client *distribution.Client, args []string) int {
fmt.Fprint(os.Stdout, bundle.RootDir())
return 0
}

// packageFromDirectory scans a directory for safetensors files and config files,
// creating a temporary tar archive of the config files
func packageFromDirectory(dirPath string) (safetensorsPaths []string, tempConfigArchive string, err error) {
// Read directory contents (only top level, no subdirectories)
entries, err := os.ReadDir(dirPath)
if err != nil {
return nil, "", fmt.Errorf("read directory: %w", err)
}

var configFiles []string

for _, entry := range entries {
if entry.IsDir() {
continue // Skip subdirectories
}

name := entry.Name()
fullPath := filepath.Join(dirPath, name)

// Collect safetensors files
if strings.HasSuffix(strings.ToLower(name), ".safetensors") {
safetensorsPaths = append(safetensorsPaths, fullPath)
}

// Collect config files: *.json, merges.txt
if strings.HasSuffix(strings.ToLower(name), ".json") ||
name == "merges.txt" {
configFiles = append(configFiles, fullPath)
}
}

if len(safetensorsPaths) == 0 {
return nil, "", fmt.Errorf("no safetensors files found in directory: %s", dirPath)
}

// Sort to ensure reproducible artifacts
sort.Strings(safetensorsPaths)

// Create temporary tar archive with config files if any exist
if len(configFiles) > 0 {
// Sort config files for reproducible tar archive
sort.Strings(configFiles)

tempConfigArchive, err = createTempConfigArchive(configFiles)
if err != nil {
return nil, "", fmt.Errorf("create config archive: %w", err)
}
}

return safetensorsPaths, tempConfigArchive, nil
}

// createTempConfigArchive creates a temporary tar archive containing the specified config files
func createTempConfigArchive(configFiles []string) (string, error) {
// Create temp file
tmpFile, err := os.CreateTemp("", "vllm-config-*.tar")
if err != nil {
return "", fmt.Errorf("create temp file: %w", err)
}
tmpPath := tmpFile.Name()

// Create tar writer
tw := tar.NewWriter(tmpFile)

// Add each config file to tar (preserving just filename, not full path)
for _, filePath := range configFiles {
// Open the file
file, err := os.Open(filePath)
if err != nil {
tw.Close()
tmpFile.Close()
os.Remove(tmpPath)
return "", fmt.Errorf("open config file %s: %w", filePath, err)
}

// Get file info for tar header
fileInfo, err := file.Stat()
if err != nil {
file.Close()
tw.Close()
tmpFile.Close()
os.Remove(tmpPath)
return "", fmt.Errorf("stat config file %s: %w", filePath, err)
}

// Create tar header (use only basename, not full path)
header := &tar.Header{
Name: filepath.Base(filePath),
Size: fileInfo.Size(),
Mode: int64(fileInfo.Mode()),
ModTime: fileInfo.ModTime(),
}

// Write header
if err := tw.WriteHeader(header); err != nil {
file.Close()
tw.Close()
tmpFile.Close()
os.Remove(tmpPath)
return "", fmt.Errorf("write tar header for %s: %w", filePath, err)
}

// Copy file contents
if _, err := io.Copy(tw, file); err != nil {
file.Close()
tw.Close()
tmpFile.Close()
os.Remove(tmpPath)
return "", fmt.Errorf("write tar content for %s: %w", filePath, err)
}

file.Close()
}

// Close tar writer and file
if err := tw.Close(); err != nil {
tmpFile.Close()
os.Remove(tmpPath)
return "", fmt.Errorf("close tar writer: %w", err)
}

if err := tmpFile.Close(); err != nil {
os.Remove(tmpPath)
return "", fmt.Errorf("close temp file: %w", err)
}

return tmpPath, nil
}
Loading
Loading