Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 93 additions & 15 deletions cmd/cli/commands/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ import (
"fmt"
"html"
"io"
"os"
"path/filepath"

"github.com/docker/model-runner/pkg/distribution/builder"
"github.com/docker/model-runner/pkg/distribution/packaging"
"github.com/docker/model-runner/pkg/distribution/registry"
"github.com/docker/model-runner/pkg/distribution/tarball"
"github.com/docker/model-runner/pkg/distribution/types"
Expand All @@ -24,10 +26,11 @@ func newPackagedCmd() *cobra.Command {
var opts packageOptions

c := &cobra.Command{
Use: "package --gguf <path> [--license <path>...] [--context-size <tokens>] [--push] MODEL",
Short: "Package a GGUF file into a Docker model OCI artifact, with optional licenses.",
Long: "Package a GGUF file into a Docker model OCI artifact, with optional licenses. The package is sent to the model-runner, unless --push is specified.\n" +
"When packaging a sharded model --gguf should point to the first shard. All shard files should be siblings and should include the index in the file name (e.g. model-00001-of-00015.gguf).",
Use: "package (--gguf <path> | --safetensors-dir <path>) [--license <path>...] [--context-size <tokens>] [--push] MODEL",
Short: "Package a GGUF file or Safetensors directory into a Docker model OCI artifact.",
Long: "Package a GGUF file or Safetensors directory into a Docker model OCI artifact, with optional licenses. The package is sent to the model-runner, unless --push is specified.\n" +
"When packaging a sharded GGUF model, --gguf should point to the first shard. All shard files should be siblings and should include the index in the file name (e.g. model-00001-of-00015.gguf).\n" +
"When packaging a Safetensors model, --safetensors-dir should point to a directory containing .safetensors files and config files (*.json, merges.txt). All files will be auto-discovered and config files will be packaged into a tar archive.",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) != 1 {
return fmt.Errorf(
Expand All @@ -37,19 +40,62 @@ func newPackagedCmd() *cobra.Command {
cmd.Use,
)
}
if opts.ggufPath == "" {

// Validate that either --gguf or --safetensors-dir is provided (mutually exclusive)
if opts.ggufPath == "" && opts.safetensorsDir == "" {
return fmt.Errorf(
"GGUF path is required.\n\n" +
"Either --gguf or --safetensors-dir path is required.\n\n" +
"See 'docker model package --help' for more information",
)
}
if !filepath.IsAbs(opts.ggufPath) {
if opts.ggufPath != "" && opts.safetensorsDir != "" {
return fmt.Errorf(
"GGUF path must be absolute.\n\n" +
"Cannot specify both --gguf and --safetensors-dir. Please use only one format.\n\n" +
"See 'docker model package --help' for more information",
)
}
opts.ggufPath = filepath.Clean(opts.ggufPath)

// Validate GGUF path if provided
if opts.ggufPath != "" {
if !filepath.IsAbs(opts.ggufPath) {
return fmt.Errorf(
"GGUF path must be absolute.\n\n" +
"See 'docker model package --help' for more information",
)
}
opts.ggufPath = filepath.Clean(opts.ggufPath)
}

// Validate safetensors directory if provided
if opts.safetensorsDir != "" {
if !filepath.IsAbs(opts.safetensorsDir) {
return fmt.Errorf(
"Safetensors directory path must be absolute.\n\n" +
"See 'docker model package --help' for more information",
)
}
opts.safetensorsDir = filepath.Clean(opts.safetensorsDir)

// Check if it's a directory
info, err := os.Stat(opts.safetensorsDir)
if err != nil {
if os.IsNotExist(err) {
return fmt.Errorf(
"Safetensors directory does not exist: %s\n\n"+
"See 'docker model package --help' for more information",
opts.safetensorsDir,
)
}
return fmt.Errorf("could not access safetensors directory %q: %w", opts.safetensorsDir, err)
}
if !info.IsDir() {
return fmt.Errorf(
"Safetensors path must be a directory: %s\n\n"+
"See 'docker model package --help' for more information",
opts.safetensorsDir,
)
}
}

for i, l := range opts.licensePaths {
if !filepath.IsAbs(l) {
Expand All @@ -73,7 +119,8 @@ func newPackagedCmd() *cobra.Command {
ValidArgsFunction: completion.NoComplete,
}

c.Flags().StringVar(&opts.ggufPath, "gguf", "", "absolute path to gguf file (required)")
c.Flags().StringVar(&opts.ggufPath, "gguf", "", "absolute path to gguf file")
c.Flags().StringVar(&opts.safetensorsDir, "safetensors-dir", "", "absolute path to directory containing safetensors files and config")
c.Flags().StringVar(&opts.chatTemplatePath, "chat-template", "", "absolute path to chat template file (must be Jinja format)")
c.Flags().StringArrayVarP(&opts.licensePaths, "license", "l", nil, "absolute path to a license file")
c.Flags().BoolVar(&opts.push, "push", false, "push to registry (if not set, the model is loaded into the Model Runner content store)")
Expand All @@ -85,6 +132,7 @@ type packageOptions struct {
chatTemplatePath string
contextSize uint64
ggufPath string
safetensorsDir string
licensePaths []string
push bool
tag string
Expand All @@ -106,11 +154,41 @@ func packageModel(cmd *cobra.Command, opts packageOptions) error {
return err
}

// Create package builder with GGUF file
cmd.PrintErrf("Adding GGUF file from %q\n", opts.ggufPath)
pkg, err := builder.FromGGUF(opts.ggufPath)
if err != nil {
return fmt.Errorf("add gguf file: %w", err)
// Create package builder based on model format
var pkg *builder.Builder
if opts.ggufPath != "" {
cmd.PrintErrf("Adding GGUF file from %q\n", opts.ggufPath)
pkg, err = builder.FromGGUF(opts.ggufPath)
if err != nil {
return fmt.Errorf("add gguf file: %w", err)
}
} else {
// Safetensors model from directory
cmd.PrintErrf("Scanning directory %q for safetensors model...\n", opts.safetensorsDir)
safetensorsPaths, tempConfigArchive, err := packaging.PackageFromDirectory(opts.safetensorsDir)
if err != nil {
return fmt.Errorf("scan safetensors directory: %w", err)
}

// Clean up temp config archive when done
if tempConfigArchive != "" {
defer os.Remove(tempConfigArchive)
}

cmd.PrintErrf("Found %d safetensors file(s)\n", len(safetensorsPaths))
pkg, err = builder.FromSafetensors(safetensorsPaths)
if err != nil {
return fmt.Errorf("create safetensors model: %w", err)
}

// Add config archive if it was created
if tempConfigArchive != "" {
cmd.PrintErrf("Adding config archive from directory\n")
pkg, err = pkg.WithConfigArchive(tempConfigArchive)
if err != nil {
return fmt.Errorf("add config archive: %w", err)
}
}
}

// Set context size
Expand Down
135 changes: 2 additions & 133 deletions cmd/mdltool/main.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
package main

import (
"archive/tar"
"context"
"flag"
"fmt"
"io"
"os"
"path/filepath"
"sort"
"strings"

"github.com/docker/model-runner/pkg/distribution/builder"
"github.com/docker/model-runner/pkg/distribution/distribution"
"github.com/docker/model-runner/pkg/distribution/packaging"
"github.com/docker/model-runner/pkg/distribution/registry"
"github.com/docker/model-runner/pkg/distribution/tarball"
)
Expand Down Expand Up @@ -220,7 +218,7 @@ func cmdPackage(args []string) int {
if sourceInfo.IsDir() {
fmt.Printf("Detected directory, scanning for safetensors model...\n")
var err error
safetensorsPaths, configArchive, err = packageFromDirectory(source)
safetensorsPaths, configArchive, err = packaging.PackageFromDirectory(source)
if err != nil {
fmt.Fprintf(os.Stderr, "Error scanning directory: %v\n", err)
return 1
Expand Down Expand Up @@ -581,132 +579,3 @@ func cmdBundle(client *distribution.Client, args []string) int {
fmt.Fprint(os.Stdout, bundle.RootDir())
return 0
}

// packageFromDirectory scans a directory for safetensors files and config files,
// creating a temporary tar archive of the config files
func packageFromDirectory(dirPath string) (safetensorsPaths []string, tempConfigArchive string, err error) {
// Read directory contents (only top level, no subdirectories)
entries, err := os.ReadDir(dirPath)
if err != nil {
return nil, "", fmt.Errorf("read directory: %w", err)
}

var configFiles []string

for _, entry := range entries {
if entry.IsDir() {
continue // Skip subdirectories
}

name := entry.Name()
fullPath := filepath.Join(dirPath, name)

// Collect safetensors files
if strings.HasSuffix(strings.ToLower(name), ".safetensors") {
safetensorsPaths = append(safetensorsPaths, fullPath)
}

// Collect config files: *.json, merges.txt
if strings.HasSuffix(strings.ToLower(name), ".json") ||
name == "merges.txt" {
configFiles = append(configFiles, fullPath)
}
}

if len(safetensorsPaths) == 0 {
return nil, "", fmt.Errorf("no safetensors files found in directory: %s", dirPath)
}

// Sort to ensure reproducible artifacts
sort.Strings(safetensorsPaths)

// Create temporary tar archive with config files if any exist
if len(configFiles) > 0 {
// Sort config files for reproducible tar archive
sort.Strings(configFiles)

tempConfigArchive, err = createTempConfigArchive(configFiles)
if err != nil {
return nil, "", fmt.Errorf("create config archive: %w", err)
}
}

return safetensorsPaths, tempConfigArchive, nil
}

// createTempConfigArchive creates a temporary tar archive containing the specified config files
func createTempConfigArchive(configFiles []string) (string, error) {
// Create temp file
tmpFile, err := os.CreateTemp("", "vllm-config-*.tar")
if err != nil {
return "", fmt.Errorf("create temp file: %w", err)
}
tmpPath := tmpFile.Name()

// Create tar writer
tw := tar.NewWriter(tmpFile)

// Add each config file to tar (preserving just filename, not full path)
for _, filePath := range configFiles {
// Open the file
file, err := os.Open(filePath)
if err != nil {
tw.Close()
tmpFile.Close()
os.Remove(tmpPath)
return "", fmt.Errorf("open config file %s: %w", filePath, err)
}

// Get file info for tar header
fileInfo, err := file.Stat()
if err != nil {
file.Close()
tw.Close()
tmpFile.Close()
os.Remove(tmpPath)
return "", fmt.Errorf("stat config file %s: %w", filePath, err)
}

// Create tar header (use only basename, not full path)
header := &tar.Header{
Name: filepath.Base(filePath),
Size: fileInfo.Size(),
Mode: int64(fileInfo.Mode()),
ModTime: fileInfo.ModTime(),
}

// Write header
if err := tw.WriteHeader(header); err != nil {
file.Close()
tw.Close()
tmpFile.Close()
os.Remove(tmpPath)
return "", fmt.Errorf("write tar header for %s: %w", filePath, err)
}

// Copy file contents
if _, err := io.Copy(tw, file); err != nil {
file.Close()
tw.Close()
tmpFile.Close()
os.Remove(tmpPath)
return "", fmt.Errorf("write tar content for %s: %w", filePath, err)
}

file.Close()
}

// Close tar writer and file
if err := tw.Close(); err != nil {
tmpFile.Close()
os.Remove(tmpPath)
return "", fmt.Errorf("close tar writer: %w", err)
}

if err := tmpFile.Close(); err != nil {
os.Remove(tmpPath)
return "", fmt.Errorf("close temp file: %w", err)
}

return tmpPath, nil
}
Loading
Loading