diff --git a/cmd/cli/commands/package.go b/cmd/cli/commands/package.go index bb6c19081..fa1683f30 100644 --- a/cmd/cli/commands/package.go +++ b/cmd/cli/commands/package.go @@ -7,9 +7,11 @@ import ( "fmt" "html" "io" + "os" "path/filepath" "github.com/docker/model-runner/pkg/distribution/builder" + "github.com/docker/model-runner/pkg/distribution/packaging" "github.com/docker/model-runner/pkg/distribution/registry" "github.com/docker/model-runner/pkg/distribution/tarball" "github.com/docker/model-runner/pkg/distribution/types" @@ -24,10 +26,11 @@ func newPackagedCmd() *cobra.Command { var opts packageOptions c := &cobra.Command{ - Use: "package --gguf [--license ...] [--context-size ] [--push] MODEL", - Short: "Package a GGUF file into a Docker model OCI artifact, with optional licenses.", - Long: "Package a GGUF file into a Docker model OCI artifact, with optional licenses. The package is sent to the model-runner, unless --push is specified.\n" + - "When packaging a sharded model --gguf should point to the first shard. All shard files should be siblings and should include the index in the file name (e.g. model-00001-of-00015.gguf).", + Use: "package (--gguf | --safetensors-dir ) [--license ...] [--context-size ] [--push] MODEL", + Short: "Package a GGUF file or Safetensors directory into a Docker model OCI artifact.", + Long: "Package a GGUF file or Safetensors directory into a Docker model OCI artifact, with optional licenses. The package is sent to the model-runner, unless --push is specified.\n" + + "When packaging a sharded GGUF model, --gguf should point to the first shard. All shard files should be siblings and should include the index in the file name (e.g. model-00001-of-00015.gguf).\n" + + "When packaging a Safetensors model, --safetensors-dir should point to a directory containing .safetensors files and config files (*.json, merges.txt). All files will be auto-discovered and config files will be packaged into a tar archive.", Args: func(cmd *cobra.Command, args []string) error { if len(args) != 1 { return fmt.Errorf( @@ -37,19 +40,62 @@ func newPackagedCmd() *cobra.Command { cmd.Use, ) } - if opts.ggufPath == "" { + + // Validate that either --gguf or --safetensors-dir is provided (mutually exclusive) + if opts.ggufPath == "" && opts.safetensorsDir == "" { return fmt.Errorf( - "GGUF path is required.\n\n" + + "Either --gguf or --safetensors-dir path is required.\n\n" + "See 'docker model package --help' for more information", ) } - if !filepath.IsAbs(opts.ggufPath) { + if opts.ggufPath != "" && opts.safetensorsDir != "" { return fmt.Errorf( - "GGUF path must be absolute.\n\n" + + "Cannot specify both --gguf and --safetensors-dir. Please use only one format.\n\n" + "See 'docker model package --help' for more information", ) } - opts.ggufPath = filepath.Clean(opts.ggufPath) + + // Validate GGUF path if provided + if opts.ggufPath != "" { + if !filepath.IsAbs(opts.ggufPath) { + return fmt.Errorf( + "GGUF path must be absolute.\n\n" + + "See 'docker model package --help' for more information", + ) + } + opts.ggufPath = filepath.Clean(opts.ggufPath) + } + + // Validate safetensors directory if provided + if opts.safetensorsDir != "" { + if !filepath.IsAbs(opts.safetensorsDir) { + return fmt.Errorf( + "Safetensors directory path must be absolute.\n\n" + + "See 'docker model package --help' for more information", + ) + } + opts.safetensorsDir = filepath.Clean(opts.safetensorsDir) + + // Check if it's a directory + info, err := os.Stat(opts.safetensorsDir) + if err != nil { + if os.IsNotExist(err) { + return fmt.Errorf( + "Safetensors directory does not exist: %s\n\n"+ + "See 'docker model package --help' for more information", + opts.safetensorsDir, + ) + } + return fmt.Errorf("could not access safetensors directory %q: %w", opts.safetensorsDir, err) + } + if !info.IsDir() { + return fmt.Errorf( + "Safetensors path must be a directory: %s\n\n"+ + "See 'docker model package --help' for more information", + opts.safetensorsDir, + ) + } + } for i, l := range opts.licensePaths { if !filepath.IsAbs(l) { @@ -73,7 +119,8 @@ func newPackagedCmd() *cobra.Command { ValidArgsFunction: completion.NoComplete, } - c.Flags().StringVar(&opts.ggufPath, "gguf", "", "absolute path to gguf file (required)") + c.Flags().StringVar(&opts.ggufPath, "gguf", "", "absolute path to gguf file") + c.Flags().StringVar(&opts.safetensorsDir, "safetensors-dir", "", "absolute path to directory containing safetensors files and config") c.Flags().StringVar(&opts.chatTemplatePath, "chat-template", "", "absolute path to chat template file (must be Jinja format)") c.Flags().StringArrayVarP(&opts.licensePaths, "license", "l", nil, "absolute path to a license file") c.Flags().BoolVar(&opts.push, "push", false, "push to registry (if not set, the model is loaded into the Model Runner content store)") @@ -85,6 +132,7 @@ type packageOptions struct { chatTemplatePath string contextSize uint64 ggufPath string + safetensorsDir string licensePaths []string push bool tag string @@ -106,11 +154,41 @@ func packageModel(cmd *cobra.Command, opts packageOptions) error { return err } - // Create package builder with GGUF file - cmd.PrintErrf("Adding GGUF file from %q\n", opts.ggufPath) - pkg, err := builder.FromGGUF(opts.ggufPath) - if err != nil { - return fmt.Errorf("add gguf file: %w", err) + // Create package builder based on model format + var pkg *builder.Builder + if opts.ggufPath != "" { + cmd.PrintErrf("Adding GGUF file from %q\n", opts.ggufPath) + pkg, err = builder.FromGGUF(opts.ggufPath) + if err != nil { + return fmt.Errorf("add gguf file: %w", err) + } + } else { + // Safetensors model from directory + cmd.PrintErrf("Scanning directory %q for safetensors model...\n", opts.safetensorsDir) + safetensorsPaths, tempConfigArchive, err := packaging.PackageFromDirectory(opts.safetensorsDir) + if err != nil { + return fmt.Errorf("scan safetensors directory: %w", err) + } + + // Clean up temp config archive when done + if tempConfigArchive != "" { + defer os.Remove(tempConfigArchive) + } + + cmd.PrintErrf("Found %d safetensors file(s)\n", len(safetensorsPaths)) + pkg, err = builder.FromSafetensors(safetensorsPaths) + if err != nil { + return fmt.Errorf("create safetensors model: %w", err) + } + + // Add config archive if it was created + if tempConfigArchive != "" { + cmd.PrintErrf("Adding config archive from directory\n") + pkg, err = pkg.WithConfigArchive(tempConfigArchive) + if err != nil { + return fmt.Errorf("add config archive: %w", err) + } + } } // Set context size diff --git a/cmd/mdltool/main.go b/cmd/mdltool/main.go index 490551a52..79faccf62 100644 --- a/cmd/mdltool/main.go +++ b/cmd/mdltool/main.go @@ -1,18 +1,16 @@ package main import ( - "archive/tar" "context" "flag" "fmt" - "io" "os" "path/filepath" - "sort" "strings" "github.com/docker/model-runner/pkg/distribution/builder" "github.com/docker/model-runner/pkg/distribution/distribution" + "github.com/docker/model-runner/pkg/distribution/packaging" "github.com/docker/model-runner/pkg/distribution/registry" "github.com/docker/model-runner/pkg/distribution/tarball" ) @@ -220,7 +218,7 @@ func cmdPackage(args []string) int { if sourceInfo.IsDir() { fmt.Printf("Detected directory, scanning for safetensors model...\n") var err error - safetensorsPaths, configArchive, err = packageFromDirectory(source) + safetensorsPaths, configArchive, err = packaging.PackageFromDirectory(source) if err != nil { fmt.Fprintf(os.Stderr, "Error scanning directory: %v\n", err) return 1 @@ -581,132 +579,3 @@ func cmdBundle(client *distribution.Client, args []string) int { fmt.Fprint(os.Stdout, bundle.RootDir()) return 0 } - -// packageFromDirectory scans a directory for safetensors files and config files, -// creating a temporary tar archive of the config files -func packageFromDirectory(dirPath string) (safetensorsPaths []string, tempConfigArchive string, err error) { - // Read directory contents (only top level, no subdirectories) - entries, err := os.ReadDir(dirPath) - if err != nil { - return nil, "", fmt.Errorf("read directory: %w", err) - } - - var configFiles []string - - for _, entry := range entries { - if entry.IsDir() { - continue // Skip subdirectories - } - - name := entry.Name() - fullPath := filepath.Join(dirPath, name) - - // Collect safetensors files - if strings.HasSuffix(strings.ToLower(name), ".safetensors") { - safetensorsPaths = append(safetensorsPaths, fullPath) - } - - // Collect config files: *.json, merges.txt - if strings.HasSuffix(strings.ToLower(name), ".json") || - name == "merges.txt" { - configFiles = append(configFiles, fullPath) - } - } - - if len(safetensorsPaths) == 0 { - return nil, "", fmt.Errorf("no safetensors files found in directory: %s", dirPath) - } - - // Sort to ensure reproducible artifacts - sort.Strings(safetensorsPaths) - - // Create temporary tar archive with config files if any exist - if len(configFiles) > 0 { - // Sort config files for reproducible tar archive - sort.Strings(configFiles) - - tempConfigArchive, err = createTempConfigArchive(configFiles) - if err != nil { - return nil, "", fmt.Errorf("create config archive: %w", err) - } - } - - return safetensorsPaths, tempConfigArchive, nil -} - -// createTempConfigArchive creates a temporary tar archive containing the specified config files -func createTempConfigArchive(configFiles []string) (string, error) { - // Create temp file - tmpFile, err := os.CreateTemp("", "vllm-config-*.tar") - if err != nil { - return "", fmt.Errorf("create temp file: %w", err) - } - tmpPath := tmpFile.Name() - - // Create tar writer - tw := tar.NewWriter(tmpFile) - - // Add each config file to tar (preserving just filename, not full path) - for _, filePath := range configFiles { - // Open the file - file, err := os.Open(filePath) - if err != nil { - tw.Close() - tmpFile.Close() - os.Remove(tmpPath) - return "", fmt.Errorf("open config file %s: %w", filePath, err) - } - - // Get file info for tar header - fileInfo, err := file.Stat() - if err != nil { - file.Close() - tw.Close() - tmpFile.Close() - os.Remove(tmpPath) - return "", fmt.Errorf("stat config file %s: %w", filePath, err) - } - - // Create tar header (use only basename, not full path) - header := &tar.Header{ - Name: filepath.Base(filePath), - Size: fileInfo.Size(), - Mode: int64(fileInfo.Mode()), - ModTime: fileInfo.ModTime(), - } - - // Write header - if err := tw.WriteHeader(header); err != nil { - file.Close() - tw.Close() - tmpFile.Close() - os.Remove(tmpPath) - return "", fmt.Errorf("write tar header for %s: %w", filePath, err) - } - - // Copy file contents - if _, err := io.Copy(tw, file); err != nil { - file.Close() - tw.Close() - tmpFile.Close() - os.Remove(tmpPath) - return "", fmt.Errorf("write tar content for %s: %w", filePath, err) - } - - file.Close() - } - - // Close tar writer and file - if err := tw.Close(); err != nil { - tmpFile.Close() - os.Remove(tmpPath) - return "", fmt.Errorf("close tar writer: %w", err) - } - - if err := tmpFile.Close(); err != nil { - os.Remove(tmpPath) - return "", fmt.Errorf("close temp file: %w", err) - } - - return tmpPath, nil -} diff --git a/pkg/distribution/packaging/safetensors.go b/pkg/distribution/packaging/safetensors.go new file mode 100644 index 000000000..46c27c7ec --- /dev/null +++ b/pkg/distribution/packaging/safetensors.go @@ -0,0 +1,146 @@ +package packaging + +import ( + "archive/tar" + "fmt" + "io" + "os" + "path/filepath" + "sort" + "strings" +) + +// PackageFromDirectory scans a directory for safetensors files and config files, +// creating a temporary tar archive of the config files. +// It returns the paths to safetensors files, path to temporary config archive (if created), +// and any error encountered. +func PackageFromDirectory(dirPath string) (safetensorsPaths []string, tempConfigArchive string, err error) { + // Read directory contents (only top level, no subdirectories) + entries, err := os.ReadDir(dirPath) + if err != nil { + return nil, "", fmt.Errorf("read directory: %w", err) + } + + var configFiles []string + + for _, entry := range entries { + if entry.IsDir() { + continue // Skip subdirectories + } + + name := entry.Name() + fullPath := filepath.Join(dirPath, name) + + // Collect safetensors files + if strings.HasSuffix(strings.ToLower(name), ".safetensors") { + safetensorsPaths = append(safetensorsPaths, fullPath) + } + + // Collect config files: *.json, merges.txt + if strings.HasSuffix(strings.ToLower(name), ".json") || strings.EqualFold(name, "merges.txt") { + configFiles = append(configFiles, fullPath) + } + } + + if len(safetensorsPaths) == 0 { + return nil, "", fmt.Errorf("no safetensors files found in directory: %s", dirPath) + } + + // Sort to ensure reproducible artifacts + sort.Strings(safetensorsPaths) + + // Create temporary tar archive with config files if any exist + if len(configFiles) > 0 { + // Sort config files for reproducible tar archive + sort.Strings(configFiles) + + tempConfigArchive, err = CreateTempConfigArchive(configFiles) + if err != nil { + return nil, "", fmt.Errorf("create config archive: %w", err) + } + } + + return safetensorsPaths, tempConfigArchive, nil +} + +// CreateTempConfigArchive creates a temporary tar archive containing the specified config files. +// It returns the path to the temporary tar file and any error encountered. +// The caller is responsible for removing the temporary file when done. +func CreateTempConfigArchive(configFiles []string) (string, error) { + // Create temp file + tmpFile, err := os.CreateTemp("", "vllm-config-*.tar") + if err != nil { + return "", fmt.Errorf("create temp file: %w", err) + } + tmpPath := tmpFile.Name() + + // Track success to determine if we should clean up the temp file + shouldKeepTempFile := false + defer func() { + if !shouldKeepTempFile { + os.Remove(tmpPath) + } + }() + + // Create tar writer + tw := tar.NewWriter(tmpFile) + + // Add each config file to tar (preserving just filename, not full path) + for _, filePath := range configFiles { + if err := addFileToTar(tw, filePath); err != nil { + tw.Close() + tmpFile.Close() + return "", err + } + } + + // Close tar writer first + if err := tw.Close(); err != nil { + tmpFile.Close() + return "", fmt.Errorf("close tar writer: %w", err) + } + + // Close temp file + if err := tmpFile.Close(); err != nil { + return "", fmt.Errorf("close temp file: %w", err) + } + + shouldKeepTempFile = true + return tmpPath, nil +} + +// addFileToTar adds a single file to the tar archive with only its basename (not full path) +func addFileToTar(tw *tar.Writer, filePath string) error { + // Open the file + file, err := os.Open(filePath) + if err != nil { + return fmt.Errorf("open file %s: %w", filePath, err) + } + defer file.Close() + + // Get file info for tar header + fileInfo, err := file.Stat() + if err != nil { + return fmt.Errorf("stat file %s: %w", filePath, err) + } + + // Create tar header (use only basename, not full path) + header := &tar.Header{ + Name: filepath.Base(filePath), + Size: fileInfo.Size(), + Mode: int64(fileInfo.Mode()), + ModTime: fileInfo.ModTime(), + } + + // Write header + if err := tw.WriteHeader(header); err != nil { + return fmt.Errorf("write tar header for %s: %w", filePath, err) + } + + // Copy file contents + if _, err := io.Copy(tw, file); err != nil { + return fmt.Errorf("write tar content for %s: %w", filePath, err) + } + + return nil +}