diff --git a/cmd/cli/commands/package.go b/cmd/cli/commands/package.go index fa1683f30..eb41b6d3c 100644 --- a/cmd/cli/commands/package.go +++ b/cmd/cli/commands/package.go @@ -106,6 +106,18 @@ func newPackagedCmd() *cobra.Command { } opts.licensePaths[i] = filepath.Clean(l) } + + // Validate dir-tar paths are relative (not absolute) + for _, dirPath := range opts.dirTarPaths { + if filepath.IsAbs(dirPath) { + return fmt.Errorf( + "dir-tar path must be relative, got absolute path: %s\n\n"+ + "See 'docker model package --help' for more information", + dirPath, + ) + } + } + return nil }, RunE: func(cmd *cobra.Command, args []string) error { @@ -123,6 +135,7 @@ func newPackagedCmd() *cobra.Command { c.Flags().StringVar(&opts.safetensorsDir, "safetensors-dir", "", "absolute path to directory containing safetensors files and config") c.Flags().StringVar(&opts.chatTemplatePath, "chat-template", "", "absolute path to chat template file (must be Jinja format)") c.Flags().StringArrayVarP(&opts.licensePaths, "license", "l", nil, "absolute path to a license file") + c.Flags().StringArrayVar(&opts.dirTarPaths, "dir-tar", nil, "relative path to directory to package as tar (can be specified multiple times)") c.Flags().BoolVar(&opts.push, "push", false, "push to registry (if not set, the model is loaded into the Model Runner content store)") c.Flags().Uint64Var(&opts.contextSize, "context-size", 0, "context size in tokens") return c @@ -134,6 +147,7 @@ type packageOptions struct { ggufPath string safetensorsDir string licensePaths []string + dirTarPaths []string push bool tag string } @@ -213,6 +227,69 @@ func packageModel(cmd *cobra.Command, opts packageOptions) error { } } + // Process directory tar archives + var tempDirTarFiles []string + if len(opts.dirTarPaths) > 0 { + // Schedule cleanup of temp tar files + defer func() { + for _, tempFile := range tempDirTarFiles { + os.Remove(tempFile) + } + }() + + // Determine base directory for resolving relative paths + var baseDir string + if opts.safetensorsDir != "" { + baseDir = opts.safetensorsDir + } else { + // For GGUF, use the directory containing the GGUF file + baseDir = filepath.Dir(opts.ggufPath) + } + + for _, relDirPath := range opts.dirTarPaths { + // Reject absolute paths + if filepath.IsAbs(relDirPath) { + return fmt.Errorf("dir-tar path must be relative: %s", relDirPath) + } + + // Resolve the full directory path + fullDirPath := filepath.Join(baseDir, relDirPath) + fullDirPath = filepath.Clean(fullDirPath) + + // Verify the resolved path is within baseDir to prevent directory traversal + relPath, err := filepath.Rel(baseDir, fullDirPath) + if err != nil { + return fmt.Errorf("dir-tar path %q could not be validated: %w", relDirPath, err) + } + // Check if the relative path tries to escape the base directory + if relPath == ".." || len(relPath) >= 3 && relPath[:3] == ".."+string(filepath.Separator) { + return fmt.Errorf("dir-tar path %q escapes base directory", relDirPath) + } + + // Verify the directory exists + info, err := os.Stat(fullDirPath) + if err != nil { + return fmt.Errorf("cannot access directory %q (resolved from %q): %w", fullDirPath, relDirPath, err) + } + if !info.IsDir() { + return fmt.Errorf("path %q is not a directory", fullDirPath) + } + + cmd.PrintErrf("Creating tar archive for directory %q\n", relDirPath) + tempTarPath, err := packaging.CreateDirectoryTarArchive(fullDirPath) + if err != nil { + return fmt.Errorf("create tar archive for directory %q: %w", relDirPath, err) + } + tempDirTarFiles = append(tempDirTarFiles, tempTarPath) + + cmd.PrintErrf("Adding directory tar archive from %q\n", relDirPath) + pkg, err = pkg.WithDirTar(tempTarPath) + if err != nil { + return fmt.Errorf("add directory tar: %w", err) + } + } + } + if opts.push { cmd.PrintErrln("Pushing model to registry...") } else { diff --git a/cmd/cli/docs/reference/docker_model_package.yaml b/cmd/cli/docs/reference/docker_model_package.yaml index bed0cd43c..b30710331 100644 --- a/cmd/cli/docs/reference/docker_model_package.yaml +++ b/cmd/cli/docs/reference/docker_model_package.yaml @@ -28,6 +28,17 @@ options: experimentalcli: false kubernetes: false swarm: false + - option: dir-tar + value_type: stringArray + default_value: '[]' + description: | + relative path to directory to package as tar (can be specified multiple times) + deprecated: false + hidden: false + experimental: false + experimentalcli: false + kubernetes: false + swarm: false - option: gguf value_type: string description: absolute path to gguf file diff --git a/cmd/cli/docs/reference/model_package.md b/cmd/cli/docs/reference/model_package.md index 262070ac5..44e7a4e32 100644 --- a/cmd/cli/docs/reference/model_package.md +++ b/cmd/cli/docs/reference/model_package.md @@ -11,6 +11,7 @@ When packaging a Safetensors model, --safetensors-dir should point to a director |:--------------------|:--------------|:--------|:---------------------------------------------------------------------------------------| | `--chat-template` | `string` | | absolute path to chat template file (must be Jinja format) | | `--context-size` | `uint64` | `0` | context size in tokens | +| `--dir-tar` | `stringArray` | | relative path to directory to package as tar (can be specified multiple times) | | `--gguf` | `string` | | absolute path to gguf file | | `-l`, `--license` | `stringArray` | | absolute path to a license file | | `--push` | `bool` | | push to registry (if not set, the model is loaded into the Model Runner content store) | diff --git a/pkg/distribution/builder/builder.go b/pkg/distribution/builder/builder.go index 9b6be8c5f..f64dd7ab4 100644 --- a/pkg/distribution/builder/builder.go +++ b/pkg/distribution/builder/builder.go @@ -102,6 +102,18 @@ func (b *Builder) WithConfigArchive(path string) (*Builder, error) { }, nil } +// WithDirTar adds a directory tar archive to the artifact. +// Multiple directory tar archives can be added by calling this method multiple times. +func (b *Builder) WithDirTar(path string) (*Builder, error) { + dirTarLayer, err := partial.NewLayer(path, types.MediaTypeDirTar) + if err != nil { + return nil, fmt.Errorf("dir tar layer from %q: %w", path, err) + } + return &Builder{ + model: mutate.AppendLayers(b.model, dirTarLayer), + }, nil +} + // Target represents a build target type Target interface { Write(context.Context, types.ModelArtifact, io.Writer) error diff --git a/pkg/distribution/internal/bundle/unpack.go b/pkg/distribution/internal/bundle/unpack.go index 6609063a0..d7b9cdea4 100644 --- a/pkg/distribution/internal/bundle/unpack.go +++ b/pkg/distribution/internal/bundle/unpack.go @@ -60,6 +60,11 @@ func Unpack(dir string, model types.Model) (*Bundle, error) { } } + // Unpack directory tar archives (can be multiple) + if err := unpackDirTarArchives(bundle, model); err != nil { + return nil, fmt.Errorf("unpack directory tar archives: %w", err) + } + // Always create the runtime config if err := unpackRuntimeConfig(bundle, model); err != nil { return nil, fmt.Errorf("add config.json to runtime bundle: %w", err) @@ -230,6 +235,52 @@ func unpackConfigArchive(bundle *Bundle, mdl types.Model) error { return nil } +func unpackDirTarArchives(bundle *Bundle, mdl types.Model) error { + // Cast to ModelArtifact to access Layers() method + artifact, ok := mdl.(types.ModelArtifact) + if !ok { + // If it's not a ModelArtifact, there are no layers to extract + return nil + } + + // Get all layers from the model + layers, err := artifact.Layers() + if err != nil { + return fmt.Errorf("get model layers: %w", err) + } + + modelDir := filepath.Join(bundle.dir, ModelSubdir) + + // Iterate through layers and extract directory tar archives + for _, layer := range layers { + mediaType, err := layer.MediaType() + if err != nil { + fmt.Printf("Warning: failed to get media type for layer: %v", err) + continue + } + + // Check if this is a directory tar layer + if mediaType != types.MediaTypeDirTar { + continue + } + + // Get the layer as an uncompressed stream (decompression handled automatically) + uncompressed, err := layer.Uncompressed() + if err != nil { + return fmt.Errorf("get uncompressed layer: %w", err) + } + + // Stream directly to tar extraction - no temp file needed + if err := extractTarArchiveFromReader(uncompressed, modelDir); err != nil { + uncompressed.Close() + return fmt.Errorf("extract directory tar archive: %w", err) + } + uncompressed.Close() + } + + return nil +} + // validatePathWithinDirectory checks if targetPath is within baseDir to prevent directory traversal attacks. // It uses filepath.IsLocal() to provide robust security against // various directory traversal attempts including edge cases like empty paths, ".", "..", symbolic links, etc. @@ -264,14 +315,7 @@ func validatePathWithinDirectory(baseDir, targetPath string) error { return nil } -func extractTarArchive(archivePath, destDir string) error { - // Open the tar file - file, err := os.Open(archivePath) - if err != nil { - return fmt.Errorf("open tar archive: %w", err) - } - defer file.Close() - +func extractTarArchiveFromReader(r io.Reader, destDir string) error { // Get absolute path of destination directory for security checks absDestDir, err := filepath.Abs(destDir) if err != nil { @@ -279,7 +323,7 @@ func extractTarArchive(archivePath, destDir string) error { } // Create tar reader - tr := tar.NewReader(file) + tr := tar.NewReader(r) // Extract files for { @@ -328,6 +372,18 @@ func extractTarArchive(archivePath, destDir string) error { return nil } +func extractTarArchive(archivePath, destDir string) error { + // Open the tar file + file, err := os.Open(archivePath) + if err != nil { + return fmt.Errorf("open tar archive: %w", err) + } + defer file.Close() + + // Delegate to the streaming version + return extractTarArchiveFromReader(file, destDir) +} + // extractFile extracts a single file from the tar reader func extractFile(tr io.Reader, target string, mode os.FileMode) error { // Ensure parent directory exists diff --git a/pkg/distribution/packaging/dirtar.go b/pkg/distribution/packaging/dirtar.go new file mode 100644 index 000000000..9998e0374 --- /dev/null +++ b/pkg/distribution/packaging/dirtar.go @@ -0,0 +1,116 @@ +package packaging + +import ( + "archive/tar" + "fmt" + "io" + "os" + "path/filepath" +) + +// CreateDirectoryTarArchive creates a temporary tar archive containing the specified directory +// with its structure preserved. Symlinks encountered in the directory are skipped and will not be included +// in the archive. It returns the path to the temporary tar file and any error encountered. +// The caller is responsible for removing the temporary file when done. +func CreateDirectoryTarArchive(dirPath string) (string, error) { + // Verify directory exists + info, err := os.Stat(dirPath) + if err != nil { + return "", fmt.Errorf("stat directory: %w", err) + } + if !info.IsDir() { + return "", fmt.Errorf("path is not a directory: %s", dirPath) + } + + // Create temp file + tmpFile, err := os.CreateTemp("", "dir-tar-*.tar") + if err != nil { + return "", fmt.Errorf("create temp file: %w", err) + } + tmpPath := tmpFile.Name() + + // Track success to determine if we should clean up the temp file + shouldKeepTempFile := false + defer func() { + if !shouldKeepTempFile { + os.Remove(tmpPath) + } + }() + + // Create tar writer + tw := tar.NewWriter(tmpFile) + + // Walk the directory tree + err = filepath.Walk(dirPath, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info == nil { + return fmt.Errorf("nil FileInfo for path: %s", path) + } + // Skip symlinks - they're not needed for model distribution and are + // skipped during extraction for security reasons + if info.Mode()&os.ModeSymlink != 0 { + return nil + } + + // Create tar header + header, err := tar.FileInfoHeader(info, "") + if err != nil { + return fmt.Errorf("create tar header for %s: %w", path, err) + } + + // Compute relative path from the parent of dirPath + relPath, err := filepath.Rel(filepath.Dir(dirPath), path) + if err != nil { + return fmt.Errorf("compute relative path: %w", err) + } + + // Use forward slashes for tar archive paths + header.Name = filepath.ToSlash(relPath) + + // Write header + if err := tw.WriteHeader(header); err != nil { + return fmt.Errorf("write tar header: %w", err) + } + + // If it's a file, write its contents + if !info.IsDir() { + file, err := os.Open(path) + if err != nil { + return fmt.Errorf("open file %s: %w", path, err) + } + + // Copy file contents + if _, err := io.Copy(tw, file); err != nil { + file.Close() + return fmt.Errorf("write tar content for %s: %w", path, err) + } + if err := file.Close(); err != nil { + return fmt.Errorf("close file %s: %w", path, err) + } + } + + return nil + }) + + if err != nil { + tw.Close() + tmpFile.Close() + return "", fmt.Errorf("walk directory: %w", err) + } + + // Close tar writer + if err := tw.Close(); err != nil { + tmpFile.Close() + return "", fmt.Errorf("close tar writer: %w", err) + } + + // Close temp file + if err := tmpFile.Close(); err != nil { + return "", fmt.Errorf("close temp file: %w", err) + } + + shouldKeepTempFile = true + return tmpPath, nil +} diff --git a/pkg/distribution/packaging/dirtar_test.go b/pkg/distribution/packaging/dirtar_test.go new file mode 100644 index 000000000..b7cfb9c1c --- /dev/null +++ b/pkg/distribution/packaging/dirtar_test.go @@ -0,0 +1,124 @@ +package packaging + +import ( + "archive/tar" + "io" + "os" + "path/filepath" + "testing" +) + +func TestCreateDirTarArchive(t *testing.T) { + // Create a temporary directory with some test files + tempDir, err := os.MkdirTemp("", "dirtar-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Create test directory structure + testDir := filepath.Join(tempDir, "test_directory") + if err := os.MkdirAll(filepath.Join(testDir, "subdir"), 0755); err != nil { + t.Fatalf("Failed to create test directory: %v", err) + } + + // Create test files + testFiles := map[string]string{ + "file1.txt": "content1", + "subdir/file2.txt": "content2", + } + + for relPath, content := range testFiles { + fullPath := filepath.Join(testDir, relPath) + if err := os.WriteFile(fullPath, []byte(content), 0644); err != nil { + t.Fatalf("Failed to write test file %s: %v", relPath, err) + } + } + + // Create tar archive + tarPath, err := CreateDirectoryTarArchive(testDir) + if err != nil { + t.Fatalf("CreateDirectoryTarArchive failed: %v", err) + } + defer os.Remove(tarPath) + + // Verify tar archive exists + if _, err := os.Stat(tarPath); os.IsNotExist(err) { + t.Fatal("Tar archive was not created") + } + + // Read and verify tar contents + file, err := os.Open(tarPath) + if err != nil { + t.Fatalf("Failed to open tar archive: %v", err) + } + defer file.Close() + + tr := tar.NewReader(file) + foundFiles := make(map[string]bool) + + for { + header, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("Failed to read tar header: %v", err) + } + + foundFiles[header.Name] = true + + // Verify it's within the test_directory structure + if header.Typeflag == tar.TypeReg { + // Read file content + content, err := io.ReadAll(tr) + if err != nil { + t.Fatalf("Failed to read file content: %v", err) + } + + // Verify content matches + expectedPath := header.Name[len("test_directory/"):] + if expectedContent, ok := testFiles[expectedPath]; ok { + if string(content) != expectedContent { + t.Errorf("File %s content mismatch: got %q, want %q", expectedPath, string(content), expectedContent) + } + } + } + } + + // Verify all expected entries are present + expectedEntries := []string{ + "test_directory", + "test_directory/file1.txt", + "test_directory/subdir", + "test_directory/subdir/file2.txt", + } + + for _, entry := range expectedEntries { + if !foundFiles[entry] { + t.Errorf("Expected entry %q not found in tar archive", entry) + } + } +} + +func TestCreateDirTarArchive_NonExistentDir(t *testing.T) { + _, err := CreateDirectoryTarArchive("/nonexistent/directory") + if err == nil { + t.Error("Expected error for non-existent directory, got nil") + } +} + +func TestCreateDirTarArchive_NotADirectory(t *testing.T) { + // Create a temporary file + tempFile, err := os.CreateTemp("", "not-a-dir-*") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer os.Remove(tempFile.Name()) + tempFile.Close() + + _, err = CreateDirectoryTarArchive(tempFile.Name()) + if err == nil { + t.Error("Expected error for file path instead of directory, got nil") + } +} diff --git a/pkg/distribution/types/config.go b/pkg/distribution/types/config.go index adce5d9f8..a5bc070dd 100644 --- a/pkg/distribution/types/config.go +++ b/pkg/distribution/types/config.go @@ -23,6 +23,9 @@ const ( // MediaTypeVLLMConfigArchive indicates a tar archive containing vLLM-specific config files. MediaTypeVLLMConfigArchive = types.MediaType("application/vnd.docker.ai.vllm.config.tar") + // MediaTypeDirTar indicates a tar archive containing a directory with its structure preserved. + MediaTypeDirTar = types.MediaType("application/vnd.docker.ai.dir.tar") + // MediaTypeLicense indicates a plain text file containing a license MediaTypeLicense = types.MediaType("application/vnd.docker.ai.license")