Skip to content

Commit de7e2d2

Browse files
authored
Merge pull request #186 from docker/safetensors-as-oci-artifact
Safetensors as OCI Artifact
2 parents fca32cb + 3e7213e commit de7e2d2

File tree

12 files changed

+798
-37
lines changed

12 files changed

+798
-37
lines changed

Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ mdl-pull: model-distribution-tool
8686

8787
mdl-package: model-distribution-tool
8888
@echo "Packaging model $(SOURCE) to $(TAG)..."
89-
./$(MDL_TOOL_NAME) --store-path $(STORE_PATH) package $(SOURCE) --tag $(TAG) $(if $(LICENSE),--licenses $(LICENSE))
89+
./$(MDL_TOOL_NAME) package --tag $(TAG) $(if $(LICENSE),--licenses $(LICENSE)) $(SOURCE)
9090

9191
mdl-list: model-distribution-tool
9292
@echo "Listing models..."
@@ -140,5 +140,6 @@ help:
140140
@echo "Model distribution tool examples:"
141141
@echo " make mdl-pull TAG=registry.example.com/models/llama:v1.0"
142142
@echo " make mdl-package SOURCE=./model.gguf TAG=registry.example.com/models/llama:v1.0 LICENSE=./license.txt"
143+
@echo " make mdl-package SOURCE=./qwen2.5-3b-instruct TAG=registry.example.com/models/qwen:v1.0"
143144
@echo " make mdl-list"
144145
@echo " make mdl-rm TAG=registry.example.com/models/llama:v1.0"

cmd/mdltool/main.go

Lines changed: 214 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
package main
22

33
import (
4+
"archive/tar"
45
"context"
56
"flag"
67
"fmt"
8+
"io"
79
"os"
810
"path/filepath"
11+
"sort"
912
"strings"
1013

1114
"github.com/docker/model-runner/pkg/distribution/builder"
@@ -178,7 +181,12 @@ func cmdPackage(args []string) int {
178181
fs.StringVar(&chatTemplate, "chat-template", "", "Jinja chat template file")
179182

180183
fs.Usage = func() {
181-
fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool package [OPTIONS] <path-to-gguf>\n\n")
184+
fmt.Fprintf(os.Stderr, "Usage: model-distribution-tool package [OPTIONS] <path-to-model-or-directory>\n\n")
185+
fmt.Fprintf(os.Stderr, "Examples:\n")
186+
fmt.Fprintf(os.Stderr, " # GGUF model:\n")
187+
fmt.Fprintf(os.Stderr, " model-distribution-tool package model.gguf --tag registry/model:tag\n\n")
188+
fmt.Fprintf(os.Stderr, " # Safetensors model:\n")
189+
fmt.Fprintf(os.Stderr, " model-distribution-tool package ./qwen-model-dir --tag registry/model:tag\n\n")
182190
fmt.Fprintf(os.Stderr, "Options:\n")
183191
fs.PrintDefaults()
184192
}
@@ -189,32 +197,62 @@ func cmdPackage(args []string) int {
189197
}
190198
args = fs.Args()
191199

200+
// Get the source from positional argument
192201
if len(args) < 1 {
193-
fmt.Fprintf(os.Stderr, "Error: missing arguments\n")
202+
fmt.Fprintf(os.Stderr, "Error: no model file or directory specified\n")
194203
fs.Usage()
195204
return 1
196205
}
197-
if file == "" && tag == "" {
198-
fmt.Fprintf(os.Stderr, "Error: one of --file or --tag is required\n")
199-
fs.Usage()
206+
207+
source := args[0]
208+
var isSafetensors bool
209+
var configArchive string // For safetensors config
210+
var safetensorsPaths []string // For safetensors model files
211+
212+
// Check if source exists
213+
sourceInfo, err := os.Stat(source)
214+
if os.IsNotExist(err) {
215+
fmt.Fprintf(os.Stderr, "Error: source does not exist: %s\n", source)
200216
return 1
201217
}
202218

203-
source := args[0]
204-
ctx := context.Background()
219+
// Handle directory-based packaging (for safetensors models)
220+
if sourceInfo.IsDir() {
221+
fmt.Printf("Detected directory, scanning for safetensors model...\n")
222+
var err error
223+
safetensorsPaths, configArchive, err = packageFromDirectory(source)
224+
if err != nil {
225+
fmt.Fprintf(os.Stderr, "Error scanning directory: %v\n", err)
226+
return 1
227+
}
205228

206-
// Check if source file exists
207-
if _, err := os.Stat(source); os.IsNotExist(err) {
208-
fmt.Fprintf(os.Stderr, "Error: source file does not exist: %s\n", source)
209-
return 1
229+
isSafetensors = true
230+
fmt.Printf("Found %d safetensors file(s)\n", len(safetensorsPaths))
231+
232+
// Clean up temp config archive when done
233+
if configArchive != "" {
234+
defer os.Remove(configArchive)
235+
fmt.Printf("Created temporary config archive from directory\n")
236+
}
237+
} else {
238+
// Handle single file (GGUF model)
239+
if strings.HasSuffix(strings.ToLower(source), ".gguf") {
240+
isSafetensors = false
241+
fmt.Println("Detected GGUF model file")
242+
} else {
243+
fmt.Fprintf(os.Stderr, "Warning: could not determine model type for: %s\n", source)
244+
fmt.Fprintf(os.Stderr, "Assuming GGUF format.\n")
245+
}
210246
}
211247

212-
// Check if source file is a GGUF file
213-
if !strings.HasSuffix(strings.ToLower(source), ".gguf") {
214-
fmt.Fprintf(os.Stderr, "Warning: source file does not have .gguf extension: %s\n", source)
215-
fmt.Fprintf(os.Stderr, "Continuing anyway, but this may cause issues.\n")
248+
if file == "" && tag == "" {
249+
fmt.Fprintf(os.Stderr, "Error: one of --file or --tag is required\n")
250+
fs.Usage()
251+
return 1
216252
}
217253

254+
ctx := context.Background()
255+
218256
// Prepare registry client options
219257
registryClientOpts := []registry.ClientOption{
220258
registry.WithUserAgent("model-distribution-tool/" + version),
@@ -230,31 +268,49 @@ func cmdPackage(args []string) int {
230268
// Create registry client once with all options
231269
registryClient := registry.NewClient(registryClientOpts...)
232270

233-
var (
234-
target builder.Target
235-
err error
236-
)
271+
var target builder.Target
237272
if file != "" {
238273
target = tarball.NewFileTarget(file)
239274
} else {
275+
var err error
240276
target, err = registryClient.NewTarget(tag)
241277
if err != nil {
242278
fmt.Fprintf(os.Stderr, "Create packaging target: %v\n", err)
243279
return 1
244280
}
245281
}
246282

247-
// Create image with layer
248-
builder, err := builder.FromGGUF(source)
249-
if err != nil {
250-
fmt.Fprintf(os.Stderr, "Error creating model from gguf: %v\n", err)
251-
return 1
283+
// Create builder based on model type
284+
var b *builder.Builder
285+
if isSafetensors {
286+
fmt.Println("Creating safetensors model")
287+
b, err = builder.FromSafetensors(safetensorsPaths)
288+
if err != nil {
289+
fmt.Fprintf(os.Stderr, "Error creating model from safetensors: %v\n", err)
290+
return 1
291+
}
292+
293+
// Add config archive if provided
294+
if configArchive != "" {
295+
fmt.Printf("Adding config archive: %s\n", configArchive)
296+
b, err = b.WithConfigArchive(configArchive)
297+
if err != nil {
298+
fmt.Fprintf(os.Stderr, "Error adding config archive: %v\n", err)
299+
return 1
300+
}
301+
}
302+
} else {
303+
b, err = builder.FromGGUF(source)
304+
if err != nil {
305+
fmt.Fprintf(os.Stderr, "Error creating model from gguf: %v\n", err)
306+
return 1
307+
}
252308
}
253309

254310
// Add all license files as layers
255311
for _, path := range licensePaths {
256312
fmt.Println("Adding license file:", path)
257-
builder, err = builder.WithLicense(path)
313+
b, err = b.WithLicense(path)
258314
if err != nil {
259315
fmt.Fprintf(os.Stderr, "Error adding license layer for %s: %v\n", path, err)
260316
return 1
@@ -263,12 +319,12 @@ func cmdPackage(args []string) int {
263319

264320
if contextSize > 0 {
265321
fmt.Println("Setting context size:", contextSize)
266-
builder = builder.WithContextSize(contextSize)
322+
b = b.WithContextSize(contextSize)
267323
}
268324

269325
if mmproj != "" {
270326
fmt.Println("Adding multimodal projector file:", mmproj)
271-
builder, err = builder.WithMultimodalProjector(mmproj)
327+
b, err = b.WithMultimodalProjector(mmproj)
272328
if err != nil {
273329
fmt.Fprintf(os.Stderr, "Error adding multimodal projector layer for %s: %v\n", mmproj, err)
274330
return 1
@@ -277,15 +333,15 @@ func cmdPackage(args []string) int {
277333

278334
if chatTemplate != "" {
279335
fmt.Println("Adding chat template file:", chatTemplate)
280-
builder, err = builder.WithChatTemplateFile(chatTemplate)
336+
b, err = b.WithChatTemplateFile(chatTemplate)
281337
if err != nil {
282338
fmt.Fprintf(os.Stderr, "Error adding chat template layer for %s: %v\n", chatTemplate, err)
283339
return 1
284340
}
285341
}
286342

287343
// Push the image
288-
if err := builder.Build(ctx, target, os.Stdout); err != nil {
344+
if err := b.Build(ctx, target, os.Stdout); err != nil {
289345
fmt.Fprintf(os.Stderr, "Error writing model to registry: %v\n", err)
290346
return 1
291347
}
@@ -525,3 +581,132 @@ func cmdBundle(client *distribution.Client, args []string) int {
525581
fmt.Fprint(os.Stdout, bundle.RootDir())
526582
return 0
527583
}
584+
585+
// packageFromDirectory scans a directory for safetensors files and config files,
586+
// creating a temporary tar archive of the config files
587+
func packageFromDirectory(dirPath string) (safetensorsPaths []string, tempConfigArchive string, err error) {
588+
// Read directory contents (only top level, no subdirectories)
589+
entries, err := os.ReadDir(dirPath)
590+
if err != nil {
591+
return nil, "", fmt.Errorf("read directory: %w", err)
592+
}
593+
594+
var configFiles []string
595+
596+
for _, entry := range entries {
597+
if entry.IsDir() {
598+
continue // Skip subdirectories
599+
}
600+
601+
name := entry.Name()
602+
fullPath := filepath.Join(dirPath, name)
603+
604+
// Collect safetensors files
605+
if strings.HasSuffix(strings.ToLower(name), ".safetensors") {
606+
safetensorsPaths = append(safetensorsPaths, fullPath)
607+
}
608+
609+
// Collect config files: *.json, merges.txt
610+
if strings.HasSuffix(strings.ToLower(name), ".json") ||
611+
name == "merges.txt" {
612+
configFiles = append(configFiles, fullPath)
613+
}
614+
}
615+
616+
if len(safetensorsPaths) == 0 {
617+
return nil, "", fmt.Errorf("no safetensors files found in directory: %s", dirPath)
618+
}
619+
620+
// Sort to ensure reproducible artifacts
621+
sort.Strings(safetensorsPaths)
622+
623+
// Create temporary tar archive with config files if any exist
624+
if len(configFiles) > 0 {
625+
// Sort config files for reproducible tar archive
626+
sort.Strings(configFiles)
627+
628+
tempConfigArchive, err = createTempConfigArchive(configFiles)
629+
if err != nil {
630+
return nil, "", fmt.Errorf("create config archive: %w", err)
631+
}
632+
}
633+
634+
return safetensorsPaths, tempConfigArchive, nil
635+
}
636+
637+
// createTempConfigArchive creates a temporary tar archive containing the specified config files
638+
func createTempConfigArchive(configFiles []string) (string, error) {
639+
// Create temp file
640+
tmpFile, err := os.CreateTemp("", "vllm-config-*.tar")
641+
if err != nil {
642+
return "", fmt.Errorf("create temp file: %w", err)
643+
}
644+
tmpPath := tmpFile.Name()
645+
646+
// Create tar writer
647+
tw := tar.NewWriter(tmpFile)
648+
649+
// Add each config file to tar (preserving just filename, not full path)
650+
for _, filePath := range configFiles {
651+
// Open the file
652+
file, err := os.Open(filePath)
653+
if err != nil {
654+
tw.Close()
655+
tmpFile.Close()
656+
os.Remove(tmpPath)
657+
return "", fmt.Errorf("open config file %s: %w", filePath, err)
658+
}
659+
660+
// Get file info for tar header
661+
fileInfo, err := file.Stat()
662+
if err != nil {
663+
file.Close()
664+
tw.Close()
665+
tmpFile.Close()
666+
os.Remove(tmpPath)
667+
return "", fmt.Errorf("stat config file %s: %w", filePath, err)
668+
}
669+
670+
// Create tar header (use only basename, not full path)
671+
header := &tar.Header{
672+
Name: filepath.Base(filePath),
673+
Size: fileInfo.Size(),
674+
Mode: int64(fileInfo.Mode()),
675+
ModTime: fileInfo.ModTime(),
676+
}
677+
678+
// Write header
679+
if err := tw.WriteHeader(header); err != nil {
680+
file.Close()
681+
tw.Close()
682+
tmpFile.Close()
683+
os.Remove(tmpPath)
684+
return "", fmt.Errorf("write tar header for %s: %w", filePath, err)
685+
}
686+
687+
// Copy file contents
688+
if _, err := io.Copy(tw, file); err != nil {
689+
file.Close()
690+
tw.Close()
691+
tmpFile.Close()
692+
os.Remove(tmpPath)
693+
return "", fmt.Errorf("write tar content for %s: %w", filePath, err)
694+
}
695+
696+
file.Close()
697+
}
698+
699+
// Close tar writer and file
700+
if err := tw.Close(); err != nil {
701+
tmpFile.Close()
702+
os.Remove(tmpPath)
703+
return "", fmt.Errorf("close tar writer: %w", err)
704+
}
705+
706+
if err := tmpFile.Close(); err != nil {
707+
os.Remove(tmpPath)
708+
return "", fmt.Errorf("close temp file: %w", err)
709+
}
710+
711+
return tmpPath, nil
712+
}

0 commit comments

Comments
 (0)