diff --git a/pkg/fileservice/aws_sdk_v2.go b/pkg/fileservice/aws_sdk_v2.go
index 413a794c0f47a..189f337e95dfc 100644
--- a/pkg/fileservice/aws_sdk_v2.go
+++ b/pkg/fileservice/aws_sdk_v2.go
@@ -23,7 +23,9 @@ import (
"iter"
"math"
gotrace "runtime/trace"
+ "sort"
"strings"
+ "sync"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
@@ -191,6 +193,7 @@ func NewAwsSDKv2(
}
var _ ObjectStorage = new(AwsSDKv2)
+var _ ParallelMultipartWriter = new(AwsSDKv2)
func (a *AwsSDKv2) List(
ctx context.Context,
@@ -444,6 +447,267 @@ func (a *AwsSDKv2) Write(
return
}
+func (a *AwsSDKv2) SupportsParallelMultipart() bool {
+ return true
+}
+
+func (a *AwsSDKv2) WriteMultipartParallel(
+ ctx context.Context,
+ key string,
+ r io.Reader,
+ sizeHint *int64,
+ opt *ParallelMultipartOption,
+) (err error) {
+ defer wrapSizeMismatchErr(&err)
+
+ options := normalizeParallelOption(opt)
+ if sizeHint != nil && *sizeHint < minMultipartPartSize {
+ return a.Write(ctx, key, r, sizeHint, options.Expire)
+ }
+ if sizeHint != nil {
+ expectedParts := (*sizeHint + options.PartSize - 1) / options.PartSize
+ if expectedParts > maxMultipartParts {
+ return moerr.NewInternalErrorNoCtxf("too many parts for multipart upload: %d", expectedParts)
+ }
+ }
+
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ bufPool := sync.Pool{
+ New: func() any {
+ buf := make([]byte, options.PartSize)
+ return &buf
+ },
+ }
+
+ readChunk := func() (bufPtr *[]byte, buf []byte, n int, err error) {
+ bufPtr = bufPool.Get().(*[]byte)
+ raw := *bufPtr
+ n, err = io.ReadFull(r, raw)
+ switch {
+ case errors.Is(err, io.EOF):
+ bufPool.Put(bufPtr)
+ return nil, nil, 0, io.EOF
+ case errors.Is(err, io.ErrUnexpectedEOF):
+ err = io.EOF
+ return bufPtr, raw, n, err
+ case err != nil:
+ bufPool.Put(bufPtr)
+ return nil, nil, 0, err
+ default:
+ return bufPtr, raw, n, nil
+ }
+ }
+
+ firstBufPtr, firstBuf, firstN, err := readChunk()
+ if err != nil && !errors.Is(err, io.EOF) {
+ return err
+ }
+ if firstN == 0 && errors.Is(err, io.EOF) {
+ return nil
+ }
+ if errors.Is(err, io.EOF) && int64(firstN) < minMultipartPartSize {
+ data := make([]byte, firstN)
+ copy(data, firstBuf[:firstN])
+ bufPool.Put(firstBufPtr)
+ size := int64(firstN)
+ return a.Write(ctx, key, bytes.NewReader(data), &size, options.Expire)
+ }
+
+ output, createErr := DoWithRetry("create multipart upload", func() (*s3.CreateMultipartUploadOutput, error) {
+ return a.client.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{
+ Bucket: ptrTo(a.bucket),
+ Key: ptrTo(key),
+ Expires: options.Expire,
+ })
+ }, maxRetryAttemps, IsRetryableError)
+ if createErr != nil {
+ bufPool.Put(firstBufPtr)
+ return createErr
+ }
+
+ defer func() {
+ if err != nil {
+ _, abortErr := a.client.AbortMultipartUpload(ctx, &s3.AbortMultipartUploadInput{
+ Bucket: ptrTo(a.bucket),
+ Key: ptrTo(key),
+ UploadId: output.UploadId,
+ })
+ err = errors.Join(err, abortErr)
+ }
+ }()
+
+ type partJob struct {
+ num int32
+ buf []byte
+ bufPtr *[]byte
+ n int
+ }
+
+ var (
+ partNum int32
+ parts []types.CompletedPart
+ partsLock sync.Mutex
+ wg sync.WaitGroup
+ errOnce sync.Once
+ firstErr error
+ )
+
+ setErr := func(e error) {
+ if e == nil {
+ return
+ }
+ errOnce.Do(func() {
+ firstErr = e
+ cancel()
+ })
+ }
+
+ jobCh := make(chan partJob, options.Concurrency*2)
+
+ startWorker := func() error {
+ wg.Add(1)
+ return getParallelUploadPool().Submit(func() {
+ defer wg.Done()
+ for job := range jobCh {
+ if ctx.Err() != nil {
+ if job.bufPtr != nil {
+ bufPool.Put(job.bufPtr)
+ }
+ continue
+ }
+ uploadOutput, uploadErr := DoWithRetry("upload part", func() (*s3.UploadPartOutput, error) {
+ return a.client.UploadPart(ctx, &s3.UploadPartInput{
+ Bucket: ptrTo(a.bucket),
+ Key: ptrTo(key),
+ PartNumber: &job.num,
+ UploadId: output.UploadId,
+ Body: bytes.NewReader(job.buf[:job.n]),
+ })
+ }, maxRetryAttemps, IsRetryableError)
+ if uploadErr != nil {
+ setErr(uploadErr)
+ if job.bufPtr != nil {
+ bufPool.Put(job.bufPtr)
+ }
+ continue
+ }
+ if job.bufPtr != nil {
+ bufPool.Put(job.bufPtr)
+ }
+ partsLock.Lock()
+ parts = append(parts, types.CompletedPart{
+ ETag: uploadOutput.ETag,
+ PartNumber: ptrTo(job.num),
+ })
+ partsLock.Unlock()
+ }
+ })
+ }
+
+ for i := 0; i < options.Concurrency; i++ {
+ if submitErr := startWorker(); submitErr != nil {
+ setErr(submitErr)
+ break
+ }
+ }
+
+ sendJob := func(bufPtr *[]byte, buf []byte, n int) bool {
+ partNum++
+ if partNum > maxMultipartParts {
+ setErr(moerr.NewInternalErrorNoCtxf("too many parts for multipart upload: %d", partNum))
+ if bufPtr != nil {
+ bufPool.Put(bufPtr)
+ }
+ return false
+ }
+ job := partJob{
+ num: partNum,
+ buf: buf,
+ bufPtr: bufPtr,
+ n: n,
+ }
+ select {
+ case jobCh <- job:
+ return true
+ case <-ctx.Done():
+ if bufPtr != nil {
+ bufPool.Put(bufPtr)
+ }
+ setErr(ctx.Err())
+ return false
+ }
+ }
+
+ if !sendJob(firstBufPtr, firstBuf, firstN) {
+ close(jobCh)
+ wg.Wait()
+ if firstErr != nil {
+ return firstErr
+ }
+ return ctx.Err()
+ }
+
+ for {
+ nextBufPtr, nextBuf, nextN, readErr := readChunk()
+ if errors.Is(readErr, io.EOF) && nextN == 0 {
+ break
+ }
+ if readErr != nil && !errors.Is(readErr, io.EOF) {
+ setErr(readErr)
+ if nextBufPtr != nil {
+ bufPool.Put(nextBufPtr)
+ }
+ break
+ }
+ if nextN == 0 {
+ if nextBufPtr != nil {
+ bufPool.Put(nextBufPtr)
+ }
+ break
+ }
+ if !sendJob(nextBufPtr, nextBuf, nextN) {
+ break
+ }
+ if readErr != nil && errors.Is(readErr, io.EOF) {
+ break
+ }
+ }
+
+ close(jobCh)
+ wg.Wait()
+
+ if firstErr != nil {
+ err = firstErr
+ return err
+ }
+ if len(parts) == 0 {
+ return nil
+ }
+ if len(parts) != int(partNum) {
+ return moerr.NewInternalErrorNoCtxf("multipart upload incomplete, expect %d parts got %d", partNum, len(parts))
+ }
+
+ sort.Slice(parts, func(i, j int) bool {
+ return *parts[i].PartNumber < *parts[j].PartNumber
+ })
+
+ _, err = DoWithRetry("complete multipart upload", func() (*s3.CompleteMultipartUploadOutput, error) {
+ return a.client.CompleteMultipartUpload(ctx, &s3.CompleteMultipartUploadInput{
+ Bucket: ptrTo(a.bucket),
+ Key: ptrTo(key),
+ UploadId: output.UploadId,
+ MultipartUpload: &types.CompletedMultipartUpload{Parts: parts},
+ })
+ }, maxRetryAttemps, IsRetryableError)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
func (a *AwsSDKv2) Read(
ctx context.Context,
key string,
diff --git a/pkg/fileservice/get.go b/pkg/fileservice/get.go
index 1e5e6740c9dba..6a9d9ac384236 100644
--- a/pkg/fileservice/get.go
+++ b/pkg/fileservice/get.go
@@ -53,6 +53,16 @@ func Get[T any](fs FileService, name string) (res T, err error) {
var NoDefaultCredentialsForETL = os.Getenv("MO_NO_DEFAULT_CREDENTIALS") != ""
+func etlParallelMode(ctx context.Context) ParallelMode {
+ if mode, ok := parallelModeFromContext(ctx); ok {
+ return mode
+ }
+ if mode, ok := parseParallelMode(strings.TrimSpace(os.Getenv("MO_ETL_PARALLEL_MODE"))); ok {
+ return mode
+ }
+ return ParallelOff
+}
+
// GetForETL get or creates a FileService instance for ETL operations
// if service part of path is empty, a LocalETLFS will be created
// if service part of path is not empty, a ETLFileService typed instance will be extracted from fs argument
@@ -110,6 +120,7 @@ func GetForETL(ctx context.Context, fs FileService, path string) (res ETLFileSer
KeySecret: accessSecret,
KeyPrefix: keyPrefix,
Name: name,
+ ParallelMode: etlParallelMode(ctx),
},
DisabledCacheConfig,
nil,
@@ -143,6 +154,7 @@ func GetForETL(ctx context.Context, fs FileService, path string) (res ETLFileSer
Bucket: bucket,
KeyPrefix: keyPrefix,
Name: name,
+ ParallelMode: etlParallelMode(ctx),
},
DisabledCacheConfig,
nil,
@@ -157,6 +169,7 @@ func GetForETL(ctx context.Context, fs FileService, path string) (res ETLFileSer
}
args.NoBucketValidation = true
args.IsHDFS = fsPath.Service == "hdfs"
+ args.ParallelMode = etlParallelMode(ctx)
res, err = NewS3FS(
ctx,
args,
@@ -198,6 +211,7 @@ func GetForETL(ctx context.Context, fs FileService, path string) (res ETLFileSer
KeyPrefix: keyPrefix,
Name: name,
IsMinio: true,
+ ParallelMode: etlParallelMode(ctx),
},
DisabledCacheConfig,
nil,
diff --git a/pkg/fileservice/get_test.go b/pkg/fileservice/get_test.go
index 17a69b9df3d8e..125dc31be0bc1 100644
--- a/pkg/fileservice/get_test.go
+++ b/pkg/fileservice/get_test.go
@@ -19,6 +19,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
+ "iter"
)
func TestGetForBackup(t *testing.T) {
@@ -30,3 +31,56 @@ func TestGetForBackup(t *testing.T) {
assert.True(t, ok)
assert.Equal(t, dir, localFS.rootPath)
}
+
+func TestGetForBackupS3Opts(t *testing.T) {
+ ctx := context.Background()
+ dir := t.TempDir()
+ spec := JoinPath("s3-opts,endpoint=disk,bucket="+dir+",prefix=backup-prefix,name=backup", "object")
+ fs, err := GetForBackup(ctx, spec)
+ assert.Nil(t, err)
+ s3fs, ok := fs.(*S3FS)
+ assert.True(t, ok)
+ assert.Equal(t, "backup", s3fs.name)
+ assert.Equal(t, "backup-prefix", s3fs.keyPrefix)
+}
+
+type dummyFileService struct{ name string }
+
+func (d dummyFileService) Delete(ctx context.Context, filePaths ...string) error { return nil }
+func (d dummyFileService) Name() string { return d.name }
+func (d dummyFileService) Read(ctx context.Context, vector *IOVector) error { return nil }
+func (d dummyFileService) ReadCache(ctx context.Context, vector *IOVector) error { return nil }
+func (d dummyFileService) Write(ctx context.Context, vector IOVector) error { return nil }
+func (d dummyFileService) List(ctx context.Context, dirPath string) iter.Seq2[*DirEntry, error] {
+ return func(yield func(*DirEntry, error) bool) {
+ yield(&DirEntry{Name: "a"}, nil)
+ }
+}
+func (d dummyFileService) StatFile(ctx context.Context, filePath string) (*DirEntry, error) {
+ return &DirEntry{Name: filePath}, nil
+}
+func (d dummyFileService) PrefetchFile(ctx context.Context, filePath string) error { return nil }
+func (d dummyFileService) Cost() *CostAttr { return nil }
+func (d dummyFileService) Close(ctx context.Context) {}
+
+func TestGetFromMappings(t *testing.T) {
+ fs1 := dummyFileService{name: "first"}
+ fs2 := dummyFileService{name: "second"}
+ mapping, err := NewFileServices("first", fs1, fs2)
+ assert.NoError(t, err)
+
+ var res FileService
+ res, err = Get[FileService](mapping, "second")
+ assert.NoError(t, err)
+ assert.Equal(t, "second", res.Name())
+
+ _, err = Get[FileService](mapping, "missing")
+ assert.Error(t, err)
+
+ res, err = Get[FileService](fs1, "first")
+ assert.NoError(t, err)
+ assert.Equal(t, "first", res.Name())
+
+ _, err = Get[FileService](fs1, "other")
+ assert.Error(t, err)
+}
diff --git a/pkg/fileservice/object_storage.go b/pkg/fileservice/object_storage.go
index 426da383721c9..2e61fb48ab15c 100644
--- a/pkg/fileservice/object_storage.go
+++ b/pkg/fileservice/object_storage.go
@@ -18,10 +18,63 @@ import (
"context"
"io"
"iter"
+ "runtime"
+ "sync"
"time"
+
+ "github.com/panjf2000/ants/v2"
)
const smallObjectThreshold = 64 * (1 << 20)
+const (
+ // defaultParallelMultipartPartSize defines the default per-part size for parallel multipart uploads.
+ defaultParallelMultipartPartSize = 64 * (1 << 20)
+ // minMultipartPartSize is the minimum allowed part size for S3-compatible multipart uploads.
+ minMultipartPartSize = 5 * (1 << 20)
+ // maxMultipartPartSize is the maximum allowed part size for S3-compatible multipart uploads.
+ maxMultipartPartSize = 5 * (1 << 30)
+ // maxMultipartParts is the maximum allowed parts for S3-compatible multipart uploads.
+ maxMultipartParts = 10000
+)
+
+var (
+ parallelUploadPoolOnce sync.Once
+ parallelUploadPool *ants.Pool
+)
+
+func getParallelUploadPool() *ants.Pool {
+ parallelUploadPoolOnce.Do(func() {
+ pool, err := ants.NewPool(runtime.NumCPU())
+ if err != nil {
+ panic(err)
+ }
+ parallelUploadPool = pool
+ })
+ return parallelUploadPool
+}
+
+func normalizeParallelOption(opt *ParallelMultipartOption) ParallelMultipartOption {
+ res := ParallelMultipartOption{}
+ if opt != nil {
+ res = *opt
+ }
+ if res.PartSize <= 0 {
+ res.PartSize = defaultParallelMultipartPartSize
+ }
+ if res.PartSize < minMultipartPartSize {
+ res.PartSize = minMultipartPartSize
+ }
+ if res.PartSize > maxMultipartPartSize {
+ res.PartSize = maxMultipartPartSize
+ }
+ if res.Concurrency <= 0 {
+ res.Concurrency = runtime.NumCPU()
+ }
+ if res.Concurrency < 1 {
+ res.Concurrency = 1
+ }
+ return res
+}
type ObjectStorage interface {
// List lists objects with specified prefix
@@ -78,3 +131,25 @@ type ObjectStorage interface {
err error,
)
}
+
+// ParallelMultipartWriter is implemented by storages that support parallel multipart uploads.
+type ParallelMultipartWriter interface {
+ SupportsParallelMultipart() bool
+ WriteMultipartParallel(
+ ctx context.Context,
+ key string,
+ r io.Reader,
+ sizeHint *int64,
+ opt *ParallelMultipartOption,
+ ) error
+}
+
+// ParallelMultipartOption controls part size and parallelism of multipart uploads.
+type ParallelMultipartOption struct {
+ // PartSize configures each part size; defaults to 64MB if zero.
+ PartSize int64
+ // Concurrency configures worker count; defaults to runtime.NumCPU() if zero.
+ Concurrency int
+ // Expire sets object expiration.
+ Expire *time.Time
+}
diff --git a/pkg/fileservice/object_storage_arguments.go b/pkg/fileservice/object_storage_arguments.go
index 5d178a65796f8..f6fde51332aa0 100644
--- a/pkg/fileservice/object_storage_arguments.go
+++ b/pkg/fileservice/object_storage_arguments.go
@@ -28,13 +28,14 @@ import (
type ObjectStorageArguments struct {
// misc
- Name string `toml:"name"`
- KeyPrefix string `toml:"key-prefix"`
- SharedConfigProfile string `toml:"shared-config-profile"`
- NoDefaultCredentials bool `toml:"no-default-credentials"`
- NoBucketValidation bool `toml:"no-bucket-validation"`
- Concurrency int64 `toml:"concurrency"`
- MaxConnsPerHost int `toml:"max-conns-per-host"`
+ Name string `toml:"name"`
+ KeyPrefix string `toml:"key-prefix"`
+ SharedConfigProfile string `toml:"shared-config-profile"`
+ NoDefaultCredentials bool `toml:"no-default-credentials"`
+ NoBucketValidation bool `toml:"no-bucket-validation"`
+ Concurrency int64 `toml:"concurrency"`
+ MaxConnsPerHost int `toml:"max-conns-per-host"`
+ ParallelMode ParallelMode `toml:"parallel-mode"`
// s3
Bucket string `toml:"bucket"`
@@ -107,6 +108,10 @@ func (o *ObjectStorageArguments) SetFromString(arguments []string) error {
if err == nil {
o.MaxConnsPerHost = n
}
+ case "parallel-mode", "parallel":
+ if mode, ok := parseParallelMode(value); ok {
+ o.ParallelMode = mode
+ }
case "bucket":
o.Bucket = value
diff --git a/pkg/fileservice/object_storage_arguments_test.go b/pkg/fileservice/object_storage_arguments_test.go
index c8e6eb525e903..d87eee3475065 100644
--- a/pkg/fileservice/object_storage_arguments_test.go
+++ b/pkg/fileservice/object_storage_arguments_test.go
@@ -185,6 +185,61 @@ func TestAWSRegion(t *testing.T) {
assert.NotNil(t, args.validate())
}
+func TestSetFromStringParallelMode(t *testing.T) {
+ var args ObjectStorageArguments
+ assert.NoError(t, args.SetFromString([]string{"parallel-mode=force"}))
+ assert.Equal(t, ParallelForce, args.ParallelMode)
+
+ args = ObjectStorageArguments{
+ ParallelMode: ParallelAuto,
+ }
+ assert.NoError(t, args.SetFromString([]string{"parallel-mode=unknown"}))
+ assert.Equal(t, ParallelAuto, args.ParallelMode)
+}
+
+func TestObjectStorageArgumentsValidateDefaults(t *testing.T) {
+ args := ObjectStorageArguments{
+ Endpoint: "example.com",
+ }
+ assert.NoError(t, args.validate())
+ assert.Equal(t, "https://example.com", args.Endpoint)
+ assert.Equal(t, "mo-service", args.RoleSessionName)
+}
+
+func TestObjectStorageArgumentsShouldLoadDefaultCredentials(t *testing.T) {
+ t.Setenv("AWS_ACCESS_KEY_ID", "ak")
+ t.Setenv("AWS_SECRET_ACCESS_KEY", "sk")
+ args := ObjectStorageArguments{}
+ assert.True(t, args.shouldLoadDefaultCredentials())
+
+ args = ObjectStorageArguments{
+ NoDefaultCredentials: true,
+ KeyID: "id",
+ KeySecret: "secret",
+ }
+ assert.False(t, args.shouldLoadDefaultCredentials())
+
+ args = ObjectStorageArguments{
+ NoDefaultCredentials: true,
+ RoleARN: "arn",
+ }
+ assert.True(t, args.shouldLoadDefaultCredentials())
+}
+
+func TestObjectStorageArgumentsString(t *testing.T) {
+ args := ObjectStorageArguments{
+ Name: "foo",
+ KeyPrefix: "bar",
+ Concurrency: 3,
+ }
+ s := args.String()
+ var decoded ObjectStorageArguments
+ assert.NoError(t, json.Unmarshal([]byte(s), &decoded))
+ assert.Equal(t, args.Name, decoded.Name)
+ assert.Equal(t, args.KeyPrefix, decoded.KeyPrefix)
+ assert.Equal(t, args.Concurrency, decoded.Concurrency)
+}
+
func TestParseHDFSArgs(t *testing.T) {
var args ObjectStorageArguments
if err := args.SetFromString([]string{
diff --git a/pkg/fileservice/object_storage_http_trace.go b/pkg/fileservice/object_storage_http_trace.go
index 7a7602a1618ad..7c9350957d9b7 100644
--- a/pkg/fileservice/object_storage_http_trace.go
+++ b/pkg/fileservice/object_storage_http_trace.go
@@ -21,6 +21,7 @@ import (
"net/http/httptrace"
"time"
+ "github.com/matrixorigin/matrixone/pkg/common/moerr"
"github.com/matrixorigin/matrixone/pkg/common/reuse"
)
@@ -35,6 +36,7 @@ func newObjectStorageHTTPTrace(upstream ObjectStorage) *objectStorageHTTPTrace {
}
var _ ObjectStorage = new(objectStorageHTTPTrace)
+var _ ParallelMultipartWriter = new(objectStorageHTTPTrace)
func (o *objectStorageHTTPTrace) Delete(ctx context.Context, keys ...string) (err error) {
traceInfo := o.newTraceInfo()
@@ -78,6 +80,23 @@ func (o *objectStorageHTTPTrace) Write(ctx context.Context, key string, r io.Rea
return o.upstream.Write(ctx, key, r, sizeHint, expire)
}
+func (o *objectStorageHTTPTrace) SupportsParallelMultipart() bool {
+ if p, ok := o.upstream.(ParallelMultipartWriter); ok {
+ return p.SupportsParallelMultipart()
+ }
+ return false
+}
+
+func (o *objectStorageHTTPTrace) WriteMultipartParallel(ctx context.Context, key string, r io.Reader, sizeHint *int64, opt *ParallelMultipartOption) (err error) {
+ traceInfo := o.newTraceInfo()
+ defer o.closeTraceInfo(traceInfo)
+ ctx = httptrace.WithClientTrace(ctx, traceInfo.trace)
+ if p, ok := o.upstream.(ParallelMultipartWriter); ok {
+ return p.WriteMultipartParallel(ctx, key, r, sizeHint, opt)
+ }
+ return moerr.NewNotSupportedNoCtx("parallel multipart upload")
+}
+
func (o *objectStorageHTTPTrace) newTraceInfo() *traceInfo {
return reuse.Alloc[traceInfo](nil)
}
diff --git a/pkg/fileservice/object_storage_http_trace_test.go b/pkg/fileservice/object_storage_http_trace_test.go
new file mode 100644
index 0000000000000..789bb6f83a311
--- /dev/null
+++ b/pkg/fileservice/object_storage_http_trace_test.go
@@ -0,0 +1,78 @@
+// Copyright 2025 Matrix Origin
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fileservice
+
+import (
+ "context"
+ "net/http/httptrace"
+ "strings"
+ "testing"
+
+ "github.com/matrixorigin/matrixone/pkg/common/moerr"
+ "github.com/stretchr/testify/require"
+)
+
+func TestObjectStorageHTTPTraceWriteMultipartParallel(t *testing.T) {
+ upstream := &mockParallelObjectStorage{supports: true}
+ wrapped := newObjectStorageHTTPTrace(upstream)
+
+ err := wrapped.WriteMultipartParallel(context.Background(), "key", strings.NewReader("data"), nil, nil)
+ require.NoError(t, err)
+ require.NotNil(t, upstream.ctx)
+ require.NotNil(t, httptrace.ContextClientTrace(upstream.ctx))
+ require.Equal(t, "key", upstream.key)
+}
+
+func TestObjectStorageHTTPTraceWriteMultipartParallelUnsupported(t *testing.T) {
+ wrapped := newObjectStorageHTTPTrace(dummyObjectStorage{})
+
+ err := wrapped.WriteMultipartParallel(context.Background(), "key", strings.NewReader("data"), nil, nil)
+ require.Error(t, err)
+ require.True(t, moerr.IsMoErrCode(err, moerr.ErrNotSupported))
+}
+
+func TestObjectStorageHTTPTraceDelegates(t *testing.T) {
+ upstream := &recordingObjectStorage{}
+ wrapped := newObjectStorageHTTPTrace(upstream)
+ ctx := context.Background()
+
+ require.NoError(t, wrapped.Delete(ctx, "a"))
+ exists, err := wrapped.Exists(ctx, "b")
+ require.NoError(t, err)
+ require.True(t, exists)
+ iterSeq := wrapped.List(ctx, "c")
+ var listed []string
+ iterSeq(func(entry *DirEntry, err error) bool {
+ require.NoError(t, err)
+ listed = append(listed, entry.Name)
+ return true
+ })
+ reader, err := wrapped.Read(ctx, "d", nil, nil)
+ require.NoError(t, err)
+ defer reader.Close()
+ buf := make([]byte, 4)
+ _, _ = reader.Read(buf)
+ size, err := wrapped.Stat(ctx, "e")
+ require.NoError(t, err)
+ require.Equal(t, int64(3), size)
+ require.NoError(t, wrapped.Write(ctx, "f", strings.NewReader("payload"), nil, nil))
+
+ require.Len(t, upstream.calls, 6)
+ for _, ctx := range upstream.ctxs {
+ require.NotNil(t, httptrace.ContextClientTrace(ctx))
+ }
+ require.ElementsMatch(t, []string{"delete", "exists", "list", "read", "stat", "write"}, upstream.calls)
+ require.Equal(t, []string{"one"}, listed)
+}
diff --git a/pkg/fileservice/object_storage_metrics.go b/pkg/fileservice/object_storage_metrics.go
index 926b5d885eeaa..b0bce4c67207a 100644
--- a/pkg/fileservice/object_storage_metrics.go
+++ b/pkg/fileservice/object_storage_metrics.go
@@ -20,6 +20,7 @@ import (
"iter"
"time"
+ "github.com/matrixorigin/matrixone/pkg/common/moerr"
metric "github.com/matrixorigin/matrixone/pkg/util/metric/v2"
"github.com/prometheus/client_golang/prometheus"
)
@@ -55,6 +56,7 @@ func newObjectStorageMetrics(
}
var _ ObjectStorage = new(objectStorageMetrics)
+var _ ParallelMultipartWriter = new(objectStorageMetrics)
func (o *objectStorageMetrics) Delete(ctx context.Context, keys ...string) (err error) {
o.numDelete.Inc()
@@ -99,3 +101,18 @@ func (o *objectStorageMetrics) Write(ctx context.Context, key string, r io.Reade
o.numWrite.Inc()
return o.upstream.Write(ctx, key, r, sizeHint, expire)
}
+
+func (o *objectStorageMetrics) SupportsParallelMultipart() bool {
+ if p, ok := o.upstream.(ParallelMultipartWriter); ok {
+ return p.SupportsParallelMultipart()
+ }
+ return false
+}
+
+func (o *objectStorageMetrics) WriteMultipartParallel(ctx context.Context, key string, r io.Reader, sizeHint *int64, opt *ParallelMultipartOption) error {
+ o.numWrite.Inc()
+ if p, ok := o.upstream.(ParallelMultipartWriter); ok {
+ return p.WriteMultipartParallel(ctx, key, r, sizeHint, opt)
+ }
+ return moerr.NewNotSupportedNoCtx("parallel multipart upload")
+}
diff --git a/pkg/fileservice/object_storage_metrics_test.go b/pkg/fileservice/object_storage_metrics_test.go
new file mode 100644
index 0000000000000..5a7c99592c7ac
--- /dev/null
+++ b/pkg/fileservice/object_storage_metrics_test.go
@@ -0,0 +1,93 @@
+// Copyright 2025 Matrix Origin
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fileservice
+
+import (
+ "context"
+ "strings"
+ "testing"
+
+ "github.com/matrixorigin/matrixone/pkg/common/moerr"
+ metric "github.com/matrixorigin/matrixone/pkg/util/metric/v2"
+ "github.com/prometheus/client_golang/prometheus/testutil"
+ "github.com/stretchr/testify/require"
+)
+
+func TestObjectStorageMetricsWriteMultipartParallel(t *testing.T) {
+ name := t.Name()
+ upstream := &mockParallelObjectStorage{supports: true}
+ wrapped := newObjectStorageMetrics(upstream, name)
+ gauge := metric.FSObjectStorageOperations.WithLabelValues(name, "write")
+ before := testutil.ToFloat64(gauge)
+
+ err := wrapped.WriteMultipartParallel(context.Background(), "key", strings.NewReader("data"), nil, nil)
+ require.NoError(t, err)
+ require.Equal(t, before+1, testutil.ToFloat64(gauge))
+ require.Equal(t, "key", upstream.key)
+}
+
+func TestObjectStorageMetricsWriteMultipartParallelNotSupported(t *testing.T) {
+ name := t.Name()
+ wrapped := newObjectStorageMetrics(dummyObjectStorage{}, name)
+ gauge := metric.FSObjectStorageOperations.WithLabelValues(name, "write")
+ before := testutil.ToFloat64(gauge)
+
+ err := wrapped.WriteMultipartParallel(context.Background(), "key", strings.NewReader("data"), nil, nil)
+ require.Error(t, err)
+ require.True(t, moerr.IsMoErrCode(err, moerr.ErrNotSupported))
+ require.Equal(t, before+1, testutil.ToFloat64(gauge))
+}
+
+func TestObjectStorageMetricsDelegates(t *testing.T) {
+ name := t.Name()
+ upstream := &recordingObjectStorage{}
+ wrapped := newObjectStorageMetrics(upstream, name)
+
+ require.NoError(t, wrapped.Delete(context.Background(), "a"))
+ _, _ = wrapped.Exists(context.Background(), "b")
+ seq := wrapped.List(context.Background(), "c")
+ seq(func(_ *DirEntry, _ error) bool { return true })
+ rc, err := wrapped.Read(context.Background(), "d", nil, nil)
+ require.NoError(t, err)
+ require.NoError(t, rc.Close())
+ _, _ = wrapped.Stat(context.Background(), "e")
+ require.NoError(t, wrapped.Write(context.Background(), "f", strings.NewReader("x"), nil, nil))
+
+ require.ElementsMatch(t, []string{
+ "delete", "exists", "list", "read", "stat", "write",
+ }, upstream.calls)
+
+ // gauges incremented
+ require.True(t, testutil.ToFloat64(metric.FSObjectStorageOperations.WithLabelValues(name, "delete")) >= 1)
+ require.True(t, testutil.ToFloat64(metric.FSObjectStorageOperations.WithLabelValues(name, "exists")) >= 1)
+ require.True(t, testutil.ToFloat64(metric.FSObjectStorageOperations.WithLabelValues(name, "list")) >= 1)
+ require.True(t, testutil.ToFloat64(metric.FSObjectStorageOperations.WithLabelValues(name, "read")) >= 1)
+ require.True(t, testutil.ToFloat64(metric.FSObjectStorageOperations.WithLabelValues(name, "stat")) >= 1)
+ require.True(t, testutil.ToFloat64(metric.FSObjectStorageOperations.WithLabelValues(name, "write")) >= 1)
+}
+
+func TestObjectStorageMetricsReadCloseDecrementsActive(t *testing.T) {
+ name := t.Name()
+ upstream := &recordingObjectStorage{}
+ wrapped := newObjectStorageMetrics(upstream, name)
+ active := metric.FSObjectStorageOperations.WithLabelValues(name, "active-read")
+ before := testutil.ToFloat64(active)
+
+ r, err := wrapped.Read(context.Background(), "key", nil, nil)
+ require.NoError(t, err)
+ require.Equal(t, before+1, testutil.ToFloat64(active))
+ require.NoError(t, r.Close())
+ require.Equal(t, before, testutil.ToFloat64(active))
+}
diff --git a/pkg/fileservice/object_storage_semaphore.go b/pkg/fileservice/object_storage_semaphore.go
index 243e1eaf7c805..a6d4768af15d0 100644
--- a/pkg/fileservice/object_storage_semaphore.go
+++ b/pkg/fileservice/object_storage_semaphore.go
@@ -16,6 +16,7 @@ package fileservice
import (
"context"
+ "github.com/matrixorigin/matrixone/pkg/common/moerr"
"io"
"iter"
"sync"
@@ -46,6 +47,7 @@ func (o *objectStorageSemaphore) release() {
}
var _ ObjectStorage = new(objectStorageSemaphore)
+var _ ParallelMultipartWriter = new(objectStorageSemaphore)
func (o *objectStorageSemaphore) Delete(ctx context.Context, keys ...string) (err error) {
o.acquire()
@@ -107,3 +109,19 @@ func (o *objectStorageSemaphore) Write(ctx context.Context, key string, r io.Rea
defer o.release()
return o.upstream.Write(ctx, key, r, sizeHint, expire)
}
+
+func (o *objectStorageSemaphore) SupportsParallelMultipart() bool {
+ if p, ok := o.upstream.(ParallelMultipartWriter); ok {
+ return p.SupportsParallelMultipart()
+ }
+ return false
+}
+
+func (o *objectStorageSemaphore) WriteMultipartParallel(ctx context.Context, key string, r io.Reader, sizeHint *int64, opt *ParallelMultipartOption) (err error) {
+ o.acquire()
+ defer o.release()
+ if p, ok := o.upstream.(ParallelMultipartWriter); ok {
+ return p.WriteMultipartParallel(ctx, key, r, sizeHint, opt)
+ }
+ return moerr.NewNotSupportedNoCtx("parallel multipart upload")
+}
diff --git a/pkg/fileservice/object_storage_semaphore_test.go b/pkg/fileservice/object_storage_semaphore_test.go
new file mode 100644
index 0000000000000..c9adcae8919d9
--- /dev/null
+++ b/pkg/fileservice/object_storage_semaphore_test.go
@@ -0,0 +1,88 @@
+// Copyright 2025 Matrix Origin
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fileservice
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestObjectStorageSemaphoreSerializes(t *testing.T) {
+ start := make(chan struct{}, 2)
+ wait := make(chan struct{})
+ upstream := &blockingObjectStorage{
+ start: start,
+ wait: wait,
+ }
+ sem := newObjectStorageSemaphore(upstream, 1)
+
+ done := make(chan struct{})
+ go func() {
+ require.NoError(t, sem.Write(context.Background(), "a", nil, nil, nil))
+ close(done)
+ }()
+
+ select {
+ case <-start:
+ case <-time.After(time.Second):
+ t.Fatal("first write did not start")
+ }
+
+ startSecond := make(chan struct{})
+ go func() {
+ defer close(startSecond)
+ require.NoError(t, sem.Write(context.Background(), "b", nil, nil, nil))
+ }()
+
+ select {
+ case <-startSecond:
+ t.Fatal("second write started before release")
+ case <-time.After(50 * time.Millisecond):
+ }
+
+ close(wait) // release first
+ select {
+ case <-startSecond:
+ case <-time.After(time.Second):
+ t.Fatal("second write not started after release")
+ }
+ <-done
+}
+
+func TestObjectStorageSemaphoreReleasesOnError(t *testing.T) {
+ start := make(chan struct{}, 1)
+ wait := make(chan struct{})
+ upstream := &blockingObjectStorage{
+ start: start,
+ wait: wait,
+ err: context.DeadlineExceeded,
+ }
+ sem := newObjectStorageSemaphore(upstream, 1)
+
+ // release the blocked write once it has started
+ go func() {
+ <-start
+ close(wait)
+ }()
+
+ err := sem.Write(context.Background(), "a", nil, nil, nil)
+ require.Error(t, err)
+
+ // another call should proceed after the failed one
+ require.NoError(t, sem.Delete(context.Background(), "x"))
+}
diff --git a/pkg/fileservice/object_storage_test_helper_test.go b/pkg/fileservice/object_storage_test_helper_test.go
new file mode 100644
index 0000000000000..3ddfb8687cda3
--- /dev/null
+++ b/pkg/fileservice/object_storage_test_helper_test.go
@@ -0,0 +1,130 @@
+// Copyright 2025 Matrix Origin
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fileservice
+
+import (
+ "context"
+ "io"
+ "iter"
+ "strings"
+ "time"
+)
+
+type dummyObjectStorage struct{}
+
+func (dummyObjectStorage) Delete(ctx context.Context, keys ...string) error {
+ return nil
+}
+
+func (dummyObjectStorage) Exists(ctx context.Context, key string) (bool, error) {
+ return false, nil
+}
+
+func (dummyObjectStorage) List(ctx context.Context, prefix string) iter.Seq2[*DirEntry, error] {
+ return func(yield func(*DirEntry, error) bool) {}
+}
+
+func (dummyObjectStorage) Read(ctx context.Context, key string, min *int64, max *int64) (io.ReadCloser, error) {
+ return io.NopCloser(strings.NewReader("")), nil
+}
+
+func (dummyObjectStorage) Stat(ctx context.Context, key string) (int64, error) {
+ return 0, nil
+}
+
+func (dummyObjectStorage) Write(ctx context.Context, key string, r io.Reader, sizeHint *int64, expire *time.Time) error {
+ _, _ = io.Copy(io.Discard, r)
+ return nil
+}
+
+type mockParallelObjectStorage struct {
+ dummyObjectStorage
+ ctx context.Context
+ key string
+ sizeHint *int64
+ opt *ParallelMultipartOption
+ err error
+ supports bool
+}
+
+func (m *mockParallelObjectStorage) SupportsParallelMultipart() bool {
+ return m.supports
+}
+
+func (m *mockParallelObjectStorage) WriteMultipartParallel(ctx context.Context, key string, r io.Reader, sizeHint *int64, opt *ParallelMultipartOption) error {
+ m.ctx = ctx
+ m.key = key
+ m.sizeHint = sizeHint
+ m.opt = opt
+ return m.err
+}
+
+type recordingObjectStorage struct {
+ calls []string
+ ctxs []context.Context
+}
+
+func (r *recordingObjectStorage) record(ctx context.Context, name string) {
+ r.calls = append(r.calls, name)
+ r.ctxs = append(r.ctxs, ctx)
+}
+
+func (r *recordingObjectStorage) Delete(ctx context.Context, keys ...string) error {
+ r.record(ctx, "delete")
+ return nil
+}
+
+func (r *recordingObjectStorage) Exists(ctx context.Context, key string) (bool, error) {
+ r.record(ctx, "exists")
+ return true, nil
+}
+
+func (r *recordingObjectStorage) List(ctx context.Context, prefix string) iter.Seq2[*DirEntry, error] {
+ r.record(ctx, "list")
+ return func(yield func(*DirEntry, error) bool) {
+ yield(&DirEntry{
+ Name: "one",
+ }, nil)
+ }
+}
+
+func (r *recordingObjectStorage) Read(ctx context.Context, key string, min *int64, max *int64) (io.ReadCloser, error) {
+ r.record(ctx, "read")
+ return io.NopCloser(strings.NewReader("data")), nil
+}
+
+func (r *recordingObjectStorage) Stat(ctx context.Context, key string) (int64, error) {
+ r.record(ctx, "stat")
+ return 3, nil
+}
+
+func (r *recordingObjectStorage) Write(ctx context.Context, key string, rd io.Reader, sizeHint *int64, expire *time.Time) error {
+ r.record(ctx, "write")
+ _, _ = io.Copy(io.Discard, rd)
+ return nil
+}
+
+type blockingObjectStorage struct {
+ dummyObjectStorage
+ start chan struct{}
+ wait chan struct{}
+ err error
+}
+
+func (b *blockingObjectStorage) Write(ctx context.Context, key string, rd io.Reader, sizeHint *int64, expire *time.Time) error {
+ b.start <- struct{}{}
+ <-b.wait
+ return b.err
+}
diff --git a/pkg/fileservice/parallel_mode.go b/pkg/fileservice/parallel_mode.go
new file mode 100644
index 0000000000000..6201ea0811c5d
--- /dev/null
+++ b/pkg/fileservice/parallel_mode.go
@@ -0,0 +1,62 @@
+// Copyright 2025 Matrix Origin
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fileservice
+
+import (
+ "context"
+ "strings"
+)
+
+// ParallelMode controls when multipart parallel uploads are used.
+type ParallelMode uint8
+
+const (
+ ParallelOff ParallelMode = iota
+ ParallelAuto
+ ParallelForce
+)
+
+func parseParallelMode(s string) (ParallelMode, bool) {
+ switch strings.ToLower(s) {
+ case "off", "false", "0", "":
+ return ParallelOff, true
+ case "auto":
+ return ParallelAuto, true
+ case "force", "on", "true", "1":
+ return ParallelForce, true
+ default:
+ return ParallelOff, false
+ }
+}
+
+type parallelModeKey struct{}
+
+// WithParallelMode sets a per-call parallel mode override on context.
+func WithParallelMode(ctx context.Context, mode ParallelMode) context.Context {
+ return context.WithValue(ctx, parallelModeKey{}, mode)
+}
+
+// parallelModeFromContext retrieves a parallel mode override if present.
+func parallelModeFromContext(ctx context.Context) (ParallelMode, bool) {
+ if ctx == nil {
+ return ParallelOff, false
+ }
+ if v := ctx.Value(parallelModeKey{}); v != nil {
+ if mode, ok := v.(ParallelMode); ok {
+ return mode, true
+ }
+ }
+ return ParallelOff, false
+}
diff --git a/pkg/fileservice/parallel_mode_test.go b/pkg/fileservice/parallel_mode_test.go
new file mode 100644
index 0000000000000..ae9d8b345c9f6
--- /dev/null
+++ b/pkg/fileservice/parallel_mode_test.go
@@ -0,0 +1,74 @@
+// Copyright 2025 Matrix Origin
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fileservice
+
+import (
+ "context"
+ "testing"
+)
+
+func TestParseParallelModeVariants(t *testing.T) {
+ cases := []struct {
+ in string
+ expect ParallelMode
+ ok bool
+ }{
+ {"", ParallelOff, true},
+ {"off", ParallelOff, true},
+ {"false", ParallelOff, true},
+ {"0", ParallelOff, true},
+ {"auto", ParallelAuto, true},
+ {"force", ParallelForce, true},
+ {"on", ParallelForce, true},
+ {"1", ParallelForce, true},
+ {"xxx", ParallelOff, false},
+ }
+
+ for _, c := range cases {
+ mode, ok := parseParallelMode(c.in)
+ if mode != c.expect || ok != c.ok {
+ t.Fatalf("input %s got (%v,%v) expect (%v,%v)", c.in, mode, ok, c.expect, c.ok)
+ }
+ }
+}
+
+func TestWithParallelModeOnContext(t *testing.T) {
+ base := context.Background()
+ ctx := WithParallelMode(base, ParallelForce)
+ mode, ok := parallelModeFromContext(ctx)
+ if !ok || mode != ParallelForce {
+ t.Fatalf("expected force override, got %v %v", mode, ok)
+ }
+ if _, ok := parallelModeFromContext(base); ok {
+ t.Fatalf("unexpected mode on base context")
+ }
+}
+
+func TestETLParallelModePriority(t *testing.T) {
+ t.Setenv("MO_ETL_PARALLEL_MODE", "off")
+ if m := etlParallelMode(context.Background()); m != ParallelOff {
+ t.Fatalf("expect off from env, got %v", m)
+ }
+
+ ctx := WithParallelMode(context.Background(), ParallelForce)
+ if m := etlParallelMode(ctx); m != ParallelForce {
+ t.Fatalf("ctx override should win, got %v", m)
+ }
+
+ t.Setenv("MO_ETL_PARALLEL_MODE", "auto")
+ if m := etlParallelMode(context.Background()); m != ParallelAuto {
+ t.Fatalf("env auto not applied, got %v", m)
+ }
+}
diff --git a/pkg/fileservice/parallel_sdk_test.go b/pkg/fileservice/parallel_sdk_test.go
new file mode 100644
index 0000000000000..54edf6511c451
--- /dev/null
+++ b/pkg/fileservice/parallel_sdk_test.go
@@ -0,0 +1,552 @@
+// Copyright 2025 Matrix Origin
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fileservice
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/credentials"
+ "github.com/aws/aws-sdk-go-v2/service/s3"
+ costypes "github.com/tencentyun/cos-go-sdk-v5"
+)
+
+func newMockAWSServer(t *testing.T, failPart int32) (*httptest.Server, *awsServerState) {
+ t.Helper()
+ state := &awsServerState{
+ failPart: failPart,
+ parts: make(map[int32][]byte),
+ }
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case r.Method == http.MethodPost && strings.Contains(r.URL.RawQuery, "uploads"):
+ if state.failCreate {
+ w.WriteHeader(http.StatusInternalServerError)
+ return
+ }
+ _, _ = io.ReadAll(r.Body)
+ w.Header().Set("Content-Type", "application/xml")
+ _, _ = fmt.Fprintf(w, `%s`, state.uploadID)
+ case r.Method == http.MethodPut && !strings.Contains(r.URL.RawQuery, "partNumber") && !strings.Contains(r.URL.RawQuery, "uploadId") && !strings.Contains(r.URL.RawQuery, "uploads"):
+ body, _ := io.ReadAll(r.Body)
+ state.mu.Lock()
+ state.putCount++
+ state.putBody = append([]byte{}, body...)
+ state.mu.Unlock()
+ w.WriteHeader(http.StatusOK)
+ case r.Method == http.MethodPut && strings.Contains(r.URL.RawQuery, "partNumber"):
+ partStr := r.URL.Query().Get("partNumber")
+ pn, _ := strconv.Atoi(partStr)
+ body, _ := io.ReadAll(r.Body)
+ if state.failPart > 0 && int32(pn) == state.failPart {
+ w.WriteHeader(http.StatusInternalServerError)
+ return
+ }
+ state.mu.Lock()
+ state.parts[int32(pn)] = body
+ state.mu.Unlock()
+ w.Header().Set("ETag", fmt.Sprintf("\"etag-%d\"", pn))
+ case r.Method == http.MethodPost && strings.Contains(r.URL.RawQuery, "uploadId"):
+ body, _ := io.ReadAll(r.Body)
+ if state.failComplete {
+ w.WriteHeader(http.StatusInternalServerError)
+ return
+ }
+ state.mu.Lock()
+ state.completeBody = append([]byte{}, body...)
+ state.mu.Unlock()
+ w.Header().Set("Content-Type", "application/xml")
+ _, _ = w.Write([]byte(`locbucketobject"etag"`))
+ case r.Method == http.MethodDelete && strings.Contains(r.URL.RawQuery, "uploadId"):
+ state.aborted.Store(true)
+ w.WriteHeader(http.StatusOK)
+ default:
+ w.WriteHeader(http.StatusNotFound)
+ }
+ })
+ return httptest.NewServer(handler), state
+}
+
+type awsServerState struct {
+ mu sync.Mutex
+ parts map[int32][]byte
+ completeBody []byte
+ failPart int32
+ failComplete bool
+ failCreate bool
+ uploadID string
+ aborted atomic.Bool
+ putCount int
+ putBody []byte
+}
+
+func newTestAWSClient(t *testing.T, srv *httptest.Server) *AwsSDKv2 {
+ t.Helper()
+ cfg := aws.Config{
+ Region: "us-east-1",
+ Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider("id", "key", "")),
+ HTTPClient: srv.Client(),
+ Retryer: func() aws.Retryer {
+ return aws.NopRetryer{}
+ },
+ }
+ client := s3.NewFromConfig(cfg, func(o *s3.Options) {
+ o.UsePathStyle = true
+ o.BaseEndpoint = aws.String(srv.URL)
+ })
+
+ return &AwsSDKv2{
+ name: "aws-test",
+ bucket: "bucket",
+ client: client,
+ }
+}
+
+func TestAwsParallelMultipartSuccess(t *testing.T) {
+ server, state := newMockAWSServer(t, 0)
+ defer server.Close()
+ state.uploadID = "uid-success"
+
+ sdk := newTestAWSClient(t, server)
+
+ data := bytes.Repeat([]byte("a"), int(minMultipartPartSize+1))
+ size := int64(len(data))
+ err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(data), &size, &ParallelMultipartOption{
+ PartSize: minMultipartPartSize,
+ Concurrency: 2,
+ })
+ if err != nil {
+ t.Fatalf("write failed: %v, parts=%d, complete=%s", err, len(state.parts), string(state.completeBody))
+ }
+ if len(state.parts) != 2 {
+ t.Fatalf("expected 2 parts, got %d", len(state.parts))
+ }
+ if len(state.completeBody) == 0 {
+ t.Fatalf("complete body not recorded")
+ }
+}
+
+func TestAwsParallelMultipartAbortOnError(t *testing.T) {
+ server, state := newMockAWSServer(t, 0)
+ defer server.Close()
+ state.uploadID = "uid-fail"
+ state.failComplete = true
+
+ sdk := newTestAWSClient(t, server)
+
+ data := bytes.Repeat([]byte("b"), int(minMultipartPartSize*2))
+ size := int64(len(data))
+ err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(data), &size, &ParallelMultipartOption{
+ PartSize: minMultipartPartSize,
+ Concurrency: 2,
+ })
+ if err == nil {
+ t.Fatalf("expected error")
+ }
+ if !state.aborted.Load() {
+ t.Fatalf("expected abort request to be sent")
+ }
+}
+
+func TestAwsMultipartFallbackSmallSize(t *testing.T) {
+ server, state := newMockAWSServer(t, 0)
+ defer server.Close()
+ state.uploadID = "uid-small"
+
+ sdk := newTestAWSClient(t, server)
+ data := bytes.Repeat([]byte("x"), int(minMultipartPartSize-1))
+ size := int64(len(data))
+ if err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(data), &size, nil); err != nil {
+ t.Fatalf("write failed: %v", err)
+ }
+ if state.putCount != 1 {
+ t.Fatalf("expected single PUT fallback, got %d", state.putCount)
+ }
+ if string(state.putBody) != string(data) {
+ t.Fatalf("unexpected put body")
+ }
+}
+
+func TestAwsMultipartFallbackUnknownSmall(t *testing.T) {
+ server, state := newMockAWSServer(t, 0)
+ defer server.Close()
+ state.uploadID = "uid-unknown"
+
+ sdk := newTestAWSClient(t, server)
+ data := bytes.Repeat([]byte("y"), int(minMultipartPartSize-1))
+ reader := bytes.NewReader(data)
+ if err := sdk.WriteMultipartParallel(context.Background(), "object", reader, nil, nil); err != nil {
+ t.Fatalf("write failed: %v", err)
+ }
+ if state.putCount != 1 {
+ t.Fatalf("expected fallback PUT for unknown small")
+ }
+ if string(state.putBody) != string(data) {
+ t.Fatalf("unexpected put body")
+ }
+}
+
+func TestAwsMultipartTooManyParts(t *testing.T) {
+ server, _ := newMockAWSServer(t, 0)
+ defer server.Close()
+
+ sdk := newTestAWSClient(t, server)
+ huge := int64(maxMultipartParts+1) * defaultParallelMultipartPartSize
+ err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(nil), &huge, &ParallelMultipartOption{
+ PartSize: defaultParallelMultipartPartSize,
+ })
+ if err == nil {
+ t.Fatalf("expected error for too many parts")
+ }
+}
+
+func TestAwsMultipartUploadPartError(t *testing.T) {
+ server, _ := newMockAWSServer(t, 1)
+ defer server.Close()
+
+ sdk := newTestAWSClient(t, server)
+ data := bytes.Repeat([]byte("p"), int(minMultipartPartSize*2))
+ size := int64(len(data))
+ if err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(data), &size, &ParallelMultipartOption{
+ PartSize: minMultipartPartSize,
+ Concurrency: 2,
+ }); err == nil {
+ t.Fatalf("expected upload part error")
+ }
+}
+
+func TestAwsParallelMultipartUnknownSize(t *testing.T) {
+ server, state := newMockAWSServer(t, 0)
+ defer server.Close()
+ state.uploadID = "uid-unknown-size"
+
+ sdk := newTestAWSClient(t, server)
+ data := bytes.Repeat([]byte("z"), int(minMultipartPartSize+1))
+ if err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(data), nil, &ParallelMultipartOption{
+ PartSize: minMultipartPartSize,
+ Concurrency: 2,
+ }); err != nil {
+ t.Fatalf("write failed: %v", err)
+ }
+ if len(state.parts) != 2 {
+ t.Fatalf("expected multipart upload with unknown size")
+ }
+}
+
+func TestAwsMultipartEmptyReader(t *testing.T) {
+ server, state := newMockAWSServer(t, 0)
+ defer server.Close()
+
+ sdk := newTestAWSClient(t, server)
+ if err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(nil), nil, nil); err != nil {
+ t.Fatalf("expected nil error for empty reader, got %v", err)
+ }
+ if state.putCount != 0 && len(state.parts) != 0 {
+ t.Fatalf("no upload should happen for empty reader")
+ }
+}
+
+func TestAwsMultipartContextCanceled(t *testing.T) {
+ server, _ := newMockAWSServer(t, 0)
+ defer server.Close()
+
+ sdk := newTestAWSClient(t, server)
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ if err := sdk.WriteMultipartParallel(ctx, "object", bytes.NewReader([]byte("data")), nil, nil); err == nil {
+ t.Fatalf("expected context canceled error")
+ }
+}
+
+func TestAwsMultipartCreateFail(t *testing.T) {
+ server, state := newMockAWSServer(t, 0)
+ defer server.Close()
+ state.failCreate = true
+
+ sdk := newTestAWSClient(t, server)
+ data := bytes.Repeat([]byte("i"), int(minMultipartPartSize+1))
+ size := int64(len(data))
+ if err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(data), &size, nil); err == nil {
+ t.Fatalf("expected create multipart error")
+ }
+}
+
+func newMockCOSServer(t *testing.T, failPart int) (*httptest.Server, *cosServerState) {
+ t.Helper()
+ state := &cosServerState{
+ failPart: failPart,
+ parts: make(map[int][]byte),
+ }
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case r.Method == http.MethodPost && strings.Contains(r.URL.RawQuery, "uploads"):
+ if state.failCreate {
+ w.WriteHeader(http.StatusInternalServerError)
+ return
+ }
+ _, _ = io.ReadAll(r.Body)
+ w.Header().Set("Content-Type", "application/xml")
+ _, _ = fmt.Fprintf(w, `%s`, state.uploadID)
+ case r.Method == http.MethodPut && !strings.Contains(r.URL.RawQuery, "partNumber") && !strings.Contains(r.URL.RawQuery, "uploadId") && !strings.Contains(r.URL.RawQuery, "uploads"):
+ body, _ := io.ReadAll(r.Body)
+ state.mu.Lock()
+ state.putCount++
+ state.putBody = append([]byte{}, body...)
+ state.mu.Unlock()
+ w.WriteHeader(http.StatusOK)
+ case r.Method == http.MethodPut && strings.Contains(r.URL.RawQuery, "partNumber"):
+ partStr := r.URL.Query().Get("partNumber")
+ pn, _ := strconv.Atoi(partStr)
+ body, _ := io.ReadAll(r.Body)
+ if state.failPart > 0 && pn == state.failPart {
+ w.WriteHeader(http.StatusInternalServerError)
+ return
+ }
+ state.mu.Lock()
+ state.parts[pn] = body
+ state.mu.Unlock()
+ case r.Method == http.MethodPost && strings.Contains(r.URL.RawQuery, "uploadId"):
+ body, _ := io.ReadAll(r.Body)
+ if state.failComplete {
+ w.WriteHeader(http.StatusInternalServerError)
+ return
+ }
+ state.mu.Lock()
+ state.completeBody = append([]byte{}, body...)
+ state.mu.Unlock()
+ w.Header().Set("Content-Type", "application/xml")
+ state.respBody = `locbucketobjectetag`
+ state.completed.Store(true)
+ _, _ = w.Write([]byte(state.respBody))
+ case r.Method == http.MethodDelete && strings.Contains(r.URL.RawQuery, "uploadId"):
+ state.aborted.Store(true)
+ w.WriteHeader(http.StatusOK)
+ default:
+ w.WriteHeader(http.StatusNotFound)
+ }
+ })
+ return httptest.NewServer(handler), state
+}
+
+type cosServerState struct {
+ mu sync.Mutex
+ parts map[int][]byte
+ completeBody []byte
+ failPart int
+ failComplete bool
+ failCreate bool
+ uploadID string
+ aborted atomic.Bool
+ completed atomic.Bool
+ respBody string
+ putCount int
+ putBody []byte
+}
+
+func newTestCOSClient(t *testing.T, srv *httptest.Server) *QCloudSDK {
+ t.Helper()
+ baseURL, err := url.Parse(srv.URL)
+ if err != nil {
+ t.Fatalf("parse url: %v", err)
+ }
+
+ client := costypes.NewClient(
+ &costypes.BaseURL{BucketURL: baseURL},
+ srv.Client(),
+ )
+ client.Conf.EnableCRC = false
+
+ return &QCloudSDK{
+ name: "cos-test",
+ client: client,
+ }
+}
+
+func TestCOSParallelMultipartSuccess(t *testing.T) {
+ server, state := newMockCOSServer(t, 0)
+ defer server.Close()
+ state.uploadID = "cos-uid"
+
+ sdk := newTestCOSClient(t, server)
+ data := bytes.Repeat([]byte("c"), int(minMultipartPartSize+2))
+ size := int64(len(data))
+ err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(data), &size, &ParallelMultipartOption{
+ PartSize: minMultipartPartSize,
+ Concurrency: 2,
+ })
+ if err != nil {
+ t.Fatalf("write failed: %v, parts=%d, complete=%s", err, len(state.parts), string(state.completeBody))
+ }
+ if len(state.parts) != 2 {
+ t.Fatalf("expected 2 parts, got %d", len(state.parts))
+ }
+ if len(state.completeBody) == 0 {
+ t.Fatalf("complete body not recorded")
+ }
+}
+
+func TestCOSParallelMultipartAbortOnError(t *testing.T) {
+ server, state := newMockCOSServer(t, 0)
+ defer server.Close()
+ state.uploadID = "cos-uid-fail"
+ state.failComplete = true
+
+ sdk := newTestCOSClient(t, server)
+ data := bytes.Repeat([]byte("d"), int(minMultipartPartSize*2))
+ size := int64(len(data))
+ err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(data), &size, &ParallelMultipartOption{
+ PartSize: minMultipartPartSize,
+ Concurrency: 2,
+ })
+ if err == nil {
+ t.Fatalf("expected error")
+ }
+ if !state.aborted.Load() {
+ t.Fatalf("expected abort request")
+ }
+}
+
+func TestCOSMultipartFallbackSmallSize(t *testing.T) {
+ server, state := newMockCOSServer(t, 0)
+ defer server.Close()
+ state.uploadID = "cos-small"
+
+ sdk := newTestCOSClient(t, server)
+ data := bytes.Repeat([]byte("e"), int(minMultipartPartSize-1))
+ size := int64(len(data))
+ if err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(data), &size, nil); err != nil {
+ t.Fatalf("write failed: %v", err)
+ }
+ if state.putCount != 1 {
+ t.Fatalf("expected single PUT fallback, got %d", state.putCount)
+ }
+ if string(state.putBody) != string(data) {
+ t.Fatalf("unexpected put body")
+ }
+}
+
+func TestCOSMultipartFallbackUnknownSmall(t *testing.T) {
+ server, state := newMockCOSServer(t, 0)
+ defer server.Close()
+ state.uploadID = "cos-unknown"
+
+ sdk := newTestCOSClient(t, server)
+ data := bytes.Repeat([]byte("f"), int(minMultipartPartSize-1))
+ if err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(data), nil, nil); err != nil {
+ t.Fatalf("write failed: %v", err)
+ }
+ if state.putCount != 1 {
+ t.Fatalf("expected fallback PUT for unknown small")
+ }
+ if string(state.putBody) != string(data) {
+ t.Fatalf("unexpected put body")
+ }
+}
+
+func TestCOSMultipartTooManyParts(t *testing.T) {
+ server, _ := newMockCOSServer(t, 0)
+ defer server.Close()
+
+ sdk := newTestCOSClient(t, server)
+ huge := int64(maxMultipartParts+1) * defaultParallelMultipartPartSize
+ err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(nil), &huge, &ParallelMultipartOption{
+ PartSize: defaultParallelMultipartPartSize,
+ })
+ if err == nil {
+ t.Fatalf("expected error for too many parts")
+ }
+}
+
+func TestCOSMultipartUploadPartError(t *testing.T) {
+ server, _ := newMockCOSServer(t, 1)
+ defer server.Close()
+
+ sdk := newTestCOSClient(t, server)
+ data := bytes.Repeat([]byte("h"), int(minMultipartPartSize*2))
+ size := int64(len(data))
+ if err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(data), &size, &ParallelMultipartOption{
+ PartSize: minMultipartPartSize,
+ Concurrency: 2,
+ }); err == nil {
+ t.Fatalf("expected upload part error")
+ }
+}
+
+func TestCOSParallelMultipartUnknownSize(t *testing.T) {
+ server, state := newMockCOSServer(t, 0)
+ defer server.Close()
+ state.uploadID = "cos-unknown-size"
+
+ sdk := newTestCOSClient(t, server)
+ data := bytes.Repeat([]byte("g"), int(minMultipartPartSize+1))
+ if err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(data), nil, &ParallelMultipartOption{
+ PartSize: minMultipartPartSize,
+ Concurrency: 2,
+ }); err != nil {
+ t.Fatalf("write failed: %v", err)
+ }
+ if len(state.parts) != 2 {
+ t.Fatalf("expected multipart upload with unknown size")
+ }
+}
+
+func TestCOSMultipartEmptyReader(t *testing.T) {
+ server, state := newMockCOSServer(t, 0)
+ defer server.Close()
+
+ sdk := newTestCOSClient(t, server)
+ if err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(nil), nil, nil); err != nil {
+ t.Fatalf("expected nil error for empty reader, got %v", err)
+ }
+ if state.putCount != 0 && len(state.parts) != 0 {
+ t.Fatalf("no upload should happen for empty reader")
+ }
+}
+
+func TestCOSMultipartContextCanceled(t *testing.T) {
+ server, _ := newMockCOSServer(t, 0)
+ defer server.Close()
+
+ sdk := newTestCOSClient(t, server)
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ if err := sdk.WriteMultipartParallel(ctx, "object", bytes.NewReader([]byte("data")), nil, nil); err == nil {
+ t.Fatalf("expected context canceled error")
+ }
+}
+
+func TestCOSMultipartCreateFail(t *testing.T) {
+ server, state := newMockCOSServer(t, 0)
+ defer server.Close()
+ state.failCreate = true
+
+ sdk := newTestCOSClient(t, server)
+ data := bytes.Repeat([]byte("j"), int(minMultipartPartSize+1))
+ size := int64(len(data))
+ if err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(data), &size, nil); err == nil {
+ t.Fatalf("expected create multipart error")
+ }
+}
diff --git a/pkg/fileservice/parallel_upload_test.go b/pkg/fileservice/parallel_upload_test.go
new file mode 100644
index 0000000000000..e3e3fb435931b
--- /dev/null
+++ b/pkg/fileservice/parallel_upload_test.go
@@ -0,0 +1,511 @@
+// Copyright 2025 Matrix Origin
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fileservice
+
+import (
+ "bytes"
+ "context"
+ "io"
+ "iter"
+ "runtime"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+func TestNormalizeParallelOption(t *testing.T) {
+ cases := []struct {
+ name string
+ opt *ParallelMultipartOption
+ expectSize int64
+ expectConc int
+ expectExpire bool
+ }{
+ {
+ name: "default values",
+ opt: nil,
+ expectSize: defaultParallelMultipartPartSize,
+ expectConc: runtime.NumCPU(),
+ },
+ {
+ name: "clamp below min",
+ opt: &ParallelMultipartOption{
+ PartSize: minMultipartPartSize - 1,
+ },
+ expectSize: minMultipartPartSize,
+ expectConc: runtime.NumCPU(),
+ },
+ {
+ name: "clamp above max",
+ opt: &ParallelMultipartOption{
+ PartSize: maxMultipartPartSize + 1,
+ },
+ expectSize: maxMultipartPartSize,
+ expectConc: runtime.NumCPU(),
+ },
+ {
+ name: "custom values",
+ opt: &ParallelMultipartOption{
+ PartSize: 16 << 20,
+ Concurrency: 3,
+ },
+ expectSize: 16 << 20,
+ expectConc: 3,
+ },
+ }
+
+ for _, c := range cases {
+ t.Run(c.name, func(t *testing.T) {
+ res := normalizeParallelOption(c.opt)
+ if res.PartSize != c.expectSize {
+ t.Fatalf("part size mismatch, got %d expect %d", res.PartSize, c.expectSize)
+ }
+ if res.Concurrency != c.expectConc {
+ t.Fatalf("concurrency mismatch, got %d expect %d", res.Concurrency, c.expectConc)
+ }
+ })
+ }
+}
+
+func TestGetParallelUploadPoolSingleton(t *testing.T) {
+ p1 := getParallelUploadPool()
+ p2 := getParallelUploadPool()
+ if p1 == nil || p2 == nil {
+ t.Fatalf("pool not initialized")
+ }
+ if p1 != p2 {
+ t.Fatalf("expected singleton pool")
+ }
+}
+
+type mockParallelStorage struct {
+ supports bool
+ exists bool
+
+ writeCalled int
+ mpCalled int
+
+ lastKey string
+ lastData []byte
+ lastSize *int64
+ lastOpt ParallelMultipartOption
+ lastExpire *time.Time
+}
+
+func (m *mockParallelStorage) List(ctx context.Context, prefix string) iter.Seq2[*DirEntry, error] {
+ return func(yield func(*DirEntry, error) bool) {}
+}
+
+func (m *mockParallelStorage) Stat(ctx context.Context, key string) (int64, error) {
+ return 0, nil
+}
+
+func (m *mockParallelStorage) Exists(ctx context.Context, key string) (bool, error) {
+ return m.exists, nil
+}
+
+func (m *mockParallelStorage) Write(ctx context.Context, key string, r io.Reader, sizeHint *int64, expire *time.Time) error {
+ m.writeCalled++
+ m.lastKey = key
+ m.lastSize = sizeHint
+ m.lastExpire = expire
+ b, _ := io.ReadAll(r)
+ m.lastData = b
+ return nil
+}
+
+func (m *mockParallelStorage) Read(ctx context.Context, key string, min *int64, max *int64) (io.ReadCloser, error) {
+ return io.NopCloser(bytes.NewReader(nil)), nil
+}
+
+func (m *mockParallelStorage) Delete(ctx context.Context, keys ...string) error {
+ return nil
+}
+
+func (m *mockParallelStorage) SupportsParallelMultipart() bool {
+ return m.supports
+}
+
+func (m *mockParallelStorage) WriteMultipartParallel(ctx context.Context, key string, r io.Reader, sizeHint *int64, opt *ParallelMultipartOption) error {
+ m.mpCalled++
+ m.lastKey = key
+ if opt != nil {
+ m.lastOpt = *opt
+ }
+ m.lastSize = sizeHint
+ b, _ := io.ReadAll(r)
+ m.lastData = b
+ return nil
+}
+
+func TestS3FSWriteUsesParallelWhenForced(t *testing.T) {
+ storage := &mockParallelStorage{supports: true}
+ fs := &S3FS{
+ name: "s3",
+ storage: storage,
+ ioMerger: NewIOMerger(),
+ asyncUpdate: true,
+ parallelMode: ParallelForce,
+ }
+
+ data := []byte("hello")
+ vector := IOVector{
+ FilePath: "obj",
+ Entries: []IOEntry{
+ {Offset: 0, Size: int64(len(data)), Data: data},
+ },
+ }
+
+ if err := fs.Write(context.Background(), vector); err != nil {
+ t.Fatalf("write failed: %v", err)
+ }
+ if storage.mpCalled != 1 {
+ t.Fatalf("expected parallel write, got %d", storage.mpCalled)
+ }
+ if storage.writeCalled != 0 {
+ t.Fatalf("unexpected non-parallel write")
+ }
+ if storage.lastKey != "obj" {
+ t.Fatalf("unexpected key %s", storage.lastKey)
+ }
+ if string(storage.lastData) != "hello" {
+ t.Fatalf("unexpected data %s", string(storage.lastData))
+ }
+}
+
+func TestS3FSWriteDisableParallel(t *testing.T) {
+ storage := &mockParallelStorage{supports: true}
+ fs := &S3FS{
+ name: "s3",
+ storage: storage,
+ ioMerger: NewIOMerger(),
+ asyncUpdate: true,
+ parallelMode: ParallelOff,
+ }
+
+ data := bytes.Repeat([]byte("a"), int(minMultipartPartSize+1))
+ vector := IOVector{
+ FilePath: "large",
+ Entries: []IOEntry{
+ {Offset: 0, Size: int64(len(data)), Data: data},
+ },
+ }
+
+ if err := fs.Write(context.Background(), vector); err != nil {
+ t.Fatalf("write failed: %v", err)
+ }
+ if storage.mpCalled != 0 {
+ t.Fatalf("parallel write should be disabled")
+ }
+ if storage.writeCalled != 1 {
+ t.Fatalf("expected single write, got %d", storage.writeCalled)
+ }
+ if storage.lastKey != "large" {
+ t.Fatalf("unexpected key %s", storage.lastKey)
+ }
+}
+
+func TestS3FSWriteUnknownSizeUsesParallelWhenForced(t *testing.T) {
+ storage := &mockParallelStorage{supports: true}
+ fs := &S3FS{
+ name: "s3",
+ storage: storage,
+ ioMerger: NewIOMerger(),
+ asyncUpdate: true,
+ parallelMode: ParallelForce,
+ }
+
+ reader := strings.NewReader("data")
+ vector := IOVector{
+ FilePath: "unknown",
+ Entries: []IOEntry{
+ {Offset: 0, Size: -1, ReaderForWrite: reader},
+ },
+ }
+
+ if err := fs.Write(context.Background(), vector); err != nil {
+ t.Fatalf("write failed: %v", err)
+ }
+ if storage.mpCalled != 1 {
+ t.Fatalf("expected parallel write for unknown size")
+ }
+ if storage.lastSize != nil {
+ t.Fatalf("size hint should be nil for unknown size")
+ }
+}
+
+func TestS3FSWriteUnknownSizeDefaultNoParallel(t *testing.T) {
+ storage := &mockParallelStorage{supports: true}
+ fs := &S3FS{
+ name: "s3",
+ storage: storage,
+ ioMerger: NewIOMerger(),
+ asyncUpdate: true,
+ parallelMode: ParallelOff,
+ }
+
+ reader := strings.NewReader("data")
+ vector := IOVector{
+ FilePath: "unknown",
+ Entries: []IOEntry{
+ {Offset: 0, Size: -1, ReaderForWrite: reader},
+ },
+ }
+
+ if err := fs.Write(context.Background(), vector); err != nil {
+ t.Fatalf("write failed: %v", err)
+ }
+ if storage.mpCalled != 0 {
+ t.Fatalf("parallel write should not be used by default")
+ }
+ if storage.writeCalled != 1 {
+ t.Fatalf("expected single write by default")
+ }
+}
+
+func TestS3FSSmallSizeSkipsParallel(t *testing.T) {
+ storage := &mockParallelStorage{supports: true}
+ fs := &S3FS{
+ name: "s3",
+ storage: storage,
+ ioMerger: NewIOMerger(),
+ asyncUpdate: true,
+ parallelMode: ParallelAuto,
+ }
+
+ data := bytes.Repeat([]byte("b"), int(minMultipartPartSize-1))
+ vector := IOVector{
+ FilePath: "small",
+ Entries: []IOEntry{
+ {Offset: 0, Size: int64(len(data)), Data: data},
+ },
+ }
+
+ if err := fs.Write(context.Background(), vector); err != nil {
+ t.Fatalf("write failed: %v", err)
+ }
+ if storage.writeCalled != 1 {
+ t.Fatalf("expected non-parallel write for small object")
+ }
+ if storage.mpCalled != 0 {
+ t.Fatalf("unexpected parallel write")
+ }
+}
+
+func TestS3FSLargeSizeDefaultNoParallel(t *testing.T) {
+ storage := &mockParallelStorage{supports: true}
+ fs := &S3FS{
+ name: "s3",
+ storage: storage,
+ ioMerger: NewIOMerger(),
+ asyncUpdate: true,
+ parallelMode: ParallelOff,
+ }
+
+ data := bytes.Repeat([]byte("d"), int(minMultipartPartSize+1))
+ vector := IOVector{
+ FilePath: "large-default",
+ Entries: []IOEntry{
+ {Offset: 0, Size: int64(len(data)), Data: data},
+ },
+ }
+
+ if err := fs.Write(context.Background(), vector); err != nil {
+ t.Fatalf("write failed: %v", err)
+ }
+ if storage.mpCalled != 0 {
+ t.Fatalf("parallel write should be disabled by default")
+ }
+ if storage.writeCalled != 1 {
+ t.Fatalf("expected single write by default, got %d", storage.writeCalled)
+ }
+}
+
+func TestS3FSAutoUsesParallelForLarge(t *testing.T) {
+ storage := &mockParallelStorage{supports: true}
+ fs := &S3FS{
+ name: "s3",
+ storage: storage,
+ ioMerger: NewIOMerger(),
+ asyncUpdate: true,
+ parallelMode: ParallelAuto,
+ }
+
+ data := bytes.Repeat([]byte("p"), int(minMultipartPartSize+1))
+ vector := IOVector{
+ FilePath: "auto-large",
+ Entries: []IOEntry{
+ {Offset: 0, Size: int64(len(data)), Data: data},
+ },
+ }
+
+ if err := fs.Write(context.Background(), vector); err != nil {
+ t.Fatalf("write failed: %v", err)
+ }
+ if storage.mpCalled != 1 {
+ t.Fatalf("expected parallel write with auto mode")
+ }
+}
+
+func TestS3FSForceParallelFallbackWhenUnsupported(t *testing.T) {
+ storage := &mockParallelStorage{supports: false}
+ fs := &S3FS{
+ name: "s3",
+ storage: storage,
+ ioMerger: NewIOMerger(),
+ asyncUpdate: true,
+ parallelMode: ParallelForce,
+ }
+
+ data := []byte("force")
+ vector := IOVector{
+ FilePath: "force",
+ Entries: []IOEntry{
+ {Offset: 0, Size: int64(len(data)), Data: data},
+ },
+ }
+
+ if err := fs.Write(context.Background(), vector); err != nil {
+ t.Fatalf("write failed: %v", err)
+ }
+ if storage.mpCalled != 0 {
+ t.Fatalf("parallel write should not happen when unsupported")
+ }
+ if storage.writeCalled != 1 {
+ t.Fatalf("expected fallback single write")
+ }
+}
+
+func TestS3FSParallelPassesExpire(t *testing.T) {
+ storage := &mockParallelStorage{supports: true}
+ fs := &S3FS{
+ name: "s3",
+ storage: storage,
+ ioMerger: NewIOMerger(),
+ asyncUpdate: true,
+ parallelMode: ParallelForce,
+ }
+
+ data := bytes.Repeat([]byte("c"), int(minMultipartPartSize+1))
+ expire := time.Now().Add(time.Hour)
+ vector := IOVector{
+ FilePath: "with-expire",
+ ExpireAt: expire,
+ Entries: []IOEntry{
+ {Offset: 0, Size: int64(len(data)), Data: data},
+ },
+ }
+
+ if err := fs.Write(context.Background(), vector); err != nil {
+ t.Fatalf("write failed: %v", err)
+ }
+ if storage.mpCalled != 1 {
+ t.Fatalf("expected parallel write")
+ }
+ if storage.lastOpt.Expire == nil || !storage.lastOpt.Expire.Equal(expire) {
+ t.Fatalf("expire not propagated")
+ }
+}
+
+func TestS3FSWriteFileExists(t *testing.T) {
+ storage := &mockParallelStorage{supports: true, exists: true}
+ fs := &S3FS{
+ name: "s3",
+ storage: storage,
+ ioMerger: NewIOMerger(),
+ asyncUpdate: true,
+ parallelMode: ParallelOff,
+ }
+ vector := IOVector{
+ FilePath: "dup",
+ Entries: []IOEntry{
+ {Offset: 0, Size: 1, Data: []byte("x")},
+ },
+ }
+ err := fs.Write(context.Background(), vector)
+ if err == nil {
+ t.Fatalf("expected file exists error")
+ }
+}
+
+type countingParallelStorage struct {
+ supports bool
+ current atomic.Int64
+ max atomic.Int64
+}
+
+func (c *countingParallelStorage) List(ctx context.Context, prefix string) iter.Seq2[*DirEntry, error] {
+ return func(yield func(*DirEntry, error) bool) {}
+}
+
+func (c *countingParallelStorage) Stat(ctx context.Context, key string) (int64, error) {
+ return 0, nil
+}
+
+func (c *countingParallelStorage) Exists(ctx context.Context, key string) (bool, error) {
+ return false, nil
+}
+
+func (c *countingParallelStorage) Write(ctx context.Context, key string, r io.Reader, sizeHint *int64, expire *time.Time) error {
+ return nil
+}
+
+func (c *countingParallelStorage) Read(ctx context.Context, key string, min *int64, max *int64) (io.ReadCloser, error) {
+ return io.NopCloser(strings.NewReader("")), nil
+}
+
+func (c *countingParallelStorage) Delete(ctx context.Context, keys ...string) error {
+ return nil
+}
+
+func (c *countingParallelStorage) SupportsParallelMultipart() bool {
+ return c.supports
+}
+
+func (c *countingParallelStorage) WriteMultipartParallel(ctx context.Context, key string, r io.Reader, sizeHint *int64, opt *ParallelMultipartOption) error {
+ cur := c.current.Add(1)
+ defer c.current.Add(-1)
+ for {
+ old := c.max.Load()
+ if cur <= old || c.max.CompareAndSwap(old, cur) {
+ break
+ }
+ }
+ time.Sleep(20 * time.Millisecond)
+ return nil
+}
+
+func TestObjectStorageSemaphoreLimitsParallel(t *testing.T) {
+ upstream := &countingParallelStorage{supports: true}
+ sem := newObjectStorageSemaphore(upstream, 1)
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+ _ = sem.WriteMultipartParallel(context.Background(), "a", strings.NewReader("1"), nil, nil)
+ }()
+ go func() {
+ defer wg.Done()
+ _ = sem.WriteMultipartParallel(context.Background(), "b", strings.NewReader("2"), nil, nil)
+ }()
+ wg.Wait()
+
+ if upstream.max.Load() > 1 {
+ t.Fatalf("expected semaphore to cap concurrency at 1, got %d", upstream.max.Load())
+ }
+}
diff --git a/pkg/fileservice/qcloud_sdk.go b/pkg/fileservice/qcloud_sdk.go
index 008806f8b609a..165199aa37563 100644
--- a/pkg/fileservice/qcloud_sdk.go
+++ b/pkg/fileservice/qcloud_sdk.go
@@ -25,7 +25,9 @@ import (
"net/url"
"os"
gotrace "runtime/trace"
+ "sort"
"strconv"
+ "sync"
"time"
"github.com/matrixorigin/matrixone/pkg/common/moerr"
@@ -125,6 +127,7 @@ func NewQCloudSDK(
}
var _ ObjectStorage = new(QCloudSDK)
+var _ ParallelMultipartWriter = new(QCloudSDK)
func (a *QCloudSDK) List(
ctx context.Context,
@@ -280,6 +283,269 @@ func (a *QCloudSDK) Write(
return
}
+func (a *QCloudSDK) SupportsParallelMultipart() bool {
+ return true
+}
+
+func (a *QCloudSDK) WriteMultipartParallel(
+ ctx context.Context,
+ key string,
+ r io.Reader,
+ sizeHint *int64,
+ opt *ParallelMultipartOption,
+) (err error) {
+ defer wrapSizeMismatchErr(&err)
+
+ options := normalizeParallelOption(opt)
+ if sizeHint != nil && *sizeHint < minMultipartPartSize {
+ return a.Write(ctx, key, r, sizeHint, options.Expire)
+ }
+ if sizeHint != nil {
+ expectedParts := (*sizeHint + options.PartSize - 1) / options.PartSize
+ if expectedParts > maxMultipartParts {
+ return moerr.NewInternalErrorNoCtxf("too many parts for multipart upload: %d", expectedParts)
+ }
+ }
+
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ bufPool := sync.Pool{
+ New: func() any {
+ buf := make([]byte, options.PartSize)
+ return &buf
+ },
+ }
+
+ readChunk := func() (bufPtr *[]byte, buf []byte, n int, err error) {
+ bufPtr = bufPool.Get().(*[]byte)
+ raw := *bufPtr
+ n, err = io.ReadFull(r, raw)
+ switch {
+ case errors.Is(err, io.EOF):
+ bufPool.Put(bufPtr)
+ return nil, nil, 0, io.EOF
+ case errors.Is(err, io.ErrUnexpectedEOF):
+ err = io.EOF
+ return bufPtr, raw, n, err
+ case err != nil:
+ bufPool.Put(bufPtr)
+ return nil, nil, 0, err
+ default:
+ return bufPtr, raw, n, nil
+ }
+ }
+
+ firstBufPtr, firstBuf, firstN, err := readChunk()
+ if err != nil && !errors.Is(err, io.EOF) {
+ return err
+ }
+ if firstN == 0 && errors.Is(err, io.EOF) {
+ return nil
+ }
+ if errors.Is(err, io.EOF) && int64(firstN) < minMultipartPartSize {
+ data := make([]byte, firstN)
+ copy(data, firstBuf[:firstN])
+ bufPool.Put(firstBufPtr)
+ size := int64(firstN)
+ return a.Write(ctx, key, bytes.NewReader(data), &size, options.Expire)
+ }
+
+ var expiresHeader string
+ if options.Expire != nil {
+ expiresHeader = options.Expire.UTC().Format(http.TimeFormat)
+ }
+
+ initOpt := &cos.InitiateMultipartUploadOptions{
+ ObjectPutHeaderOptions: &cos.ObjectPutHeaderOptions{
+ Expires: expiresHeader,
+ },
+ }
+ output, createErr := DoWithRetry("cos initiate multipart upload", func() (*cos.InitiateMultipartUploadResult, error) {
+ res, _, e := a.client.Object.InitiateMultipartUpload(ctx, key, initOpt)
+ return res, e
+ }, maxRetryAttemps, IsRetryableError)
+ if createErr != nil {
+ bufPool.Put(firstBufPtr)
+ return createErr
+ }
+
+ defer func() {
+ if err != nil {
+ _, _ = a.client.Object.AbortMultipartUpload(ctx, key, output.UploadID)
+ }
+ }()
+
+ type partJob struct {
+ num int32
+ buf []byte
+ bufPtr *[]byte
+ n int
+ }
+
+ var (
+ partNum int32
+ parts []cos.Object
+ partsLock sync.Mutex
+ wg sync.WaitGroup
+ errOnce sync.Once
+ firstErr error
+ )
+
+ setErr := func(e error) {
+ if e == nil {
+ return
+ }
+ errOnce.Do(func() {
+ firstErr = e
+ cancel()
+ })
+ }
+
+ jobCh := make(chan partJob, options.Concurrency*2)
+
+ startWorker := func() error {
+ wg.Add(1)
+ return getParallelUploadPool().Submit(func() {
+ defer wg.Done()
+ for job := range jobCh {
+ if ctx.Err() != nil {
+ if job.bufPtr != nil {
+ bufPool.Put(job.bufPtr)
+ }
+ continue
+ }
+ uploadOpt := &cos.ObjectUploadPartOptions{
+ ContentLength: int64(job.n),
+ }
+ resp, uploadErr := DoWithRetry("cos upload part", func() (*cos.Response, error) {
+ return a.client.Object.UploadPart(ctx, key, output.UploadID, int(job.num), bytes.NewReader(job.buf[:job.n]), uploadOpt)
+ }, maxRetryAttemps, IsRetryableError)
+ if uploadErr != nil {
+ setErr(uploadErr)
+ if job.bufPtr != nil {
+ bufPool.Put(job.bufPtr)
+ }
+ continue
+ }
+ etag := ""
+ if resp != nil && resp.Header != nil {
+ etag = resp.Header.Get("ETag")
+ }
+ if job.bufPtr != nil {
+ bufPool.Put(job.bufPtr)
+ }
+ partsLock.Lock()
+ parts = append(parts, cos.Object{
+ PartNumber: int(job.num),
+ ETag: etag,
+ })
+ partsLock.Unlock()
+ }
+ })
+ }
+
+ for i := 0; i < options.Concurrency; i++ {
+ if submitErr := startWorker(); submitErr != nil {
+ setErr(submitErr)
+ break
+ }
+ }
+
+ sendJob := func(bufPtr *[]byte, buf []byte, n int) bool {
+ partNum++
+ if partNum > maxMultipartParts {
+ setErr(moerr.NewInternalErrorNoCtxf("too many parts for multipart upload: %d", partNum))
+ if bufPtr != nil {
+ bufPool.Put(bufPtr)
+ }
+ return false
+ }
+ job := partJob{
+ num: partNum,
+ buf: buf,
+ bufPtr: bufPtr,
+ n: n,
+ }
+ select {
+ case jobCh <- job:
+ return true
+ case <-ctx.Done():
+ if bufPtr != nil {
+ bufPool.Put(bufPtr)
+ }
+ setErr(ctx.Err())
+ return false
+ }
+ }
+
+ if !sendJob(firstBufPtr, firstBuf, firstN) {
+ close(jobCh)
+ wg.Wait()
+ if firstErr != nil {
+ return firstErr
+ }
+ return ctx.Err()
+ }
+
+ for {
+ nextBufPtr, nextBuf, nextN, readErr := readChunk()
+ if errors.Is(readErr, io.EOF) && nextN == 0 {
+ break
+ }
+ if readErr != nil && !errors.Is(readErr, io.EOF) {
+ setErr(readErr)
+ if nextBufPtr != nil {
+ bufPool.Put(nextBufPtr)
+ }
+ break
+ }
+ if nextN == 0 {
+ if nextBufPtr != nil {
+ bufPool.Put(nextBufPtr)
+ }
+ break
+ }
+ if !sendJob(nextBufPtr, nextBuf, nextN) {
+ break
+ }
+ if readErr != nil && errors.Is(readErr, io.EOF) {
+ break
+ }
+ }
+
+ close(jobCh)
+ wg.Wait()
+
+ if firstErr != nil {
+ err = firstErr
+ return err
+ }
+ if len(parts) == 0 {
+ return nil
+ }
+ if len(parts) != int(partNum) {
+ return moerr.NewInternalErrorNoCtxf("multipart upload incomplete, expect %d parts got %d", partNum, len(parts))
+ }
+
+ sort.Slice(parts, func(i, j int) bool {
+ return parts[i].PartNumber < parts[j].PartNumber
+ })
+
+ completeOpt := &cos.CompleteMultipartUploadOptions{
+ Parts: parts,
+ }
+ _, err = DoWithRetry("cos complete multipart upload", func() (*cos.CompleteMultipartUploadResult, error) {
+ res, _, e := a.client.Object.CompleteMultipartUpload(ctx, key, output.UploadID, completeOpt)
+ return res, e
+ }, maxRetryAttemps, IsRetryableError)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
func (a *QCloudSDK) Read(
ctx context.Context,
key string,
diff --git a/pkg/fileservice/s3_fs.go b/pkg/fileservice/s3_fs.go
index 7a2a58570a838..c1d562b87d77e 100644
--- a/pkg/fileservice/s3_fs.go
+++ b/pkg/fileservice/s3_fs.go
@@ -52,6 +52,8 @@ type S3FS struct {
perfCounterSets []*perfcounter.CounterSet
ioMerger *IOMerger
+
+ parallelMode ParallelMode
}
// key mapping scheme:
@@ -76,6 +78,7 @@ func NewS3FS(
asyncUpdate: true,
perfCounterSets: perfCounterSets,
ioMerger: NewIOMerger(),
+ parallelMode: args.ParallelMode,
}
var err error
@@ -452,8 +455,30 @@ func (s *S3FS) write(ctx context.Context, vector IOVector) (bytesWritten int, er
expire = &vector.ExpireAt
}
key := s.pathToKey(path.File)
- if err := s.storage.Write(ctx, key, reader, size, expire); err != nil {
- return 0, err
+ enableParallel := false
+ switch s.parallelMode {
+ case ParallelForce:
+ enableParallel = true
+ case ParallelAuto:
+ if size == nil || *size >= minMultipartPartSize {
+ enableParallel = true
+ }
+ }
+
+ if pmw, ok := s.storage.(ParallelMultipartWriter); ok && pmw.SupportsParallelMultipart() &&
+ enableParallel {
+ opt := &ParallelMultipartOption{
+ PartSize: defaultParallelMultipartPartSize,
+ Concurrency: runtime.NumCPU(),
+ Expire: expire,
+ }
+ if err := pmw.WriteMultipartParallel(ctx, key, reader, size, opt); err != nil {
+ return 0, err
+ }
+ } else {
+ if err := s.storage.Write(ctx, key, reader, size, expire); err != nil {
+ return 0, err
+ }
}
// write to disk cache
diff --git a/pkg/fileservice/s3_fs_test.go b/pkg/fileservice/s3_fs_test.go
index 042bc27369395..c01189cec2599 100644
--- a/pkg/fileservice/s3_fs_test.go
+++ b/pkg/fileservice/s3_fs_test.go
@@ -226,6 +226,69 @@ func TestDynamicS3Opts(t *testing.T) {
})
}
+func TestDynamicS3NoKey(t *testing.T) {
+ ctx := WithParallelMode(context.Background(), ParallelForce)
+ prefix := "etl-prefix"
+ buf := new(strings.Builder)
+ w := csv.NewWriter(buf)
+ err := w.Write([]string{
+ "s3-no-key",
+ "disk",
+ "",
+ t.TempDir(),
+ prefix,
+ "s3-nokey",
+ })
+ assert.Nil(t, err)
+ w.Flush()
+
+ fs, path, err := GetForETL(ctx, nil, JoinPath(buf.String(), "foo.txt"))
+ assert.Nil(t, err)
+ assert.Equal(t, "foo.txt", path)
+ s3fs, ok := fs.(*S3FS)
+ assert.True(t, ok)
+ assert.Equal(t, ParallelForce, s3fs.parallelMode)
+ assert.Equal(t, prefix, s3fs.keyPrefix)
+ assert.IsType(t, &diskObjectStorage{}, baseStorage(t, s3fs))
+}
+
+func TestDynamicMinio(t *testing.T) {
+ ctx := WithParallelMode(context.Background(), ParallelForce)
+ buf := new(strings.Builder)
+ w := csv.NewWriter(buf)
+ err := w.Write([]string{
+ "minio",
+ "http://127.0.0.1:9000",
+ "us-east-1",
+ "bucket",
+ "ak",
+ "sk",
+ "pref",
+ "minio-etl",
+ })
+ assert.Nil(t, err)
+ w.Flush()
+
+ fs, path, err := GetForETL(ctx, nil, JoinPath(buf.String(), "bar.txt"))
+ assert.Nil(t, err)
+ assert.Equal(t, "bar.txt", path)
+ s3fs, ok := fs.(*S3FS)
+ assert.True(t, ok)
+ assert.Equal(t, ParallelForce, s3fs.parallelMode)
+ assert.Equal(t, "pref", s3fs.keyPrefix)
+ assert.IsType(t, &MinioSDK{}, baseStorage(t, s3fs))
+}
+
+func baseStorage(t *testing.T, fs *S3FS) ObjectStorage {
+ trace, ok := fs.storage.(*objectStorageHTTPTrace)
+ assert.True(t, ok)
+ metrics, ok := trace.upstream.(*objectStorageMetrics)
+ assert.True(t, ok)
+ sem, ok := metrics.upstream.(*objectStorageSemaphore)
+ assert.True(t, ok)
+ return sem.upstream
+}
+
func BenchmarkS3FS(b *testing.B) {
cacheDir := b.TempDir()
b.ResetTimer()
diff --git a/pkg/frontend/data_branch.go b/pkg/frontend/data_branch.go
index 99eb2519c4efb..88df513ac5341 100644
--- a/pkg/frontend/data_branch.go
+++ b/pkg/frontend/data_branch.go
@@ -144,6 +144,42 @@ type batchWithKind struct {
batch *batch.Batch
}
+type emitFunc func(batchWithKind) (stop bool, err error)
+
+func newEmitter(
+ ctx context.Context, stopCh <-chan struct{}, retCh chan batchWithKind,
+) emitFunc {
+ return func(wrapped batchWithKind) (bool, error) {
+ select {
+ case <-ctx.Done():
+ return false, ctx.Err()
+ case <-stopCh:
+ return true, nil
+ default:
+ }
+
+ select {
+ case <-ctx.Done():
+ return false, ctx.Err()
+ case <-stopCh:
+ return true, nil
+ case retCh <- wrapped:
+ return false, nil
+ }
+ }
+}
+
+func emitBatch(
+ emit emitFunc, wrapped batchWithKind, forTombstone bool, pool *retBatchList,
+) (bool, error) {
+ stop, err := emit(wrapped)
+ if stop || err != nil {
+ pool.releaseRetBatch(wrapped.batch, forTombstone)
+ return stop, err
+ }
+ return false, nil
+}
+
type retBatchList struct {
mu sync.Mutex
// 0: data
@@ -633,11 +669,17 @@ func diffMergeAgency(
}()
var (
- dagInfo branchMetaInfo
- tblStuff tableStuff
- copt compositeOption
- ctx, cancel = context.WithCancel(execCtx.reqCtx)
+ ctx context.Context
+ cancel context.CancelFunc
+ )
+ //ctx = fileservice.WithParallelMode(execCtx.reqCtx, fileservice.ParallelForce)
+ ctx, cancel = context.WithCancel(execCtx.reqCtx)
+
+ var (
+ dagInfo branchMetaInfo
+ tblStuff tableStuff
+ copt compositeOption
ok bool
diffStmt *tree.DataBranchDiff
mergeStmt *tree.DataBranchMerge
@@ -683,7 +725,11 @@ func diffMergeAgency(
done bool
wg = new(sync.WaitGroup)
outputErr atomic.Value
- retBatCh = make(chan batchWithKind, 100)
+ retBatCh = make(chan batchWithKind, 10)
+ stopCh = make(chan struct{})
+ stopOnce sync.Once
+ emit emitFunc
+ stop func()
waited bool
)
@@ -699,6 +745,13 @@ func diffMergeAgency(
}
}()
+ emit = newEmitter(ctx, stopCh, retBatCh)
+ stop = func() {
+ stopOnce.Do(func() {
+ close(stopCh)
+ })
+ }
+
if diffStmt != nil {
if err = buildOutputSchema(ctx, ses, diffStmt, tblStuff); err != nil {
return
@@ -725,7 +778,7 @@ func diffMergeAgency(
// 5. as file
if err2 := satisfyDiffOutputOpt(
- ctx, cancel, ses, bh, diffStmt, dagInfo, tblStuff, retBatCh,
+ ctx, cancel, stop, ses, bh, diffStmt, dagInfo, tblStuff, retBatCh,
); err2 != nil {
outputErr.Store(err2)
}
@@ -739,7 +792,7 @@ func diffMergeAgency(
}()
if err = diffOnBase(
- ctx, ses, bh, wg, dagInfo, tblStuff, retBatCh, copt,
+ ctx, ses, bh, wg, dagInfo, tblStuff, copt, emit,
); err != nil {
return
}
@@ -909,15 +962,19 @@ func writeReplaceRowValues(
tblStuff tableStuff,
row []any,
buf *bytes.Buffer,
-) {
+) error {
buf.WriteString("(")
for i, idx := range tblStuff.def.visibleIdxes {
- formatValIntoString(ses, row[idx], tblStuff.def.colTypes[idx], buf)
+ if err := formatValIntoString(ses, row[idx], tblStuff.def.colTypes[idx], buf); err != nil {
+ return err
+ }
if i != len(tblStuff.def.visibleIdxes)-1 {
buf.WriteString(",")
}
}
buf.WriteString(")")
+
+ return nil
}
func writeDeleteRowValues(
@@ -925,12 +982,14 @@ func writeDeleteRowValues(
tblStuff tableStuff,
row []any,
buf *bytes.Buffer,
-) {
+) error {
if len(tblStuff.def.pkColIdxes) > 1 {
buf.WriteString("(")
}
for i, colIdx := range tblStuff.def.pkColIdxes {
- formatValIntoString(ses, row[colIdx], tblStuff.def.colTypes[colIdx], buf)
+ if err := formatValIntoString(ses, row[colIdx], tblStuff.def.colTypes[colIdx], buf); err != nil {
+ return err
+ }
if i != len(tblStuff.def.pkColIdxes)-1 {
buf.WriteString(",")
}
@@ -938,6 +997,8 @@ func writeDeleteRowValues(
if len(tblStuff.def.pkColIdxes) > 1 {
buf.WriteString(")")
}
+
+ return nil
}
func appendBatchRowsAsSQLValues(
@@ -1005,9 +1066,13 @@ func appendBatchRowsAsSQLValues(
tmpValsBuffer.Reset()
if wrapped.kind == diffDelete {
- writeDeleteRowValues(ses, tblStuff, row, tmpValsBuffer)
+ if err = writeDeleteRowValues(ses, tblStuff, row, tmpValsBuffer); err != nil {
+ return
+ }
} else {
- writeReplaceRowValues(ses, tblStuff, row, tmpValsBuffer)
+ if err = writeReplaceRowValues(ses, tblStuff, row, tmpValsBuffer); err != nil {
+ return
+ }
}
if tmpValsBuffer.Len() == 0 {
@@ -1104,6 +1169,7 @@ func mergeDiffs(
func satisfyDiffOutputOpt(
ctx context.Context,
cancel context.CancelFunc,
+ stop func(),
ses *Session,
bh BackgroundExec,
stmt *tree.DataBranchDiff,
@@ -1161,10 +1227,10 @@ func satisfyDiffOutputOpt(
rows = append(rows, row)
if stmt.OutputOpt != nil && stmt.OutputOpt.Limit != nil &&
- int64(mrs.GetRowCount()) >= *stmt.OutputOpt.Limit {
+ int64(len(rows)) >= *stmt.OutputOpt.Limit {
// hit limit, cancel producers but keep draining the channel
hitLimit = true
- cancel()
+ stop()
break
}
}
@@ -2049,6 +2115,12 @@ func writeCSV(
case <-inputCtx.Done():
err = errors.Join(err, inputCtx.Err())
stop = true
+ case e := <-writerErr:
+ if e != nil {
+ err = errors.Join(err, e)
+ }
+ stop = true
+ cancelCtx()
case e, ok := <-errChan:
if !ok {
errOpen = false
@@ -2217,8 +2289,8 @@ func diffOnBase(
wg *sync.WaitGroup,
dagInfo branchMetaInfo,
tblStuff tableStuff,
- retCh chan batchWithKind,
copt compositeOption,
+ emit emitFunc,
) (err error) {
defer func() {
@@ -2278,8 +2350,8 @@ func diffOnBase(
}
if err = hashDiff(
- ctx, ses, bh, tblStuff, dagInfo, retCh,
- copt, tarHandle, baseHandle,
+ ctx, ses, bh, tblStuff, dagInfo,
+ copt, emit, tarHandle, baseHandle,
); err != nil {
return
}
@@ -2407,8 +2479,8 @@ func hashDiff(
bh BackgroundExec,
tblStuff tableStuff,
dagInfo branchMetaInfo,
- retCh chan batchWithKind,
copt compositeOption,
+ emit emitFunc,
tarHandle []engine.ChangesHandle,
baseHandle []engine.ChangesHandle,
) (
@@ -2452,7 +2524,7 @@ func hashDiff(
if dagInfo.lcaType == lcaEmpty {
if err = hashDiffIfNoLCA(
- ctx, ses, tblStuff, retCh, copt,
+ ctx, ses, tblStuff, copt, emit,
tarDataHashmap, tarTombstoneHashmap,
baseDataHashmap, baseTombstoneHashmap,
); err != nil {
@@ -2460,7 +2532,7 @@ func hashDiff(
}
} else {
if err = hashDiffIfHasLCA(
- ctx, ses, bh, dagInfo, tblStuff, retCh, copt,
+ ctx, ses, bh, dagInfo, tblStuff, copt, emit,
tarDataHashmap, tarTombstoneHashmap,
baseDataHashmap, baseTombstoneHashmap,
); err != nil {
@@ -2477,8 +2549,8 @@ func hashDiffIfHasLCA(
bh BackgroundExec,
dagInfo branchMetaInfo,
tblStuff tableStuff,
- retCh chan batchWithKind,
copt compositeOption,
+ emit emitFunc,
tarDataHashmap databranchutils.BranchHashmap,
tarTombstoneHashmap databranchutils.BranchHashmap,
baseDataHashmap databranchutils.BranchHashmap,
@@ -2511,7 +2583,11 @@ func hashDiffIfHasLCA(
handleTarDeleteAndUpdates := func(wrapped batchWithKind) (err2 error) {
if len(baseUpdateBatches) == 0 && len(baseDeleteBatches) == 0 {
// no need to check conflict
- retCh <- wrapped
+ if stop, e := emitBatch(emit, wrapped, false, tblStuff.retPool); e != nil {
+ return e
+ } else if stop {
+ return nil
+ }
return nil
}
@@ -2576,7 +2652,12 @@ func hashDiffIfHasLCA(
}
buf := acquireBuffer(tblStuff.bufPool)
- formatValIntoString(ses, tarRow[0], tblStuff.def.colTypes[tblStuff.def.pkColIdx], buf)
+ if err3 = formatValIntoString(
+ ses, tarRow[0], tblStuff.def.colTypes[tblStuff.def.pkColIdx], buf,
+ ); err3 != nil {
+ releaseBuffer(tblStuff.bufPool, buf)
+ return
+ }
err3 = moerr.NewInternalErrorNoCtxf(
"conflict: %s %s and %s %s on pk(%v) with different values",
@@ -2653,7 +2734,13 @@ func hashDiffIfHasLCA(
return
}
- retCh <- wrapped
+ stop, e := emitBatch(emit, wrapped, false, tblStuff.retPool)
+ if e != nil {
+ return e
+ }
+ if stop {
+ return nil
+ }
return
}
@@ -2725,12 +2812,16 @@ func hashDiffIfHasLCA(
} else {
err2 = handleTarDeleteAndUpdates(wrapped)
}
+
+ if errors.Is(err2, context.Canceled) {
+ err2 = nil
+ cancel()
+ }
}
if err2 != nil {
first = err2
cancel()
- tblStuff.retPool.releaseRetBatch(wrapped.batch, false)
}
}
@@ -2758,24 +2849,49 @@ func hashDiffIfHasLCA(
// what can I do with these left base updates/inserts ?
if copt.conflictOpt == nil {
- for _, w := range baseUpdateBatches {
- retCh <- w
+ stopped := false
+ for i, w := range baseUpdateBatches {
+ var stop bool
+ if stop, err = emitBatch(emit, w, false, tblStuff.retPool); err != nil {
+ return err
+ }
+ if stop {
+ stopped = true
+ for j := i + 1; j < len(baseUpdateBatches); j++ {
+ tblStuff.retPool.releaseRetBatch(baseUpdateBatches[j].batch, false)
+ }
+ for _, bw := range baseDeleteBatches {
+ tblStuff.retPool.releaseRetBatch(bw.batch, false)
+ }
+ break
+ }
}
- for _, w := range baseDeleteBatches {
- retCh <- w
+ if !stopped {
+ for i, w := range baseDeleteBatches {
+ var stop bool
+ if stop, err = emitBatch(emit, w, false, tblStuff.retPool); err != nil {
+ return err
+ }
+ if stop {
+ for j := i + 1; j < len(baseDeleteBatches); j++ {
+ tblStuff.retPool.releaseRetBatch(baseDeleteBatches[j].batch, false)
+ }
+ break
+ }
+ }
}
}
- return diffDataHelper(ctx, ses, copt, tblStuff, retCh, tarDataHashmap, baseDataHashmap)
+ return diffDataHelper(ctx, ses, copt, tblStuff, emit, tarDataHashmap, baseDataHashmap)
}
func hashDiffIfNoLCA(
ctx context.Context,
ses *Session,
tblStuff tableStuff,
- retCh chan batchWithKind,
copt compositeOption,
+ emit emitFunc,
tarDataHashmap databranchutils.BranchHashmap,
tarTombstoneHashmap databranchutils.BranchHashmap,
baseDataHashmap databranchutils.BranchHashmap,
@@ -2802,7 +2918,7 @@ func hashDiffIfNoLCA(
return
}
- return diffDataHelper(ctx, ses, copt, tblStuff, retCh, tarDataHashmap, baseDataHashmap)
+ return diffDataHelper(ctx, ses, copt, tblStuff, emit, tarDataHashmap, baseDataHashmap)
}
func compareRowInWrappedBatches(
@@ -3005,7 +3121,10 @@ func checkConflictAndAppendToBat(
case tree.CONFLICT_FAIL:
buf := acquireBuffer(tblStuff.bufPool)
for i, idx := range tblStuff.def.pkColIdxes {
- formatValIntoString(ses, tarTuple[idx], tblStuff.def.colTypes[idx], buf)
+ if err2 = formatValIntoString(ses, tarTuple[idx], tblStuff.def.colTypes[idx], buf); err2 != nil {
+ releaseBuffer(tblStuff.bufPool, buf)
+ return err2
+ }
if i < len(tblStuff.def.pkColIdxes)-1 {
buf.WriteString(",")
}
@@ -3041,7 +3160,7 @@ func diffDataHelper(
ses *Session,
copt compositeOption,
tblStuff tableStuff,
- retCh chan batchWithKind,
+ emit emitFunc,
tarDataHashmap databranchutils.BranchHashmap,
baseDataHashmap databranchutils.BranchHashmap,
) (err error) {
@@ -3156,16 +3275,24 @@ func diffDataHelper(
default:
}
- retCh <- batchWithKind{
+ if stop, err3 := emitBatch(emit, batchWithKind{
batch: tarBat,
kind: diffInsert,
name: tblStuff.tarRel.GetTableName(),
+ }, false, tblStuff.retPool); err3 != nil {
+ return err3
+ } else if stop {
+ return nil
}
- retCh <- batchWithKind{
+ if stop, err3 := emitBatch(emit, batchWithKind{
batch: baseBat,
kind: diffInsert,
name: tblStuff.baseRel.GetTableName(),
+ }, false, tblStuff.retPool); err3 != nil {
+ return err3
+ } else if stop {
+ return nil
}
return nil
@@ -3214,10 +3341,14 @@ func diffDataHelper(
default:
}
- retCh <- batchWithKind{
+ if stop, err3 := emitBatch(emit, batchWithKind{
batch: bat,
kind: diffInsert,
name: tblStuff.baseRel.GetTableName(),
+ }, false, tblStuff.retPool); err3 != nil {
+ return err3
+ } else if stop {
+ return nil
}
return nil
}, -1); err != nil {
@@ -3313,7 +3444,11 @@ func handleDelsOnLCA(
valsBuf.WriteString(fmt.Sprintf("row(%d,", i))
for j := range tuple {
- formatValIntoString(ses, tuple[j], colTypes[expandedPKColIdxes[j]], valsBuf)
+ if err = formatValIntoString(
+ ses, tuple[j], colTypes[expandedPKColIdxes[j]], valsBuf,
+ ); err != nil {
+ return nil, err
+ }
if j != len(tuple)-1 {
valsBuf.WriteString(", ")
}
@@ -3461,10 +3596,10 @@ func handleDelsOnLCA(
return
}
-func formatValIntoString(ses *Session, val any, t types.Type, buf *bytes.Buffer) {
+func formatValIntoString(ses *Session, val any, t types.Type, buf *bytes.Buffer) error {
if val == nil {
buf.WriteString("NULL")
- return
+ return nil
}
var scratch [64]byte
@@ -3481,6 +3616,10 @@ func formatValIntoString(ses *Session, val any, t types.Type, buf *bytes.Buffer)
buf.Write(strconv.AppendFloat(scratch[:0], v, 'g', -1, bitSize))
}
+ writeBool := func(v bool) {
+ buf.WriteString(strconv.FormatBool(v))
+ }
+
switch t.Oid {
case types.T_varchar, types.T_text, types.T_json, types.T_char, types.
T_varbinary, types.T_binary:
@@ -3491,7 +3630,7 @@ func formatValIntoString(ses *Session, val any, t types.Type, buf *bytes.Buffer)
strVal = x.String()
case *bytejson.ByteJson:
if x == nil {
- panic(moerr.NewInternalErrorNoCtx("formatValIntoString: nil *bytejson.ByteJson"))
+ return moerr.NewInternalErrorNoCtx("formatValIntoString: nil *bytejson.ByteJson")
}
strVal = x.String()
case []byte:
@@ -3499,14 +3638,14 @@ func formatValIntoString(ses *Session, val any, t types.Type, buf *bytes.Buffer)
case string:
strVal = x
default:
- panic(moerr.NewInternalErrorNoCtxf("formatValIntoString: unexpected json type %T", val))
+ return moerr.NewInternalErrorNoCtxf("formatValIntoString: unexpected json type %T", val)
}
jsonLiteral := escapeJSONControlBytes([]byte(strVal))
if !json.Valid(jsonLiteral) {
- panic(moerr.NewInternalErrorNoCtxf("formatValIntoString: invalid json input %q", strVal))
+ return moerr.NewInternalErrorNoCtxf("formatValIntoString: invalid json input %q", strVal)
}
writeEscapedSQLString(buf, jsonLiteral)
- return
+ return nil
}
switch x := val.(type) {
case []byte:
@@ -3514,7 +3653,7 @@ func formatValIntoString(ses *Session, val any, t types.Type, buf *bytes.Buffer)
case string:
writeEscapedSQLString(buf, []byte(x))
default:
- panic(moerr.NewInternalErrorNoCtxf("formatValIntoString: unexpected string type %T", val))
+ return moerr.NewInternalErrorNoCtxf("formatValIntoString: unexpected string type %T", val)
}
case types.T_timestamp:
buf.WriteString("'")
@@ -3524,6 +3663,10 @@ func formatValIntoString(ses *Session, val any, t types.Type, buf *bytes.Buffer)
buf.WriteString("'")
buf.WriteString(val.(types.Datetime).String2(t.Scale))
buf.WriteString("'")
+ case types.T_time:
+ buf.WriteString("'")
+ buf.WriteString(val.(types.Time).String2(t.Scale))
+ buf.WriteString("'")
case types.T_date:
buf.WriteString("'")
buf.WriteString(val.(types.Date).String())
@@ -3534,6 +3677,8 @@ func formatValIntoString(ses *Session, val any, t types.Type, buf *bytes.Buffer)
buf.WriteString(val.(types.Decimal128).Format(t.Scale))
case types.T_decimal256:
buf.WriteString(val.(types.Decimal256).Format(t.Scale))
+ case types.T_bool:
+ writeBool(val.(bool))
case types.T_uint8:
writeUint(uint64(val.(uint8)))
case types.T_uint16:
@@ -3563,8 +3708,10 @@ func formatValIntoString(ses *Session, val any, t types.Type, buf *bytes.Buffer)
buf.WriteString(types.ArrayToString[float64](val.([]float64)))
buf.WriteString("'")
default:
- panic(moerr.NewInternalErrorNoCtxf("formatValIntoString: unsupported type %v", t.Oid))
+ return moerr.NewNotSupportedNoCtxf("formatValIntoString: not support type %v", t.Oid)
}
+
+ return nil
}
// writeEscapedSQLString escapes special and control characters for SQL literal output.
@@ -4389,114 +4536,140 @@ func compareSingleValInVector(
// Use raw values to avoid format conversions in extractRowFromVector.
switch vec1.GetType().Oid {
case types.T_json:
- val1 := types.DecodeJson(vec1.GetBytesAt(rowIdx1))
- val2 := types.DecodeJson(vec2.GetBytesAt(rowIdx2))
- return bytejson.CompareByteJson(val1, val2), nil
+ return bytejson.CompareByteJson(
+ types.DecodeJson(vec1.GetBytesAt(rowIdx1)),
+ types.DecodeJson(vec2.GetBytesAt(rowIdx2)),
+ ), nil
case types.T_bool:
- val1 := vector.GetFixedAtNoTypeCheck[bool](vec1, rowIdx1)
- val2 := vector.GetFixedAtNoTypeCheck[bool](vec2, rowIdx2)
- return types.CompareValue(val1, val2), nil
+ return types.CompareValue(
+ vector.GetFixedAtNoTypeCheck[bool](vec1, rowIdx1),
+ vector.GetFixedAtNoTypeCheck[bool](vec2, rowIdx2),
+ ), nil
case types.T_bit:
- val1 := vector.GetFixedAtNoTypeCheck[uint64](vec1, rowIdx1)
- val2 := vector.GetFixedAtNoTypeCheck[uint64](vec2, rowIdx2)
- return types.CompareValue(val1, val2), nil
+ return types.CompareValue(
+ vector.GetFixedAtNoTypeCheck[uint64](vec1, rowIdx1),
+ vector.GetFixedAtNoTypeCheck[uint64](vec2, rowIdx2),
+ ), nil
case types.T_int8:
- val1 := vector.GetFixedAtNoTypeCheck[int8](vec1, rowIdx1)
- val2 := vector.GetFixedAtNoTypeCheck[int8](vec2, rowIdx2)
- return types.CompareValue(val1, val2), nil
+ return types.CompareValue(
+ vector.GetFixedAtNoTypeCheck[int8](vec1, rowIdx1),
+ vector.GetFixedAtNoTypeCheck[int8](vec2, rowIdx2),
+ ), nil
case types.T_uint8:
- val1 := vector.GetFixedAtNoTypeCheck[uint8](vec1, rowIdx1)
- val2 := vector.GetFixedAtNoTypeCheck[uint8](vec2, rowIdx2)
- return types.CompareValue(val1, val2), nil
+ return types.CompareValue(
+ vector.GetFixedAtNoTypeCheck[uint8](vec1, rowIdx1),
+ vector.GetFixedAtNoTypeCheck[uint8](vec2, rowIdx2),
+ ), nil
case types.T_int16:
- val1 := vector.GetFixedAtNoTypeCheck[int16](vec1, rowIdx1)
- val2 := vector.GetFixedAtNoTypeCheck[int16](vec2, rowIdx2)
- return types.CompareValue(val1, val2), nil
+ return types.CompareValue(
+ vector.GetFixedAtNoTypeCheck[int16](vec1, rowIdx1),
+ vector.GetFixedAtNoTypeCheck[int16](vec2, rowIdx2),
+ ), nil
case types.T_uint16:
- val1 := vector.GetFixedAtNoTypeCheck[uint16](vec1, rowIdx1)
- val2 := vector.GetFixedAtNoTypeCheck[uint16](vec2, rowIdx2)
- return types.CompareValue(val1, val2), nil
+ return types.CompareValue(
+ vector.GetFixedAtNoTypeCheck[uint16](vec1, rowIdx1),
+ vector.GetFixedAtNoTypeCheck[uint16](vec2, rowIdx2),
+ ), nil
case types.T_int32:
- val1 := vector.GetFixedAtNoTypeCheck[int32](vec1, rowIdx1)
- val2 := vector.GetFixedAtNoTypeCheck[int32](vec2, rowIdx2)
- return types.CompareValue(val1, val2), nil
+ return types.CompareValue(
+ vector.GetFixedAtNoTypeCheck[int32](vec1, rowIdx1),
+ vector.GetFixedAtNoTypeCheck[int32](vec2, rowIdx2),
+ ), nil
case types.T_uint32:
- val1 := vector.GetFixedAtNoTypeCheck[uint32](vec1, rowIdx1)
- val2 := vector.GetFixedAtNoTypeCheck[uint32](vec2, rowIdx2)
- return types.CompareValue(val1, val2), nil
+ return types.CompareValue(
+ vector.GetFixedAtNoTypeCheck[uint32](vec1, rowIdx1),
+ vector.GetFixedAtNoTypeCheck[uint32](vec2, rowIdx2),
+ ), nil
case types.T_int64:
- val1 := vector.GetFixedAtNoTypeCheck[int64](vec1, rowIdx1)
- val2 := vector.GetFixedAtNoTypeCheck[int64](vec2, rowIdx2)
- return types.CompareValue(val1, val2), nil
+ return types.CompareValue(
+ vector.GetFixedAtNoTypeCheck[int64](vec1, rowIdx1),
+ vector.GetFixedAtNoTypeCheck[int64](vec2, rowIdx2),
+ ), nil
case types.T_uint64:
- val1 := vector.GetFixedAtNoTypeCheck[uint64](vec1, rowIdx1)
- val2 := vector.GetFixedAtNoTypeCheck[uint64](vec2, rowIdx2)
- return types.CompareValue(val1, val2), nil
+ return types.CompareValue(
+ vector.GetFixedAtNoTypeCheck[uint64](vec1, rowIdx1),
+ vector.GetFixedAtNoTypeCheck[uint64](vec2, rowIdx2),
+ ), nil
case types.T_float32:
- val1 := vector.GetFixedAtNoTypeCheck[float32](vec1, rowIdx1)
- val2 := vector.GetFixedAtNoTypeCheck[float32](vec2, rowIdx2)
- return types.CompareValue(val1, val2), nil
+ return types.CompareValue(
+ vector.GetFixedAtNoTypeCheck[float32](vec1, rowIdx1),
+ vector.GetFixedAtNoTypeCheck[float32](vec2, rowIdx2),
+ ), nil
case types.T_float64:
- val1 := vector.GetFixedAtNoTypeCheck[float64](vec1, rowIdx1)
- val2 := vector.GetFixedAtNoTypeCheck[float64](vec2, rowIdx2)
- return types.CompareValue(val1, val2), nil
+ return types.CompareValue(
+ vector.GetFixedAtNoTypeCheck[float64](vec1, rowIdx1),
+ vector.GetFixedAtNoTypeCheck[float64](vec2, rowIdx2),
+ ), nil
case types.T_char, types.T_varchar, types.T_blob, types.T_text, types.T_binary, types.T_varbinary, types.T_datalink:
return bytes.Compare(
vec1.GetBytesAt(rowIdx1),
vec2.GetBytesAt(rowIdx2),
), nil
case types.T_array_float32:
- val1 := vector.GetArrayAt[float32](vec1, rowIdx1)
- val2 := vector.GetArrayAt[float32](vec2, rowIdx2)
- return types.CompareValue(val1, val2), nil
+ return types.CompareValue(
+ vector.GetArrayAt[float32](vec1, rowIdx1),
+ vector.GetArrayAt[float32](vec2, rowIdx2),
+ ), nil
case types.T_array_float64:
- val1 := vector.GetArrayAt[float64](vec1, rowIdx1)
- val2 := vector.GetArrayAt[float64](vec2, rowIdx2)
- return types.CompareValue(val1, val2), nil
+ return types.CompareValue(
+ vector.GetArrayAt[float64](vec1, rowIdx1),
+ vector.GetArrayAt[float64](vec2, rowIdx2),
+ ), nil
case types.T_date:
- val1 := vector.GetFixedAtNoTypeCheck[types.Date](vec1, rowIdx1)
- val2 := vector.GetFixedAtNoTypeCheck[types.Date](vec2, rowIdx2)
- return types.CompareValue(val1, val2), nil
+ return types.CompareValue(
+ vector.GetFixedAtNoTypeCheck[types.Date](vec1, rowIdx1),
+ vector.GetFixedAtNoTypeCheck[types.Date](vec2, rowIdx2),
+ ), nil
case types.T_datetime:
- val1 := vector.GetFixedAtNoTypeCheck[types.Datetime](vec1, rowIdx1)
- val2 := vector.GetFixedAtNoTypeCheck[types.Datetime](vec2, rowIdx2)
- return types.CompareValue(val1, val2), nil
+ return types.CompareValue(
+ vector.GetFixedAtNoTypeCheck[types.Datetime](vec1, rowIdx1),
+ vector.GetFixedAtNoTypeCheck[types.Datetime](vec2, rowIdx2),
+ ), nil
case types.T_time:
- val1 := vector.GetFixedAtNoTypeCheck[types.Time](vec1, rowIdx1)
- val2 := vector.GetFixedAtNoTypeCheck[types.Time](vec2, rowIdx2)
- return types.CompareValue(val1, val2), nil
+ return types.CompareValue(
+ vector.GetFixedAtNoTypeCheck[types.Time](vec1, rowIdx1),
+ vector.GetFixedAtNoTypeCheck[types.Time](vec2, rowIdx2),
+ ), nil
case types.T_timestamp:
- val1 := vector.GetFixedAtNoTypeCheck[types.Timestamp](vec1, rowIdx1)
- val2 := vector.GetFixedAtNoTypeCheck[types.Timestamp](vec2, rowIdx2)
- return types.CompareValue(val1, val2), nil
+ return types.CompareValue(
+ vector.GetFixedAtNoTypeCheck[types.Timestamp](vec1, rowIdx1),
+ vector.GetFixedAtNoTypeCheck[types.Timestamp](vec2, rowIdx2),
+ ), nil
case types.T_decimal64:
- val1 := vector.GetFixedAtNoTypeCheck[types.Decimal64](vec1, rowIdx1)
- val2 := vector.GetFixedAtNoTypeCheck[types.Decimal64](vec2, rowIdx2)
- return types.CompareValue(val1, val2), nil
+ return types.CompareValue(
+ vector.GetFixedAtNoTypeCheck[types.Decimal64](vec1, rowIdx1),
+ vector.GetFixedAtNoTypeCheck[types.Decimal64](vec2, rowIdx2),
+ ), nil
case types.T_decimal128:
- val1 := vector.GetFixedAtNoTypeCheck[types.Decimal128](vec1, rowIdx1)
- val2 := vector.GetFixedAtNoTypeCheck[types.Decimal128](vec2, rowIdx2)
- return types.CompareValue(val1, val2), nil
+ return types.CompareValue(
+ vector.GetFixedAtNoTypeCheck[types.Decimal128](vec1, rowIdx1),
+ vector.GetFixedAtNoTypeCheck[types.Decimal128](vec2, rowIdx2),
+ ), nil
case types.T_uuid:
- val1 := vector.GetFixedAtNoTypeCheck[types.Uuid](vec1, rowIdx1)
- val2 := vector.GetFixedAtNoTypeCheck[types.Uuid](vec2, rowIdx2)
- return types.CompareValue(val1, val2), nil
+ return types.CompareValue(
+ vector.GetFixedAtNoTypeCheck[types.Uuid](vec1, rowIdx1),
+ vector.GetFixedAtNoTypeCheck[types.Uuid](vec2, rowIdx2),
+ ), nil
case types.T_Rowid:
- val1 := vector.GetFixedAtNoTypeCheck[types.Rowid](vec1, rowIdx1)
- val2 := vector.GetFixedAtNoTypeCheck[types.Rowid](vec2, rowIdx2)
- return types.CompareValue(val1, val2), nil
+ return types.CompareValue(
+ vector.GetFixedAtNoTypeCheck[types.Rowid](vec1, rowIdx1),
+ vector.GetFixedAtNoTypeCheck[types.Rowid](vec2, rowIdx2),
+ ), nil
case types.T_Blockid:
- val1 := vector.GetFixedAtNoTypeCheck[types.Blockid](vec1, rowIdx1)
- val2 := vector.GetFixedAtNoTypeCheck[types.Blockid](vec2, rowIdx2)
- return types.CompareValue(val1, val2), nil
+ return types.CompareValue(
+ vector.GetFixedAtNoTypeCheck[types.Blockid](vec1, rowIdx1),
+ vector.GetFixedAtNoTypeCheck[types.Blockid](vec2, rowIdx2),
+ ), nil
case types.T_TS:
- val1 := vector.GetFixedAtNoTypeCheck[types.TS](vec1, rowIdx1)
- val2 := vector.GetFixedAtNoTypeCheck[types.TS](vec2, rowIdx2)
- return types.CompareValue(val1, val2), nil
+ return types.CompareValue(
+ vector.GetFixedAtNoTypeCheck[types.TS](vec1, rowIdx1),
+ vector.GetFixedAtNoTypeCheck[types.TS](vec2, rowIdx2),
+ ), nil
case types.T_enum:
- val1 := vector.GetFixedAtNoTypeCheck[types.Enum](vec1, rowIdx1)
- val2 := vector.GetFixedAtNoTypeCheck[types.Enum](vec2, rowIdx2)
- return types.CompareValue(val1, val2), nil
+ return types.CompareValue(
+ vector.GetFixedAtNoTypeCheck[types.Enum](vec1, rowIdx1),
+ vector.GetFixedAtNoTypeCheck[types.Enum](vec2, rowIdx2),
+ ), nil
default:
return 0, moerr.NewInternalErrorNoCtxf("compareSingleValInVector : unsupported type %d", vec1.GetType().Oid)
}
diff --git a/pkg/frontend/data_branch_test.go b/pkg/frontend/data_branch_test.go
index 0d9aba8339b0f..c6be503e4c9af 100644
--- a/pkg/frontend/data_branch_test.go
+++ b/pkg/frontend/data_branch_test.go
@@ -40,7 +40,7 @@ func TestFormatValIntoString_StringEscaping(t *testing.T) {
ses := &Session{}
val := "a'b\"c\\\n\t\r\x1a\x00"
- formatValIntoString(ses, val, types.New(types.T_varchar, 0, 0), &buf)
+ require.NoError(t, formatValIntoString(ses, val, types.New(types.T_varchar, 0, 0), &buf))
require.Equal(t, `'a\'b"c\\\n\t\r\Z\0'`, buf.String())
}
@@ -49,16 +49,27 @@ func TestFormatValIntoString_ByteEscaping(t *testing.T) {
ses := &Session{}
val := []byte{'x', 0x00, '\\', 0x07, '\''}
- formatValIntoString(ses, val, types.New(types.T_varbinary, 0, 0), &buf)
+ require.NoError(t, formatValIntoString(ses, val, types.New(types.T_varbinary, 0, 0), &buf))
require.Equal(t, `'x\0\\\x07\''`, buf.String())
}
+func TestFormatValIntoString_Time(t *testing.T) {
+ var buf bytes.Buffer
+ ses := &Session{}
+
+ val, err := types.ParseTime("12:34:56.123456", 6)
+ require.NoError(t, err)
+
+ require.NoError(t, formatValIntoString(ses, val, types.New(types.T_time, 0, 6), &buf))
+ require.Equal(t, `'12:34:56.123456'`, buf.String())
+}
+
func TestFormatValIntoString_JSONEscaping(t *testing.T) {
var buf bytes.Buffer
ses := &Session{}
val := `{"k":"` + string([]byte{0x01, '\n'}) + `"}`
- formatValIntoString(ses, val, types.New(types.T_json, 0, 0), &buf)
+ require.NoError(t, formatValIntoString(ses, val, types.New(types.T_json, 0, 0), &buf))
require.Equal(t, `'{"k":"\\u0001\\u000a"}'`, buf.String())
}
@@ -69,7 +80,7 @@ func TestFormatValIntoString_JSONByteJson(t *testing.T) {
bj, err := types.ParseStringToByteJson(`{"a":1}`)
require.NoError(t, err)
- formatValIntoString(ses, bj, types.New(types.T_json, 0, 0), &buf)
+ require.NoError(t, formatValIntoString(ses, bj, types.New(types.T_json, 0, 0), &buf))
require.Equal(t, `'{"a": 1}'`, buf.String())
}
@@ -77,7 +88,7 @@ func TestFormatValIntoString_Nil(t *testing.T) {
var buf bytes.Buffer
ses := &Session{}
- formatValIntoString(ses, nil, types.New(types.T_varchar, 0, 0), &buf)
+ require.NoError(t, formatValIntoString(ses, nil, types.New(types.T_varchar, 0, 0), &buf))
require.Equal(t, "NULL", buf.String())
}
@@ -85,9 +96,9 @@ func TestFormatValIntoString_UnsupportedType(t *testing.T) {
var buf bytes.Buffer
ses := &Session{}
- require.Panics(t, func() {
- formatValIntoString(ses, true, types.New(types.T_bool, 0, 0), &buf)
- })
+ err := formatValIntoString(ses, true, types.New(types.T_Rowid, 0, 0), &buf)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "not support type")
}
func TestCompareSingleValInVector_AllTypes(t *testing.T) {
diff --git a/pkg/frontend/export.go b/pkg/frontend/export.go
index 1a62b65cca1cf..e6a0316a23551 100644
--- a/pkg/frontend/export.go
+++ b/pkg/frontend/export.go
@@ -364,6 +364,21 @@ func constructByte(ctx context.Context, obj FeSession, bat *batch.Batch, index i
mp = ss.GetMemPool()
}
+ // respect cancellation to avoid blocking when downstream writer stops
+ if ctx.Err() != nil {
+ bat.Clean(mp)
+ return
+ }
+
+ sendByte := func(bb *BatchByte) bool {
+ select {
+ case ByteChan <- bb:
+ return true
+ case <-ctx.Done():
+ return false
+ }
+ }
+
symbol := ep.Symbol
closeby := ep.userConfig.Fields.EnclosedBy.Value
terminated := ep.userConfig.Fields.Terminated.Value
@@ -501,9 +516,10 @@ func constructByte(ctx context.Context, obj FeSession, bat *batch.Batch, index i
zap.Int("typeOid", int(vec.GetType().Oid)))
}
- ByteChan <- &BatchByte{
+ // stop early if downstream already failed
+ sendByte(&BatchByte{
err: moerr.NewInternalErrorf(ctx, "constructByte : unsupported type %d", vec.GetType().Oid),
- }
+ })
bat.Clean(mp)
return
}
@@ -516,10 +532,13 @@ func constructByte(ctx context.Context, obj FeSession, bat *batch.Batch, index i
copy(result, buffer.Bytes())
buffer = nil
- ByteChan <- &BatchByte{
+ if !sendByte(&BatchByte{
index: index,
writeByte: result,
err: nil,
+ }) {
+ bat.Clean(mp)
+ return
}
if ss != nil {
diff --git a/pkg/tests/dml/dml_test.go b/pkg/tests/dml/dml_test.go
index ecb29a789837a..49376bad692cf 100644
--- a/pkg/tests/dml/dml_test.go
+++ b/pkg/tests/dml/dml_test.go
@@ -152,6 +152,15 @@ func TestDataBranchDiffAsFile(t *testing.T) {
t.Log("csv diff covers rich data type payloads")
runCSVLoadRichTypes(t, ctx, sqlDB)
+ t.Log("diff output limit returns subset of full diff")
+ runDiffOutputLimitSubset(t, ctx, sqlDB)
+
+ t.Log("diff output limit without branch relationship returns subset of full diff")
+ runDiffOutputLimitNoBase(t, ctx, sqlDB)
+
+ t.Log("diff output limit with large base workload still returns subset of full diff")
+ runDiffOutputLimitLargeBase(t, ctx, sqlDB)
+
t.Log("diff output to stage and load via datalink")
runDiffOutputToStage(t, ctx, sqlDB)
})
@@ -532,6 +541,149 @@ insert into %s values
assertTablesEqual(t, ctx, db, dbName, target, base)
}
+func runDiffOutputLimitSubset(t *testing.T, parentCtx context.Context, db *sql.DB) {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(parentCtx, time.Second*90)
+ defer cancel()
+
+ dbName := testutils.GetDatabaseName(t)
+ base := "limit_base"
+ branch := "limit_branch"
+
+ execSQLDB(t, ctx, db, fmt.Sprintf("create database `%s`", dbName))
+ defer func() {
+ execSQLDB(t, ctx, db, "use mo_catalog")
+ execSQLDB(t, ctx, db, fmt.Sprintf("drop database if exists `%s`", dbName))
+ }()
+ execSQLDB(t, ctx, db, fmt.Sprintf("use `%s`", dbName))
+
+ execSQLDB(t, ctx, db, fmt.Sprintf("create table %s (id int primary key, val int, note varchar(16))", base))
+ execSQLDB(t, ctx, db, fmt.Sprintf("insert into %s values (1, 10, 'seed'), (2, 20, 'seed'), (3, 30, 'seed'), (4, 40, 'seed'), (5, 50, 'seed'), (6, 60, 'seed')", base))
+
+ execSQLDB(t, ctx, db, fmt.Sprintf("data branch create table %s from %s", branch, base))
+ execSQLDB(t, ctx, db, fmt.Sprintf("insert into %s values (7, 70, 'inserted'), (8, 80, 'inserted')", branch))
+ execSQLDB(t, ctx, db, fmt.Sprintf("update %s set val = val + 100, note = 'updated' where id in (2,3)", branch))
+ execSQLDB(t, ctx, db, fmt.Sprintf("delete from %s where id in (4,5)", branch))
+
+ fullStmt := fmt.Sprintf("data branch diff %s against %s", branch, base)
+ fullRows := fetchDiffRowsAsStrings(t, ctx, db, fullStmt)
+ require.GreaterOrEqual(t, len(fullRows), 6)
+
+ limit := 3
+ limitStmt := fmt.Sprintf("data branch diff %s against %s output limit %d", branch, base, limit)
+ limitedRows := fetchDiffRowsAsStrings(t, ctx, db, limitStmt)
+
+ require.NotEmpty(t, limitedRows, "limited diff returned no rows")
+ require.LessOrEqual(t, len(limitedRows), limit, "limited diff returned too many rows")
+
+ fullSet := make(map[string]struct{}, len(fullRows))
+ for _, row := range fullRows {
+ fullSet[strings.Join(row, "||")] = struct{}{}
+ }
+ for _, row := range limitedRows {
+ _, ok := fullSet[strings.Join(row, "||")]
+ require.Truef(t, ok, "limited diff row not contained in full diff: %v", row)
+ }
+}
+
+func runDiffOutputLimitNoBase(t *testing.T, parentCtx context.Context, db *sql.DB) {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(parentCtx, time.Second*90)
+ defer cancel()
+
+ dbName := testutils.GetDatabaseName(t)
+ base := "limit_nobranch_base"
+ target := "limit_nobranch_target"
+
+ execSQLDB(t, ctx, db, fmt.Sprintf("create database `%s`", dbName))
+ defer func() {
+ execSQLDB(t, ctx, db, "use mo_catalog")
+ execSQLDB(t, ctx, db, fmt.Sprintf("drop database if exists `%s`", dbName))
+ }()
+ execSQLDB(t, ctx, db, fmt.Sprintf("use `%s`", dbName))
+
+ execSQLDB(t, ctx, db, fmt.Sprintf("create table %s (id int primary key, val int, note varchar(16))", base))
+ execSQLDB(t, ctx, db, fmt.Sprintf("create table %s (id int primary key, val int, note varchar(16))", target))
+
+ execSQLDB(t, ctx, db, fmt.Sprintf("insert into %s values (1, 10, 'seed'), (2, 20, 'seed'), (3, 30, 'seed'), (4, 40, 'seed')", base))
+
+ execSQLDB(t, ctx, db, fmt.Sprintf("insert into %s values (1, 110, 'updated'), (2, 20, 'seed'), (5, 500, 'added'), (6, 600, 'added')", target))
+
+ fullStmt := fmt.Sprintf("data branch diff %s against %s", target, base)
+ fullRows := fetchDiffRowsAsStrings(t, ctx, db, fullStmt)
+ require.GreaterOrEqual(t, len(fullRows), 3)
+
+ limit := 1
+ limitStmt := fmt.Sprintf("data branch diff %s against %s output limit %d", target, base, limit)
+ limitedRows := fetchDiffRowsAsStrings(t, ctx, db, limitStmt)
+
+ require.NotEmpty(t, limitedRows, "limited diff returned no rows")
+ require.LessOrEqual(t, len(limitedRows), limit, "limited diff returned too many rows")
+
+ fullSet := make(map[string]struct{}, len(fullRows))
+ for _, row := range fullRows {
+ fullSet[strings.Join(row, "||")] = struct{}{}
+ }
+ for _, row := range limitedRows {
+ _, ok := fullSet[strings.Join(row, "||")]
+ require.Truef(t, ok, "limited diff row not contained in full diff: %v", row)
+ }
+}
+
+func runDiffOutputLimitLargeBase(t *testing.T, parentCtx context.Context, db *sql.DB) {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(parentCtx, time.Second*180)
+ defer cancel()
+
+ dbName := testutils.GetDatabaseName(t)
+ base := "limit_large_t1"
+ branch := "limit_large_t2"
+
+ execSQLDB(t, ctx, db, fmt.Sprintf("create database `%s`", dbName))
+ defer func() {
+ execSQLDB(t, ctx, db, "use mo_catalog")
+ execSQLDB(t, ctx, db, fmt.Sprintf("drop database if exists `%s`", dbName))
+ }()
+ execSQLDB(t, ctx, db, fmt.Sprintf("use `%s`", dbName))
+
+ execSQLDB(t, ctx, db, fmt.Sprintf("create table %s (a int primary key, b int, c time)", base))
+ execSQLDB(t, ctx, db, fmt.Sprintf("insert into %s select *, *, '12:34:56' from generate_series(1, 8192*100)g", base))
+
+ execSQLDB(t, ctx, db, fmt.Sprintf("data branch create table %s from %s", branch, base))
+
+ execSQLDB(t, ctx, db, fmt.Sprintf("update %s set b = b + 1 where a between 1 and 10000", base))
+ execSQLDB(t, ctx, db, fmt.Sprintf("update %s set b = b + 2 where a between 10000 and 10001", branch))
+ execSQLDB(t, ctx, db, fmt.Sprintf("delete from %s where a between 30000 and 100000", base))
+
+ fullStmt := fmt.Sprintf("data branch diff %s against %s", branch, base)
+ fullRows := fetchDiffRowsAsStrings(t, ctx, db, fullStmt)
+ //require.Equal(t, 30, len(fullRows), fmt.Sprintf("full diff: %v", fullRows))
+ fmt.Println("full diff:", len(fullRows))
+
+ limitQuery := func(cnt int) {
+ limitStmt := fmt.Sprintf("data branch diff %s against %s output limit %d", branch, base, cnt)
+ limitedRows := fetchDiffRowsAsStrings(t, ctx, db, limitStmt)
+
+ require.NotEmpty(t, limitedRows, "limited diff returned no rows")
+ require.LessOrEqual(t, len(limitedRows), cnt, fmt.Sprintf("limited diff returned too many rows: %v", limitedRows))
+
+ fullSet := make(map[string]struct{}, len(fullRows))
+ for _, row := range fullRows {
+ fullSet[strings.Join(row, "||")] = struct{}{}
+ }
+ for _, row := range limitedRows {
+ _, ok := fullSet[strings.Join(row, "||")]
+ require.Truef(t, ok, "limited diff row not contained in full diff: %v", row)
+ }
+ }
+
+ limitQuery(len(fullRows) * 1 / 100)
+ limitQuery(len(fullRows) * 20 / 100)
+}
+
func runDiffOutputToStage(t *testing.T, parentCtx context.Context, db *sql.DB) {
t.Helper()
@@ -678,6 +830,40 @@ func parseSQLStatements(content string) []string {
return stmts
}
+func fetchDiffRowsAsStrings(t *testing.T, ctx context.Context, db *sql.DB, stmt string) [][]string {
+ t.Helper()
+
+ rows, err := db.QueryContext(ctx, stmt)
+ require.NoErrorf(t, err, "sql: %s", stmt)
+ defer rows.Close()
+
+ cols, err := rows.Columns()
+ require.NoError(t, err)
+
+ result := make([][]string, 0, 8)
+ for rows.Next() {
+ raw := make([]sql.RawBytes, len(cols))
+ dest := make([]any, len(cols))
+ for i := range raw {
+ dest[i] = &raw[i]
+ }
+ require.NoError(t, rows.Scan(dest...))
+
+ row := make([]string, len(cols))
+ for i, b := range raw {
+ if b == nil {
+ row[i] = "NULL"
+ continue
+ }
+ row[i] = string(b)
+ }
+ result = append(result, row)
+ }
+ require.NoErrorf(t, rows.Err(), "sql: %s", stmt)
+ require.NotEmpty(t, result, "diff statement returned no rows: %s", stmt)
+ return result
+}
+
func execDiffAndFetchFile(t *testing.T, ctx context.Context, db *sql.DB, stmt string) string {
t.Helper()