Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 133 additions & 46 deletions pkg/archiver/archiver.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,63 +25,115 @@ import (
"strings"
)

// Tar tars the target file and return the content by stream.
func Tar(path string) (io.Reader, error) {
// Tar creates a tar archive of the specified path (file or directory)
// and returns the content as a stream. For individual files, it preserves
// the directory structure relative to the working directory.
func Tar(srcPath string, workDir string) (io.Reader, error) {
pr, pw := io.Pipe()

go func() {
defer pw.Close()
// create the tar writer.
tw := tar.NewWriter(pw)
defer tw.Close()

file, err := os.Open(path)
info, err := os.Stat(srcPath)
if err != nil {
pw.CloseWithError(fmt.Errorf("failed to open file: %w", err))
pw.CloseWithError(fmt.Errorf("failed to stat source path: %w", err))
return
}

defer file.Close()
info, err := file.Stat()
if err != nil {
pw.CloseWithError(fmt.Errorf("failed to stat file: %w", err))
return
}
// Handle directories and files differently.
if info.IsDir() {
// For directories, walk through and add all files/subdirs.
err = filepath.Walk(srcPath, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}

// Create a relative path for the tar file header.
relPath, err := filepath.Rel(workDir, path)
if err != nil {
return fmt.Errorf("failed to get relative path: %w", err)
}

header, err := tar.FileInfoHeader(info, "")
if err != nil {
return fmt.Errorf("failed to create tar header: %w", err)
}

// Set the header name to preserve directory structure.
header.Name = relPath
if err := tw.WriteHeader(header); err != nil {
return fmt.Errorf("failed to write header: %w", err)
}

if !info.IsDir() {
file, err := os.Open(path)
if err != nil {
return fmt.Errorf("failed to open file %s: %w", path, err)
}
defer file.Close()

if _, err := io.Copy(tw, file); err != nil {
return fmt.Errorf("failed to write file %s to tar: %w", path, err)
}
}

return nil
})

header, err := tar.FileInfoHeader(info, info.Name())
if err != nil {
pw.CloseWithError(fmt.Errorf("failed to create tar file info header: %w", err))
return
}
if err != nil {
pw.CloseWithError(fmt.Errorf("failed to walk directory: %w", err))
return
}
} else {
// For a single file, include the directory structure.
file, err := os.Open(srcPath)
if err != nil {
pw.CloseWithError(fmt.Errorf("failed to open file: %w", err))
return
}
defer file.Close()

if err := tw.WriteHeader(header); err != nil {
pw.CloseWithError(fmt.Errorf("failed to write header to tar writer: %w", err))
return
}
header, err := tar.FileInfoHeader(info, "")
if err != nil {
pw.CloseWithError(fmt.Errorf("failed to create tar header: %w", err))
return
}

_, err = io.Copy(tw, file)
if err != nil {
pw.CloseWithError(fmt.Errorf("failed to copy file to tar writer: %w", err))
return
// Use relative path as the header name to preserve directory structure
// This keeps the directory structure as part of the file path in the tar.
relPath, err := filepath.Rel(workDir, srcPath)
if err != nil {
pw.CloseWithError(fmt.Errorf("failed to get relative path: %w", err))
return
}

// Use the relative path (including directories) as the header name.
header.Name = relPath
if err := tw.WriteHeader(header); err != nil {
pw.CloseWithError(fmt.Errorf("failed to write header: %w", err))
return
}

if _, err := io.Copy(tw, file); err != nil {
pw.CloseWithError(fmt.Errorf("failed to copy file to tar: %w", err))
return
}
}
}()

return pr, nil
}

// Untar untars the target stream to the destination path.
// Untar extracts the contents of a tar archive from the provided reader
// to the specified destination path.
func Untar(reader io.Reader, destPath string) error {
// uncompress gzip if it is a .tar.gz file
// gzipReader, err := gzip.NewReader(reader)
// if err != nil {
// return err
// }
// defer gzipReader.Close()
// tarReader := tar.NewReader(gzipReader)

tarReader := tar.NewReader(reader)

// Ensure destination directory exists.
if err := os.MkdirAll(destPath, 0755); err != nil {
return err
return fmt.Errorf("failed to create destination directory: %w", err)
}

for {
Expand All @@ -90,39 +142,74 @@ func Untar(reader io.Reader, destPath string) error {
break
}
if err != nil {
return err
return fmt.Errorf("error reading tar: %w", err)
}

// sanitize filepaths to prevent directory traversal.
// Sanitize file paths to prevent directory traversal.
cleanPath := filepath.Clean(header.Name)
if strings.Contains(cleanPath, "..") {
if strings.Contains(cleanPath, "..") || strings.HasPrefix(cleanPath, "/") || strings.HasPrefix(cleanPath, ":\\") {
return fmt.Errorf("tar file contains invalid path: %s", cleanPath)
}

path := filepath.Join(destPath, cleanPath)
// check the file type.
targetPath := filepath.Join(destPath, cleanPath)

// Create directories for all path components.
dirPath := filepath.Dir(targetPath)
if err := os.MkdirAll(dirPath, 0755); err != nil {
return fmt.Errorf("failed to create directory %s: %w", dirPath, err)
}

switch header.Typeflag {
case tar.TypeDir:
if err := os.MkdirAll(path, 0755); err != nil {
return err
if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil {
return fmt.Errorf("failed to create directory %s: %w", targetPath, err)
}

case tar.TypeReg:
file, err := os.Create(path)
file, err := os.OpenFile(
targetPath,
os.O_CREATE|os.O_RDWR|os.O_TRUNC,
os.FileMode(header.Mode),
)
if err != nil {
return err
return fmt.Errorf("failed to create file %s: %w", targetPath, err)
}

if _, err := io.Copy(file, tarReader); err != nil {
file.Close()
return err
return fmt.Errorf("failed to write to file %s: %w", targetPath, err)
}
file.Close()

if err := os.Chmod(path, os.FileMode(header.Mode)); err != nil {
return err
case tar.TypeSymlink:
if isRel(header.Linkname, destPath) && isRel(header.Name, destPath) {
if err := os.Symlink(header.Linkname, targetPath); err != nil {
return fmt.Errorf("failed to create symlink %s -> %s: %w", targetPath, header.Linkname, err)
}
} else {
return fmt.Errorf("symlink %s -> %s points outside of destination directory", targetPath, header.Linkname)
}

default:
// Skip other types.
continue
}
}

return nil
}

// isRel checks if the candidate path is within the target directory after resolving symbolic links.
func isRel(candidate, target string) bool {
if filepath.IsAbs(candidate) {
return false
}

realpath, err := filepath.EvalSymlinks(filepath.Join(target, candidate))
if err != nil {
return false
}

relpath, err := filepath.Rel(target, realpath)
return err == nil && !strings.HasPrefix(filepath.Clean(relpath), "..")
}
95 changes: 95 additions & 0 deletions pkg/archiver/archiver_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright 2024 The CNAI Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package archiver

import (
"bytes"
"io"
"os"
"path/filepath"
"testing"
)

func TestTar(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "archiver_test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)

filePath := filepath.Join(tmpDir, "testfile.txt")
if err := os.WriteFile(filePath, []byte("hello"), 0644); err != nil {
t.Fatalf("write file error: %v", err)
}

tarReader, err := Tar(filePath, tmpDir)
if err != nil {
t.Fatalf("Tar error: %v", err)
}

var buf bytes.Buffer
if _, err := io.Copy(&buf, tarReader); err != nil {
t.Fatalf("copy tar error: %v", err)
}

if buf.Len() == 0 {
t.Fatal("tar archive is empty")
}
}

func TestUntar(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "archiver_test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)

filePath := filepath.Join(tmpDir, "testfile.txt")
if err := os.WriteFile(filePath, []byte("hello"), 0644); err != nil {
t.Fatalf("write file error: %v", err)
}

tarReader, err := Tar(filePath, tmpDir)
if err != nil {
t.Fatalf("Tar error: %v", err)
}

var buf bytes.Buffer
if _, err := io.Copy(&buf, tarReader); err != nil {
t.Fatalf("copy tar error: %v", err)
}

extractDir, err := os.MkdirTemp("", "archiver_extracted")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(extractDir)

if err := Untar(bytes.NewReader(buf.Bytes()), extractDir); err != nil {
t.Fatalf("Untar error: %v", err)
}

extractedFile := filepath.Join(extractDir, filepath.Base(filePath))
data, err := os.ReadFile(extractedFile)
if err != nil {
t.Fatalf("read extracted file error: %v", err)
}

if string(data) != "hello" {
t.Errorf("expected 'hello', got '%s'", string(data))
}
}
30 changes: 21 additions & 9 deletions pkg/backend/build/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,37 +21,49 @@ import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"time"

"github.com/CloudNativeAI/modctl/pkg/archiver"
"github.com/CloudNativeAI/modctl/pkg/modelfile"
"github.com/CloudNativeAI/modctl/pkg/storage"
modelspec "github.com/CloudNativeAI/model-spec/specs-go/v1"

modelspec "github.com/CloudNativeAI/model-spec/specs-go/v1"
godigest "github.com/opencontainers/go-digest"
spec "github.com/opencontainers/image-spec/specs-go"
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
)

// BuildLayer converts the file to the image blob and push it to the storage.
func BuildLayer(ctx context.Context, store storage.Storage, mediaType, workDir, repo, path string) (ocispec.Descriptor, error) {
reader, err := archiver.Tar(path)
info, err := os.Stat(path)
if err != nil {
return ocispec.Descriptor{}, fmt.Errorf("failed to tar file: %w", err)
return ocispec.Descriptor{}, fmt.Errorf("failed to get file info: %w", err)
}

digest, size, err := store.PushBlob(ctx, repo, reader)
if err != nil {
return ocispec.Descriptor{}, fmt.Errorf("failed to push blob to storage: %w", err)
if info.IsDir() {
return ocispec.Descriptor{}, fmt.Errorf("%s is a directory and not supported yet", path)
}

absPath, err := filepath.Abs(workDir)
workDirPath, err := filepath.Abs(workDir)
if err != nil {
return ocispec.Descriptor{}, fmt.Errorf("failed to get absolute path of workDir: %w", err)
}

filePath, err := filepath.Rel(absPath, path)
reader, err := archiver.Tar(path, workDirPath)
if err != nil {
return ocispec.Descriptor{}, fmt.Errorf("failed to tar file: %w", err)
}

digest, size, err := store.PushBlob(ctx, repo, reader)
if err != nil {
return ocispec.Descriptor{}, fmt.Errorf("failed to push blob to storage: %w", err)
}

// Gets the relative path of the file as annotation.
//nolint:typecheck
relPath, err := filepath.Rel(workDirPath, path)
if err != nil {
return ocispec.Descriptor{}, fmt.Errorf("failed to get relative path: %w", err)
}
Expand All @@ -61,7 +73,7 @@ func BuildLayer(ctx context.Context, store storage.Storage, mediaType, workDir,
Digest: godigest.Digest(digest),
Size: size,
Annotations: map[string]string{
modelspec.AnnotationFilepath: filePath,
modelspec.AnnotationFilepath: relPath,
},
}, nil
}
Expand Down
Loading