Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 264 additions & 0 deletions pkg/fileservice/aws_sdk_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ import (
"iter"
"math"
gotrace "runtime/trace"
"sort"
"strings"
"sync"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
Expand Down Expand Up @@ -191,6 +193,7 @@ func NewAwsSDKv2(
}

var _ ObjectStorage = new(AwsSDKv2)
var _ ParallelMultipartWriter = new(AwsSDKv2)

func (a *AwsSDKv2) List(
ctx context.Context,
Expand Down Expand Up @@ -444,6 +447,267 @@ func (a *AwsSDKv2) Write(
return
}

func (a *AwsSDKv2) SupportsParallelMultipart() bool {
return true
}

func (a *AwsSDKv2) WriteMultipartParallel(
ctx context.Context,
key string,
r io.Reader,
sizeHint *int64,
opt *ParallelMultipartOption,
) (err error) {
defer wrapSizeMismatchErr(&err)

options := normalizeParallelOption(opt)
if sizeHint != nil && *sizeHint < minMultipartPartSize {
return a.Write(ctx, key, r, sizeHint, options.Expire)
}
if sizeHint != nil {
expectedParts := (*sizeHint + options.PartSize - 1) / options.PartSize
if expectedParts > maxMultipartParts {
return moerr.NewInternalErrorNoCtxf("too many parts for multipart upload: %d", expectedParts)
}
}

ctx, cancel := context.WithCancel(ctx)
defer cancel()

bufPool := sync.Pool{
New: func() any {
buf := make([]byte, options.PartSize)
return &buf
},
}

readChunk := func() (bufPtr *[]byte, buf []byte, n int, err error) {
bufPtr = bufPool.Get().(*[]byte)
raw := *bufPtr
n, err = io.ReadFull(r, raw)
switch {
case errors.Is(err, io.EOF):
bufPool.Put(bufPtr)
return nil, nil, 0, io.EOF
case errors.Is(err, io.ErrUnexpectedEOF):
err = io.EOF
return bufPtr, raw, n, err
case err != nil:
bufPool.Put(bufPtr)
return nil, nil, 0, err
default:
return bufPtr, raw, n, nil
}
}

firstBufPtr, firstBuf, firstN, err := readChunk()
if err != nil && !errors.Is(err, io.EOF) {
return err
}
if firstN == 0 && errors.Is(err, io.EOF) {
return nil
}
if errors.Is(err, io.EOF) && int64(firstN) < minMultipartPartSize {
data := make([]byte, firstN)
copy(data, firstBuf[:firstN])
bufPool.Put(firstBufPtr)
size := int64(firstN)
return a.Write(ctx, key, bytes.NewReader(data), &size, options.Expire)
}

output, createErr := DoWithRetry("create multipart upload", func() (*s3.CreateMultipartUploadOutput, error) {
return a.client.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{
Bucket: ptrTo(a.bucket),
Key: ptrTo(key),
Expires: options.Expire,
})
}, maxRetryAttemps, IsRetryableError)
if createErr != nil {
bufPool.Put(firstBufPtr)
return createErr
}

defer func() {
if err != nil {
_, abortErr := a.client.AbortMultipartUpload(ctx, &s3.AbortMultipartUploadInput{
Bucket: ptrTo(a.bucket),
Key: ptrTo(key),
UploadId: output.UploadId,
})
err = errors.Join(err, abortErr)
}
}()

type partJob struct {
num int32
buf []byte
bufPtr *[]byte
n int
}

var (
partNum int32
parts []types.CompletedPart
partsLock sync.Mutex
wg sync.WaitGroup
errOnce sync.Once
firstErr error
)

setErr := func(e error) {
if e == nil {
return
}
errOnce.Do(func() {
firstErr = e
cancel()
})
}

jobCh := make(chan partJob, options.Concurrency*2)

startWorker := func() error {
wg.Add(1)
return getParallelUploadPool().Submit(func() {
defer wg.Done()
for job := range jobCh {
if ctx.Err() != nil {
if job.bufPtr != nil {
bufPool.Put(job.bufPtr)
}
continue
}
uploadOutput, uploadErr := DoWithRetry("upload part", func() (*s3.UploadPartOutput, error) {
return a.client.UploadPart(ctx, &s3.UploadPartInput{
Bucket: ptrTo(a.bucket),
Key: ptrTo(key),
PartNumber: &job.num,
UploadId: output.UploadId,
Body: bytes.NewReader(job.buf[:job.n]),
})
}, maxRetryAttemps, IsRetryableError)
if uploadErr != nil {
setErr(uploadErr)
if job.bufPtr != nil {
bufPool.Put(job.bufPtr)
}
continue
}
if job.bufPtr != nil {
bufPool.Put(job.bufPtr)
}
partsLock.Lock()
parts = append(parts, types.CompletedPart{
ETag: uploadOutput.ETag,
PartNumber: ptrTo(job.num),
})
partsLock.Unlock()
}
})
}

for i := 0; i < options.Concurrency; i++ {
if submitErr := startWorker(); submitErr != nil {
setErr(submitErr)
break
}
}

sendJob := func(bufPtr *[]byte, buf []byte, n int) bool {
partNum++
if partNum > maxMultipartParts {
setErr(moerr.NewInternalErrorNoCtxf("too many parts for multipart upload: %d", partNum))
if bufPtr != nil {
bufPool.Put(bufPtr)
}
return false
}
job := partJob{
num: partNum,
buf: buf,
bufPtr: bufPtr,
n: n,
}
select {
case jobCh <- job:
return true
case <-ctx.Done():
if bufPtr != nil {
bufPool.Put(bufPtr)
}
setErr(ctx.Err())
return false
}
}

if !sendJob(firstBufPtr, firstBuf, firstN) {
close(jobCh)
wg.Wait()
if firstErr != nil {
return firstErr
}
return ctx.Err()
}

for {
nextBufPtr, nextBuf, nextN, readErr := readChunk()
if errors.Is(readErr, io.EOF) && nextN == 0 {
break
}
if readErr != nil && !errors.Is(readErr, io.EOF) {
setErr(readErr)
if nextBufPtr != nil {
bufPool.Put(nextBufPtr)
}
break
}
if nextN == 0 {
if nextBufPtr != nil {
bufPool.Put(nextBufPtr)
}
break
}
if !sendJob(nextBufPtr, nextBuf, nextN) {
break
}
if readErr != nil && errors.Is(readErr, io.EOF) {
break
}
}

close(jobCh)
wg.Wait()

if firstErr != nil {
err = firstErr
return err
}
if len(parts) == 0 {
return nil
}
if len(parts) != int(partNum) {
return moerr.NewInternalErrorNoCtxf("multipart upload incomplete, expect %d parts got %d", partNum, len(parts))
}

sort.Slice(parts, func(i, j int) bool {
return *parts[i].PartNumber < *parts[j].PartNumber
})

_, err = DoWithRetry("complete multipart upload", func() (*s3.CompleteMultipartUploadOutput, error) {
return a.client.CompleteMultipartUpload(ctx, &s3.CompleteMultipartUploadInput{
Bucket: ptrTo(a.bucket),
Key: ptrTo(key),
UploadId: output.UploadId,
MultipartUpload: &types.CompletedMultipartUpload{Parts: parts},
})
}, maxRetryAttemps, IsRetryableError)
if err != nil {
return err
}

return nil
}

func (a *AwsSDKv2) Read(
ctx context.Context,
key string,
Expand Down
14 changes: 14 additions & 0 deletions pkg/fileservice/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ func Get[T any](fs FileService, name string) (res T, err error) {

var NoDefaultCredentialsForETL = os.Getenv("MO_NO_DEFAULT_CREDENTIALS") != ""

func etlParallelMode(ctx context.Context) ParallelMode {
if mode, ok := parallelModeFromContext(ctx); ok {
return mode
}
if mode, ok := parseParallelMode(strings.TrimSpace(os.Getenv("MO_ETL_PARALLEL_MODE"))); ok {
return mode
}
return ParallelOff
}

// GetForETL get or creates a FileService instance for ETL operations
// if service part of path is empty, a LocalETLFS will be created
// if service part of path is not empty, a ETLFileService typed instance will be extracted from fs argument
Expand Down Expand Up @@ -110,6 +120,7 @@ func GetForETL(ctx context.Context, fs FileService, path string) (res ETLFileSer
KeySecret: accessSecret,
KeyPrefix: keyPrefix,
Name: name,
ParallelMode: etlParallelMode(ctx),
},
DisabledCacheConfig,
nil,
Expand Down Expand Up @@ -143,6 +154,7 @@ func GetForETL(ctx context.Context, fs FileService, path string) (res ETLFileSer
Bucket: bucket,
KeyPrefix: keyPrefix,
Name: name,
ParallelMode: etlParallelMode(ctx),
},
DisabledCacheConfig,
nil,
Expand All @@ -157,6 +169,7 @@ func GetForETL(ctx context.Context, fs FileService, path string) (res ETLFileSer
}
args.NoBucketValidation = true
args.IsHDFS = fsPath.Service == "hdfs"
args.ParallelMode = etlParallelMode(ctx)
res, err = NewS3FS(
ctx,
args,
Expand Down Expand Up @@ -198,6 +211,7 @@ func GetForETL(ctx context.Context, fs FileService, path string) (res ETLFileSer
KeyPrefix: keyPrefix,
Name: name,
IsMinio: true,
ParallelMode: etlParallelMode(ctx),
},
DisabledCacheConfig,
nil,
Expand Down
54 changes: 54 additions & 0 deletions pkg/fileservice/get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"iter"
)

func TestGetForBackup(t *testing.T) {
Expand All @@ -30,3 +31,56 @@ func TestGetForBackup(t *testing.T) {
assert.True(t, ok)
assert.Equal(t, dir, localFS.rootPath)
}

func TestGetForBackupS3Opts(t *testing.T) {
ctx := context.Background()
dir := t.TempDir()
spec := JoinPath("s3-opts,endpoint=disk,bucket="+dir+",prefix=backup-prefix,name=backup", "object")
fs, err := GetForBackup(ctx, spec)
assert.Nil(t, err)
s3fs, ok := fs.(*S3FS)
assert.True(t, ok)
assert.Equal(t, "backup", s3fs.name)
assert.Equal(t, "backup-prefix", s3fs.keyPrefix)
}

type dummyFileService struct{ name string }

func (d dummyFileService) Delete(ctx context.Context, filePaths ...string) error { return nil }
func (d dummyFileService) Name() string { return d.name }
func (d dummyFileService) Read(ctx context.Context, vector *IOVector) error { return nil }
func (d dummyFileService) ReadCache(ctx context.Context, vector *IOVector) error { return nil }
func (d dummyFileService) Write(ctx context.Context, vector IOVector) error { return nil }
func (d dummyFileService) List(ctx context.Context, dirPath string) iter.Seq2[*DirEntry, error] {
return func(yield func(*DirEntry, error) bool) {
yield(&DirEntry{Name: "a"}, nil)
}
}
func (d dummyFileService) StatFile(ctx context.Context, filePath string) (*DirEntry, error) {
return &DirEntry{Name: filePath}, nil
}
func (d dummyFileService) PrefetchFile(ctx context.Context, filePath string) error { return nil }
func (d dummyFileService) Cost() *CostAttr { return nil }
func (d dummyFileService) Close(ctx context.Context) {}

func TestGetFromMappings(t *testing.T) {
fs1 := dummyFileService{name: "first"}
fs2 := dummyFileService{name: "second"}
mapping, err := NewFileServices("first", fs1, fs2)
assert.NoError(t, err)

var res FileService
res, err = Get[FileService](mapping, "second")
assert.NoError(t, err)
assert.Equal(t, "second", res.Name())

_, err = Get[FileService](mapping, "missing")
assert.Error(t, err)

res, err = Get[FileService](fs1, "first")
assert.NoError(t, err)
assert.Equal(t, "first", res.Name())

_, err = Get[FileService](fs1, "other")
assert.Error(t, err)
}
Loading
Loading