Skip to content

Commit 04fa549

Browse files
committed
write object can parallel
1 parent ccb29ae commit 04fa549

File tree

11 files changed

+684
-5
lines changed

11 files changed

+684
-5
lines changed

pkg/fileservice/aws_sdk_v2.go

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ import (
2323
"iter"
2424
"math"
2525
gotrace "runtime/trace"
26+
"sort"
2627
"strings"
28+
"sync"
2729
"time"
2830

2931
"github.com/aws/aws-sdk-go-v2/aws"
@@ -191,6 +193,7 @@ func NewAwsSDKv2(
191193
}
192194

193195
var _ ObjectStorage = new(AwsSDKv2)
196+
var _ ParallelMultipartWriter = new(AwsSDKv2)
194197

195198
func (a *AwsSDKv2) List(
196199
ctx context.Context,
@@ -444,6 +447,253 @@ func (a *AwsSDKv2) Write(
444447
return
445448
}
446449

450+
func (a *AwsSDKv2) SupportsParallelMultipart() bool {
451+
return true
452+
}
453+
454+
func (a *AwsSDKv2) WriteMultipartParallel(
455+
ctx context.Context,
456+
key string,
457+
r io.Reader,
458+
sizeHint *int64,
459+
opt *ParallelMultipartOption,
460+
) (err error) {
461+
defer wrapSizeMismatchErr(&err)
462+
463+
options := normalizeParallelOption(opt)
464+
if sizeHint != nil && *sizeHint < minMultipartPartSize {
465+
return a.Write(ctx, key, r, sizeHint, options.Expire)
466+
}
467+
if sizeHint != nil {
468+
expectedParts := (*sizeHint + options.PartSize - 1) / options.PartSize
469+
if expectedParts > maxMultipartParts {
470+
return moerr.NewInternalErrorNoCtxf("too many parts for multipart upload: %d", expectedParts)
471+
}
472+
}
473+
474+
ctx, cancel := context.WithCancel(ctx)
475+
defer cancel()
476+
477+
bufPool := sync.Pool{
478+
New: func() any {
479+
return make([]byte, options.PartSize)
480+
},
481+
}
482+
483+
readChunk := func() (buf []byte, n int, err error) {
484+
raw := bufPool.Get().([]byte)
485+
n, err = io.ReadFull(r, raw)
486+
switch {
487+
case errors.Is(err, io.EOF):
488+
bufPool.Put(raw)
489+
return nil, 0, io.EOF
490+
case errors.Is(err, io.ErrUnexpectedEOF):
491+
err = io.EOF
492+
return raw, n, err
493+
case err != nil:
494+
bufPool.Put(raw)
495+
return nil, 0, err
496+
default:
497+
return raw, n, nil
498+
}
499+
}
500+
501+
firstBuf, firstN, err := readChunk()
502+
if err != nil && !errors.Is(err, io.EOF) {
503+
return err
504+
}
505+
if firstN == 0 && errors.Is(err, io.EOF) {
506+
return nil
507+
}
508+
if errors.Is(err, io.EOF) && int64(firstN) < minMultipartPartSize {
509+
data := make([]byte, firstN)
510+
copy(data, firstBuf[:firstN])
511+
bufPool.Put(firstBuf)
512+
size := int64(firstN)
513+
return a.Write(ctx, key, bytes.NewReader(data), &size, options.Expire)
514+
}
515+
516+
output, createErr := DoWithRetry("create multipart upload", func() (*s3.CreateMultipartUploadOutput, error) {
517+
return a.client.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{
518+
Bucket: ptrTo(a.bucket),
519+
Key: ptrTo(key),
520+
Expires: options.Expire,
521+
})
522+
}, maxRetryAttemps, IsRetryableError)
523+
if createErr != nil {
524+
bufPool.Put(firstBuf)
525+
return createErr
526+
}
527+
528+
defer func() {
529+
if err != nil {
530+
_, abortErr := a.client.AbortMultipartUpload(ctx, &s3.AbortMultipartUploadInput{
531+
Bucket: ptrTo(a.bucket),
532+
Key: ptrTo(key),
533+
UploadId: output.UploadId,
534+
})
535+
err = errors.Join(err, abortErr)
536+
}
537+
}()
538+
539+
type partJob struct {
540+
num int32
541+
buf []byte
542+
n int
543+
}
544+
545+
var (
546+
partNum int32
547+
parts []types.CompletedPart
548+
partsLock sync.Mutex
549+
wg sync.WaitGroup
550+
errOnce sync.Once
551+
firstErr error
552+
)
553+
554+
setErr := func(e error) {
555+
if e == nil {
556+
return
557+
}
558+
errOnce.Do(func() {
559+
firstErr = e
560+
cancel()
561+
})
562+
}
563+
564+
jobCh := make(chan partJob, options.Concurrency*2)
565+
566+
startWorker := func() error {
567+
wg.Add(1)
568+
return getParallelUploadPool().Submit(func() {
569+
defer wg.Done()
570+
for job := range jobCh {
571+
if ctx.Err() != nil {
572+
bufPool.Put(job.buf)
573+
continue
574+
}
575+
uploadOutput, uploadErr := DoWithRetry("upload part", func() (*s3.UploadPartOutput, error) {
576+
return a.client.UploadPart(ctx, &s3.UploadPartInput{
577+
Bucket: ptrTo(a.bucket),
578+
Key: ptrTo(key),
579+
PartNumber: &job.num,
580+
UploadId: output.UploadId,
581+
Body: bytes.NewReader(job.buf[:job.n]),
582+
})
583+
}, maxRetryAttemps, IsRetryableError)
584+
if uploadErr != nil {
585+
setErr(uploadErr)
586+
bufPool.Put(job.buf)
587+
continue
588+
}
589+
bufPool.Put(job.buf)
590+
partsLock.Lock()
591+
parts = append(parts, types.CompletedPart{
592+
ETag: uploadOutput.ETag,
593+
PartNumber: ptrTo(job.num),
594+
})
595+
partsLock.Unlock()
596+
}
597+
})
598+
}
599+
600+
for i := 0; i < options.Concurrency; i++ {
601+
if submitErr := startWorker(); submitErr != nil {
602+
setErr(submitErr)
603+
break
604+
}
605+
}
606+
607+
sendJob := func(buf []byte, n int) bool {
608+
partNum++
609+
if partNum > maxMultipartParts {
610+
setErr(moerr.NewInternalErrorNoCtxf("too many parts for multipart upload: %d", partNum))
611+
bufPool.Put(buf)
612+
return false
613+
}
614+
job := partJob{
615+
num: partNum,
616+
buf: buf,
617+
n: n,
618+
}
619+
select {
620+
case jobCh <- job:
621+
return true
622+
case <-ctx.Done():
623+
bufPool.Put(buf)
624+
setErr(ctx.Err())
625+
return false
626+
}
627+
}
628+
629+
if !sendJob(firstBuf, firstN) {
630+
close(jobCh)
631+
wg.Wait()
632+
if firstErr != nil {
633+
return firstErr
634+
}
635+
return ctx.Err()
636+
}
637+
638+
for {
639+
nextBuf, nextN, readErr := readChunk()
640+
if errors.Is(readErr, io.EOF) && nextN == 0 {
641+
break
642+
}
643+
if readErr != nil && !errors.Is(readErr, io.EOF) {
644+
setErr(readErr)
645+
if nextBuf != nil {
646+
bufPool.Put(nextBuf)
647+
}
648+
break
649+
}
650+
if nextN == 0 {
651+
if nextBuf != nil {
652+
bufPool.Put(nextBuf)
653+
}
654+
break
655+
}
656+
if !sendJob(nextBuf, nextN) {
657+
break
658+
}
659+
if readErr != nil && errors.Is(readErr, io.EOF) {
660+
break
661+
}
662+
}
663+
664+
close(jobCh)
665+
wg.Wait()
666+
667+
if firstErr != nil {
668+
err = firstErr
669+
return err
670+
}
671+
if len(parts) == 0 {
672+
return nil
673+
}
674+
if len(parts) != int(partNum) {
675+
return moerr.NewInternalErrorNoCtxf("multipart upload incomplete, expect %d parts got %d", partNum, len(parts))
676+
}
677+
678+
sort.Slice(parts, func(i, j int) bool {
679+
return *parts[i].PartNumber < *parts[j].PartNumber
680+
})
681+
682+
_, err = DoWithRetry("complete multipart upload", func() (*s3.CompleteMultipartUploadOutput, error) {
683+
return a.client.CompleteMultipartUpload(ctx, &s3.CompleteMultipartUploadInput{
684+
Bucket: ptrTo(a.bucket),
685+
Key: ptrTo(key),
686+
UploadId: output.UploadId,
687+
MultipartUpload: &types.CompletedMultipartUpload{Parts: parts},
688+
})
689+
}, maxRetryAttemps, IsRetryableError)
690+
if err != nil {
691+
return err
692+
}
693+
694+
return nil
695+
}
696+
447697
func (a *AwsSDKv2) Read(
448698
ctx context.Context,
449699
key string,

pkg/fileservice/file_service.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,11 @@ type IOVector struct {
9797

9898
// Caches indicates extra caches to operate on
9999
Caches []IOVectorCache
100+
101+
// DisableParallel controls whether to skip parallel multipart uploads even if supported.
102+
DisableParallel bool
103+
// ForceParallel controls whether to try parallel multipart uploads when possible.
104+
ForceParallel bool
100105
}
101106

102107
type IOEntry struct {

pkg/fileservice/object_storage.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,63 @@ import (
1818
"context"
1919
"io"
2020
"iter"
21+
"runtime"
22+
"sync"
2123
"time"
24+
25+
"github.com/panjf2000/ants/v2"
2226
)
2327

2428
const smallObjectThreshold = 64 * (1 << 20)
29+
const (
30+
// defaultParallelMultipartPartSize defines the default per-part size for parallel multipart uploads.
31+
defaultParallelMultipartPartSize = 64 * (1 << 20)
32+
// minMultipartPartSize is the minimum allowed part size for S3-compatible multipart uploads.
33+
minMultipartPartSize = 5 * (1 << 20)
34+
// maxMultipartPartSize is the maximum allowed part size for S3-compatible multipart uploads.
35+
maxMultipartPartSize = 5 * (1 << 30)
36+
// maxMultipartParts is the maximum allowed parts for S3-compatible multipart uploads.
37+
maxMultipartParts = 10000
38+
)
39+
40+
var (
41+
parallelUploadPoolOnce sync.Once
42+
parallelUploadPool *ants.Pool
43+
)
44+
45+
func getParallelUploadPool() *ants.Pool {
46+
parallelUploadPoolOnce.Do(func() {
47+
pool, err := ants.NewPool(runtime.NumCPU())
48+
if err != nil {
49+
panic(err)
50+
}
51+
parallelUploadPool = pool
52+
})
53+
return parallelUploadPool
54+
}
55+
56+
func normalizeParallelOption(opt *ParallelMultipartOption) ParallelMultipartOption {
57+
res := ParallelMultipartOption{}
58+
if opt != nil {
59+
res = *opt
60+
}
61+
if res.PartSize <= 0 {
62+
res.PartSize = defaultParallelMultipartPartSize
63+
}
64+
if res.PartSize < minMultipartPartSize {
65+
res.PartSize = minMultipartPartSize
66+
}
67+
if res.PartSize > maxMultipartPartSize {
68+
res.PartSize = maxMultipartPartSize
69+
}
70+
if res.Concurrency <= 0 {
71+
res.Concurrency = runtime.NumCPU()
72+
}
73+
if res.Concurrency < 1 {
74+
res.Concurrency = 1
75+
}
76+
return res
77+
}
2578

2679
type ObjectStorage interface {
2780
// List lists objects with specified prefix
@@ -78,3 +131,25 @@ type ObjectStorage interface {
78131
err error,
79132
)
80133
}
134+
135+
// ParallelMultipartWriter is implemented by storages that support parallel multipart uploads.
136+
type ParallelMultipartWriter interface {
137+
SupportsParallelMultipart() bool
138+
WriteMultipartParallel(
139+
ctx context.Context,
140+
key string,
141+
r io.Reader,
142+
sizeHint *int64,
143+
opt *ParallelMultipartOption,
144+
) error
145+
}
146+
147+
// ParallelMultipartOption controls part size and parallelism of multipart uploads.
148+
type ParallelMultipartOption struct {
149+
// PartSize configures each part size; defaults to 64MB if zero.
150+
PartSize int64
151+
// Concurrency configures worker count; defaults to runtime.NumCPU() if zero.
152+
Concurrency int
153+
// Expire sets object expiration.
154+
Expire *time.Time
155+
}

0 commit comments

Comments
 (0)