diff --git a/pkg/fileservice/aws_sdk_v2.go b/pkg/fileservice/aws_sdk_v2.go index 413a794c0f47a..189f337e95dfc 100644 --- a/pkg/fileservice/aws_sdk_v2.go +++ b/pkg/fileservice/aws_sdk_v2.go @@ -23,7 +23,9 @@ import ( "iter" "math" gotrace "runtime/trace" + "sort" "strings" + "sync" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -191,6 +193,7 @@ func NewAwsSDKv2( } var _ ObjectStorage = new(AwsSDKv2) +var _ ParallelMultipartWriter = new(AwsSDKv2) func (a *AwsSDKv2) List( ctx context.Context, @@ -444,6 +447,267 @@ func (a *AwsSDKv2) Write( return } +func (a *AwsSDKv2) SupportsParallelMultipart() bool { + return true +} + +func (a *AwsSDKv2) WriteMultipartParallel( + ctx context.Context, + key string, + r io.Reader, + sizeHint *int64, + opt *ParallelMultipartOption, +) (err error) { + defer wrapSizeMismatchErr(&err) + + options := normalizeParallelOption(opt) + if sizeHint != nil && *sizeHint < minMultipartPartSize { + return a.Write(ctx, key, r, sizeHint, options.Expire) + } + if sizeHint != nil { + expectedParts := (*sizeHint + options.PartSize - 1) / options.PartSize + if expectedParts > maxMultipartParts { + return moerr.NewInternalErrorNoCtxf("too many parts for multipart upload: %d", expectedParts) + } + } + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + bufPool := sync.Pool{ + New: func() any { + buf := make([]byte, options.PartSize) + return &buf + }, + } + + readChunk := func() (bufPtr *[]byte, buf []byte, n int, err error) { + bufPtr = bufPool.Get().(*[]byte) + raw := *bufPtr + n, err = io.ReadFull(r, raw) + switch { + case errors.Is(err, io.EOF): + bufPool.Put(bufPtr) + return nil, nil, 0, io.EOF + case errors.Is(err, io.ErrUnexpectedEOF): + err = io.EOF + return bufPtr, raw, n, err + case err != nil: + bufPool.Put(bufPtr) + return nil, nil, 0, err + default: + return bufPtr, raw, n, nil + } + } + + firstBufPtr, firstBuf, firstN, err := readChunk() + if err != nil && !errors.Is(err, io.EOF) { + return err + } + if firstN == 0 && errors.Is(err, io.EOF) { + return nil + } + if errors.Is(err, io.EOF) && int64(firstN) < minMultipartPartSize { + data := make([]byte, firstN) + copy(data, firstBuf[:firstN]) + bufPool.Put(firstBufPtr) + size := int64(firstN) + return a.Write(ctx, key, bytes.NewReader(data), &size, options.Expire) + } + + output, createErr := DoWithRetry("create multipart upload", func() (*s3.CreateMultipartUploadOutput, error) { + return a.client.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{ + Bucket: ptrTo(a.bucket), + Key: ptrTo(key), + Expires: options.Expire, + }) + }, maxRetryAttemps, IsRetryableError) + if createErr != nil { + bufPool.Put(firstBufPtr) + return createErr + } + + defer func() { + if err != nil { + _, abortErr := a.client.AbortMultipartUpload(ctx, &s3.AbortMultipartUploadInput{ + Bucket: ptrTo(a.bucket), + Key: ptrTo(key), + UploadId: output.UploadId, + }) + err = errors.Join(err, abortErr) + } + }() + + type partJob struct { + num int32 + buf []byte + bufPtr *[]byte + n int + } + + var ( + partNum int32 + parts []types.CompletedPart + partsLock sync.Mutex + wg sync.WaitGroup + errOnce sync.Once + firstErr error + ) + + setErr := func(e error) { + if e == nil { + return + } + errOnce.Do(func() { + firstErr = e + cancel() + }) + } + + jobCh := make(chan partJob, options.Concurrency*2) + + startWorker := func() error { + wg.Add(1) + return getParallelUploadPool().Submit(func() { + defer wg.Done() + for job := range jobCh { + if ctx.Err() != nil { + if job.bufPtr != nil { + bufPool.Put(job.bufPtr) + } + continue + } + uploadOutput, uploadErr := DoWithRetry("upload part", func() (*s3.UploadPartOutput, error) { + return a.client.UploadPart(ctx, &s3.UploadPartInput{ + Bucket: ptrTo(a.bucket), + Key: ptrTo(key), + PartNumber: &job.num, + UploadId: output.UploadId, + Body: bytes.NewReader(job.buf[:job.n]), + }) + }, maxRetryAttemps, IsRetryableError) + if uploadErr != nil { + setErr(uploadErr) + if job.bufPtr != nil { + bufPool.Put(job.bufPtr) + } + continue + } + if job.bufPtr != nil { + bufPool.Put(job.bufPtr) + } + partsLock.Lock() + parts = append(parts, types.CompletedPart{ + ETag: uploadOutput.ETag, + PartNumber: ptrTo(job.num), + }) + partsLock.Unlock() + } + }) + } + + for i := 0; i < options.Concurrency; i++ { + if submitErr := startWorker(); submitErr != nil { + setErr(submitErr) + break + } + } + + sendJob := func(bufPtr *[]byte, buf []byte, n int) bool { + partNum++ + if partNum > maxMultipartParts { + setErr(moerr.NewInternalErrorNoCtxf("too many parts for multipart upload: %d", partNum)) + if bufPtr != nil { + bufPool.Put(bufPtr) + } + return false + } + job := partJob{ + num: partNum, + buf: buf, + bufPtr: bufPtr, + n: n, + } + select { + case jobCh <- job: + return true + case <-ctx.Done(): + if bufPtr != nil { + bufPool.Put(bufPtr) + } + setErr(ctx.Err()) + return false + } + } + + if !sendJob(firstBufPtr, firstBuf, firstN) { + close(jobCh) + wg.Wait() + if firstErr != nil { + return firstErr + } + return ctx.Err() + } + + for { + nextBufPtr, nextBuf, nextN, readErr := readChunk() + if errors.Is(readErr, io.EOF) && nextN == 0 { + break + } + if readErr != nil && !errors.Is(readErr, io.EOF) { + setErr(readErr) + if nextBufPtr != nil { + bufPool.Put(nextBufPtr) + } + break + } + if nextN == 0 { + if nextBufPtr != nil { + bufPool.Put(nextBufPtr) + } + break + } + if !sendJob(nextBufPtr, nextBuf, nextN) { + break + } + if readErr != nil && errors.Is(readErr, io.EOF) { + break + } + } + + close(jobCh) + wg.Wait() + + if firstErr != nil { + err = firstErr + return err + } + if len(parts) == 0 { + return nil + } + if len(parts) != int(partNum) { + return moerr.NewInternalErrorNoCtxf("multipart upload incomplete, expect %d parts got %d", partNum, len(parts)) + } + + sort.Slice(parts, func(i, j int) bool { + return *parts[i].PartNumber < *parts[j].PartNumber + }) + + _, err = DoWithRetry("complete multipart upload", func() (*s3.CompleteMultipartUploadOutput, error) { + return a.client.CompleteMultipartUpload(ctx, &s3.CompleteMultipartUploadInput{ + Bucket: ptrTo(a.bucket), + Key: ptrTo(key), + UploadId: output.UploadId, + MultipartUpload: &types.CompletedMultipartUpload{Parts: parts}, + }) + }, maxRetryAttemps, IsRetryableError) + if err != nil { + return err + } + + return nil +} + func (a *AwsSDKv2) Read( ctx context.Context, key string, diff --git a/pkg/fileservice/get.go b/pkg/fileservice/get.go index 1e5e6740c9dba..6a9d9ac384236 100644 --- a/pkg/fileservice/get.go +++ b/pkg/fileservice/get.go @@ -53,6 +53,16 @@ func Get[T any](fs FileService, name string) (res T, err error) { var NoDefaultCredentialsForETL = os.Getenv("MO_NO_DEFAULT_CREDENTIALS") != "" +func etlParallelMode(ctx context.Context) ParallelMode { + if mode, ok := parallelModeFromContext(ctx); ok { + return mode + } + if mode, ok := parseParallelMode(strings.TrimSpace(os.Getenv("MO_ETL_PARALLEL_MODE"))); ok { + return mode + } + return ParallelOff +} + // GetForETL get or creates a FileService instance for ETL operations // if service part of path is empty, a LocalETLFS will be created // if service part of path is not empty, a ETLFileService typed instance will be extracted from fs argument @@ -110,6 +120,7 @@ func GetForETL(ctx context.Context, fs FileService, path string) (res ETLFileSer KeySecret: accessSecret, KeyPrefix: keyPrefix, Name: name, + ParallelMode: etlParallelMode(ctx), }, DisabledCacheConfig, nil, @@ -143,6 +154,7 @@ func GetForETL(ctx context.Context, fs FileService, path string) (res ETLFileSer Bucket: bucket, KeyPrefix: keyPrefix, Name: name, + ParallelMode: etlParallelMode(ctx), }, DisabledCacheConfig, nil, @@ -157,6 +169,7 @@ func GetForETL(ctx context.Context, fs FileService, path string) (res ETLFileSer } args.NoBucketValidation = true args.IsHDFS = fsPath.Service == "hdfs" + args.ParallelMode = etlParallelMode(ctx) res, err = NewS3FS( ctx, args, @@ -198,6 +211,7 @@ func GetForETL(ctx context.Context, fs FileService, path string) (res ETLFileSer KeyPrefix: keyPrefix, Name: name, IsMinio: true, + ParallelMode: etlParallelMode(ctx), }, DisabledCacheConfig, nil, diff --git a/pkg/fileservice/get_test.go b/pkg/fileservice/get_test.go index 17a69b9df3d8e..125dc31be0bc1 100644 --- a/pkg/fileservice/get_test.go +++ b/pkg/fileservice/get_test.go @@ -19,6 +19,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "iter" ) func TestGetForBackup(t *testing.T) { @@ -30,3 +31,56 @@ func TestGetForBackup(t *testing.T) { assert.True(t, ok) assert.Equal(t, dir, localFS.rootPath) } + +func TestGetForBackupS3Opts(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + spec := JoinPath("s3-opts,endpoint=disk,bucket="+dir+",prefix=backup-prefix,name=backup", "object") + fs, err := GetForBackup(ctx, spec) + assert.Nil(t, err) + s3fs, ok := fs.(*S3FS) + assert.True(t, ok) + assert.Equal(t, "backup", s3fs.name) + assert.Equal(t, "backup-prefix", s3fs.keyPrefix) +} + +type dummyFileService struct{ name string } + +func (d dummyFileService) Delete(ctx context.Context, filePaths ...string) error { return nil } +func (d dummyFileService) Name() string { return d.name } +func (d dummyFileService) Read(ctx context.Context, vector *IOVector) error { return nil } +func (d dummyFileService) ReadCache(ctx context.Context, vector *IOVector) error { return nil } +func (d dummyFileService) Write(ctx context.Context, vector IOVector) error { return nil } +func (d dummyFileService) List(ctx context.Context, dirPath string) iter.Seq2[*DirEntry, error] { + return func(yield func(*DirEntry, error) bool) { + yield(&DirEntry{Name: "a"}, nil) + } +} +func (d dummyFileService) StatFile(ctx context.Context, filePath string) (*DirEntry, error) { + return &DirEntry{Name: filePath}, nil +} +func (d dummyFileService) PrefetchFile(ctx context.Context, filePath string) error { return nil } +func (d dummyFileService) Cost() *CostAttr { return nil } +func (d dummyFileService) Close(ctx context.Context) {} + +func TestGetFromMappings(t *testing.T) { + fs1 := dummyFileService{name: "first"} + fs2 := dummyFileService{name: "second"} + mapping, err := NewFileServices("first", fs1, fs2) + assert.NoError(t, err) + + var res FileService + res, err = Get[FileService](mapping, "second") + assert.NoError(t, err) + assert.Equal(t, "second", res.Name()) + + _, err = Get[FileService](mapping, "missing") + assert.Error(t, err) + + res, err = Get[FileService](fs1, "first") + assert.NoError(t, err) + assert.Equal(t, "first", res.Name()) + + _, err = Get[FileService](fs1, "other") + assert.Error(t, err) +} diff --git a/pkg/fileservice/object_storage.go b/pkg/fileservice/object_storage.go index 426da383721c9..2e61fb48ab15c 100644 --- a/pkg/fileservice/object_storage.go +++ b/pkg/fileservice/object_storage.go @@ -18,10 +18,63 @@ import ( "context" "io" "iter" + "runtime" + "sync" "time" + + "github.com/panjf2000/ants/v2" ) const smallObjectThreshold = 64 * (1 << 20) +const ( + // defaultParallelMultipartPartSize defines the default per-part size for parallel multipart uploads. + defaultParallelMultipartPartSize = 64 * (1 << 20) + // minMultipartPartSize is the minimum allowed part size for S3-compatible multipart uploads. + minMultipartPartSize = 5 * (1 << 20) + // maxMultipartPartSize is the maximum allowed part size for S3-compatible multipart uploads. + maxMultipartPartSize = 5 * (1 << 30) + // maxMultipartParts is the maximum allowed parts for S3-compatible multipart uploads. + maxMultipartParts = 10000 +) + +var ( + parallelUploadPoolOnce sync.Once + parallelUploadPool *ants.Pool +) + +func getParallelUploadPool() *ants.Pool { + parallelUploadPoolOnce.Do(func() { + pool, err := ants.NewPool(runtime.NumCPU()) + if err != nil { + panic(err) + } + parallelUploadPool = pool + }) + return parallelUploadPool +} + +func normalizeParallelOption(opt *ParallelMultipartOption) ParallelMultipartOption { + res := ParallelMultipartOption{} + if opt != nil { + res = *opt + } + if res.PartSize <= 0 { + res.PartSize = defaultParallelMultipartPartSize + } + if res.PartSize < minMultipartPartSize { + res.PartSize = minMultipartPartSize + } + if res.PartSize > maxMultipartPartSize { + res.PartSize = maxMultipartPartSize + } + if res.Concurrency <= 0 { + res.Concurrency = runtime.NumCPU() + } + if res.Concurrency < 1 { + res.Concurrency = 1 + } + return res +} type ObjectStorage interface { // List lists objects with specified prefix @@ -78,3 +131,25 @@ type ObjectStorage interface { err error, ) } + +// ParallelMultipartWriter is implemented by storages that support parallel multipart uploads. +type ParallelMultipartWriter interface { + SupportsParallelMultipart() bool + WriteMultipartParallel( + ctx context.Context, + key string, + r io.Reader, + sizeHint *int64, + opt *ParallelMultipartOption, + ) error +} + +// ParallelMultipartOption controls part size and parallelism of multipart uploads. +type ParallelMultipartOption struct { + // PartSize configures each part size; defaults to 64MB if zero. + PartSize int64 + // Concurrency configures worker count; defaults to runtime.NumCPU() if zero. + Concurrency int + // Expire sets object expiration. + Expire *time.Time +} diff --git a/pkg/fileservice/object_storage_arguments.go b/pkg/fileservice/object_storage_arguments.go index 5d178a65796f8..f6fde51332aa0 100644 --- a/pkg/fileservice/object_storage_arguments.go +++ b/pkg/fileservice/object_storage_arguments.go @@ -28,13 +28,14 @@ import ( type ObjectStorageArguments struct { // misc - Name string `toml:"name"` - KeyPrefix string `toml:"key-prefix"` - SharedConfigProfile string `toml:"shared-config-profile"` - NoDefaultCredentials bool `toml:"no-default-credentials"` - NoBucketValidation bool `toml:"no-bucket-validation"` - Concurrency int64 `toml:"concurrency"` - MaxConnsPerHost int `toml:"max-conns-per-host"` + Name string `toml:"name"` + KeyPrefix string `toml:"key-prefix"` + SharedConfigProfile string `toml:"shared-config-profile"` + NoDefaultCredentials bool `toml:"no-default-credentials"` + NoBucketValidation bool `toml:"no-bucket-validation"` + Concurrency int64 `toml:"concurrency"` + MaxConnsPerHost int `toml:"max-conns-per-host"` + ParallelMode ParallelMode `toml:"parallel-mode"` // s3 Bucket string `toml:"bucket"` @@ -107,6 +108,10 @@ func (o *ObjectStorageArguments) SetFromString(arguments []string) error { if err == nil { o.MaxConnsPerHost = n } + case "parallel-mode", "parallel": + if mode, ok := parseParallelMode(value); ok { + o.ParallelMode = mode + } case "bucket": o.Bucket = value diff --git a/pkg/fileservice/object_storage_arguments_test.go b/pkg/fileservice/object_storage_arguments_test.go index c8e6eb525e903..d87eee3475065 100644 --- a/pkg/fileservice/object_storage_arguments_test.go +++ b/pkg/fileservice/object_storage_arguments_test.go @@ -185,6 +185,61 @@ func TestAWSRegion(t *testing.T) { assert.NotNil(t, args.validate()) } +func TestSetFromStringParallelMode(t *testing.T) { + var args ObjectStorageArguments + assert.NoError(t, args.SetFromString([]string{"parallel-mode=force"})) + assert.Equal(t, ParallelForce, args.ParallelMode) + + args = ObjectStorageArguments{ + ParallelMode: ParallelAuto, + } + assert.NoError(t, args.SetFromString([]string{"parallel-mode=unknown"})) + assert.Equal(t, ParallelAuto, args.ParallelMode) +} + +func TestObjectStorageArgumentsValidateDefaults(t *testing.T) { + args := ObjectStorageArguments{ + Endpoint: "example.com", + } + assert.NoError(t, args.validate()) + assert.Equal(t, "https://example.com", args.Endpoint) + assert.Equal(t, "mo-service", args.RoleSessionName) +} + +func TestObjectStorageArgumentsShouldLoadDefaultCredentials(t *testing.T) { + t.Setenv("AWS_ACCESS_KEY_ID", "ak") + t.Setenv("AWS_SECRET_ACCESS_KEY", "sk") + args := ObjectStorageArguments{} + assert.True(t, args.shouldLoadDefaultCredentials()) + + args = ObjectStorageArguments{ + NoDefaultCredentials: true, + KeyID: "id", + KeySecret: "secret", + } + assert.False(t, args.shouldLoadDefaultCredentials()) + + args = ObjectStorageArguments{ + NoDefaultCredentials: true, + RoleARN: "arn", + } + assert.True(t, args.shouldLoadDefaultCredentials()) +} + +func TestObjectStorageArgumentsString(t *testing.T) { + args := ObjectStorageArguments{ + Name: "foo", + KeyPrefix: "bar", + Concurrency: 3, + } + s := args.String() + var decoded ObjectStorageArguments + assert.NoError(t, json.Unmarshal([]byte(s), &decoded)) + assert.Equal(t, args.Name, decoded.Name) + assert.Equal(t, args.KeyPrefix, decoded.KeyPrefix) + assert.Equal(t, args.Concurrency, decoded.Concurrency) +} + func TestParseHDFSArgs(t *testing.T) { var args ObjectStorageArguments if err := args.SetFromString([]string{ diff --git a/pkg/fileservice/object_storage_http_trace.go b/pkg/fileservice/object_storage_http_trace.go index 7a7602a1618ad..7c9350957d9b7 100644 --- a/pkg/fileservice/object_storage_http_trace.go +++ b/pkg/fileservice/object_storage_http_trace.go @@ -21,6 +21,7 @@ import ( "net/http/httptrace" "time" + "github.com/matrixorigin/matrixone/pkg/common/moerr" "github.com/matrixorigin/matrixone/pkg/common/reuse" ) @@ -35,6 +36,7 @@ func newObjectStorageHTTPTrace(upstream ObjectStorage) *objectStorageHTTPTrace { } var _ ObjectStorage = new(objectStorageHTTPTrace) +var _ ParallelMultipartWriter = new(objectStorageHTTPTrace) func (o *objectStorageHTTPTrace) Delete(ctx context.Context, keys ...string) (err error) { traceInfo := o.newTraceInfo() @@ -78,6 +80,23 @@ func (o *objectStorageHTTPTrace) Write(ctx context.Context, key string, r io.Rea return o.upstream.Write(ctx, key, r, sizeHint, expire) } +func (o *objectStorageHTTPTrace) SupportsParallelMultipart() bool { + if p, ok := o.upstream.(ParallelMultipartWriter); ok { + return p.SupportsParallelMultipart() + } + return false +} + +func (o *objectStorageHTTPTrace) WriteMultipartParallel(ctx context.Context, key string, r io.Reader, sizeHint *int64, opt *ParallelMultipartOption) (err error) { + traceInfo := o.newTraceInfo() + defer o.closeTraceInfo(traceInfo) + ctx = httptrace.WithClientTrace(ctx, traceInfo.trace) + if p, ok := o.upstream.(ParallelMultipartWriter); ok { + return p.WriteMultipartParallel(ctx, key, r, sizeHint, opt) + } + return moerr.NewNotSupportedNoCtx("parallel multipart upload") +} + func (o *objectStorageHTTPTrace) newTraceInfo() *traceInfo { return reuse.Alloc[traceInfo](nil) } diff --git a/pkg/fileservice/object_storage_http_trace_test.go b/pkg/fileservice/object_storage_http_trace_test.go new file mode 100644 index 0000000000000..789bb6f83a311 --- /dev/null +++ b/pkg/fileservice/object_storage_http_trace_test.go @@ -0,0 +1,78 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fileservice + +import ( + "context" + "net/http/httptrace" + "strings" + "testing" + + "github.com/matrixorigin/matrixone/pkg/common/moerr" + "github.com/stretchr/testify/require" +) + +func TestObjectStorageHTTPTraceWriteMultipartParallel(t *testing.T) { + upstream := &mockParallelObjectStorage{supports: true} + wrapped := newObjectStorageHTTPTrace(upstream) + + err := wrapped.WriteMultipartParallel(context.Background(), "key", strings.NewReader("data"), nil, nil) + require.NoError(t, err) + require.NotNil(t, upstream.ctx) + require.NotNil(t, httptrace.ContextClientTrace(upstream.ctx)) + require.Equal(t, "key", upstream.key) +} + +func TestObjectStorageHTTPTraceWriteMultipartParallelUnsupported(t *testing.T) { + wrapped := newObjectStorageHTTPTrace(dummyObjectStorage{}) + + err := wrapped.WriteMultipartParallel(context.Background(), "key", strings.NewReader("data"), nil, nil) + require.Error(t, err) + require.True(t, moerr.IsMoErrCode(err, moerr.ErrNotSupported)) +} + +func TestObjectStorageHTTPTraceDelegates(t *testing.T) { + upstream := &recordingObjectStorage{} + wrapped := newObjectStorageHTTPTrace(upstream) + ctx := context.Background() + + require.NoError(t, wrapped.Delete(ctx, "a")) + exists, err := wrapped.Exists(ctx, "b") + require.NoError(t, err) + require.True(t, exists) + iterSeq := wrapped.List(ctx, "c") + var listed []string + iterSeq(func(entry *DirEntry, err error) bool { + require.NoError(t, err) + listed = append(listed, entry.Name) + return true + }) + reader, err := wrapped.Read(ctx, "d", nil, nil) + require.NoError(t, err) + defer reader.Close() + buf := make([]byte, 4) + _, _ = reader.Read(buf) + size, err := wrapped.Stat(ctx, "e") + require.NoError(t, err) + require.Equal(t, int64(3), size) + require.NoError(t, wrapped.Write(ctx, "f", strings.NewReader("payload"), nil, nil)) + + require.Len(t, upstream.calls, 6) + for _, ctx := range upstream.ctxs { + require.NotNil(t, httptrace.ContextClientTrace(ctx)) + } + require.ElementsMatch(t, []string{"delete", "exists", "list", "read", "stat", "write"}, upstream.calls) + require.Equal(t, []string{"one"}, listed) +} diff --git a/pkg/fileservice/object_storage_metrics.go b/pkg/fileservice/object_storage_metrics.go index 926b5d885eeaa..b0bce4c67207a 100644 --- a/pkg/fileservice/object_storage_metrics.go +++ b/pkg/fileservice/object_storage_metrics.go @@ -20,6 +20,7 @@ import ( "iter" "time" + "github.com/matrixorigin/matrixone/pkg/common/moerr" metric "github.com/matrixorigin/matrixone/pkg/util/metric/v2" "github.com/prometheus/client_golang/prometheus" ) @@ -55,6 +56,7 @@ func newObjectStorageMetrics( } var _ ObjectStorage = new(objectStorageMetrics) +var _ ParallelMultipartWriter = new(objectStorageMetrics) func (o *objectStorageMetrics) Delete(ctx context.Context, keys ...string) (err error) { o.numDelete.Inc() @@ -99,3 +101,18 @@ func (o *objectStorageMetrics) Write(ctx context.Context, key string, r io.Reade o.numWrite.Inc() return o.upstream.Write(ctx, key, r, sizeHint, expire) } + +func (o *objectStorageMetrics) SupportsParallelMultipart() bool { + if p, ok := o.upstream.(ParallelMultipartWriter); ok { + return p.SupportsParallelMultipart() + } + return false +} + +func (o *objectStorageMetrics) WriteMultipartParallel(ctx context.Context, key string, r io.Reader, sizeHint *int64, opt *ParallelMultipartOption) error { + o.numWrite.Inc() + if p, ok := o.upstream.(ParallelMultipartWriter); ok { + return p.WriteMultipartParallel(ctx, key, r, sizeHint, opt) + } + return moerr.NewNotSupportedNoCtx("parallel multipart upload") +} diff --git a/pkg/fileservice/object_storage_metrics_test.go b/pkg/fileservice/object_storage_metrics_test.go new file mode 100644 index 0000000000000..5a7c99592c7ac --- /dev/null +++ b/pkg/fileservice/object_storage_metrics_test.go @@ -0,0 +1,93 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fileservice + +import ( + "context" + "strings" + "testing" + + "github.com/matrixorigin/matrixone/pkg/common/moerr" + metric "github.com/matrixorigin/matrixone/pkg/util/metric/v2" + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/require" +) + +func TestObjectStorageMetricsWriteMultipartParallel(t *testing.T) { + name := t.Name() + upstream := &mockParallelObjectStorage{supports: true} + wrapped := newObjectStorageMetrics(upstream, name) + gauge := metric.FSObjectStorageOperations.WithLabelValues(name, "write") + before := testutil.ToFloat64(gauge) + + err := wrapped.WriteMultipartParallel(context.Background(), "key", strings.NewReader("data"), nil, nil) + require.NoError(t, err) + require.Equal(t, before+1, testutil.ToFloat64(gauge)) + require.Equal(t, "key", upstream.key) +} + +func TestObjectStorageMetricsWriteMultipartParallelNotSupported(t *testing.T) { + name := t.Name() + wrapped := newObjectStorageMetrics(dummyObjectStorage{}, name) + gauge := metric.FSObjectStorageOperations.WithLabelValues(name, "write") + before := testutil.ToFloat64(gauge) + + err := wrapped.WriteMultipartParallel(context.Background(), "key", strings.NewReader("data"), nil, nil) + require.Error(t, err) + require.True(t, moerr.IsMoErrCode(err, moerr.ErrNotSupported)) + require.Equal(t, before+1, testutil.ToFloat64(gauge)) +} + +func TestObjectStorageMetricsDelegates(t *testing.T) { + name := t.Name() + upstream := &recordingObjectStorage{} + wrapped := newObjectStorageMetrics(upstream, name) + + require.NoError(t, wrapped.Delete(context.Background(), "a")) + _, _ = wrapped.Exists(context.Background(), "b") + seq := wrapped.List(context.Background(), "c") + seq(func(_ *DirEntry, _ error) bool { return true }) + rc, err := wrapped.Read(context.Background(), "d", nil, nil) + require.NoError(t, err) + require.NoError(t, rc.Close()) + _, _ = wrapped.Stat(context.Background(), "e") + require.NoError(t, wrapped.Write(context.Background(), "f", strings.NewReader("x"), nil, nil)) + + require.ElementsMatch(t, []string{ + "delete", "exists", "list", "read", "stat", "write", + }, upstream.calls) + + // gauges incremented + require.True(t, testutil.ToFloat64(metric.FSObjectStorageOperations.WithLabelValues(name, "delete")) >= 1) + require.True(t, testutil.ToFloat64(metric.FSObjectStorageOperations.WithLabelValues(name, "exists")) >= 1) + require.True(t, testutil.ToFloat64(metric.FSObjectStorageOperations.WithLabelValues(name, "list")) >= 1) + require.True(t, testutil.ToFloat64(metric.FSObjectStorageOperations.WithLabelValues(name, "read")) >= 1) + require.True(t, testutil.ToFloat64(metric.FSObjectStorageOperations.WithLabelValues(name, "stat")) >= 1) + require.True(t, testutil.ToFloat64(metric.FSObjectStorageOperations.WithLabelValues(name, "write")) >= 1) +} + +func TestObjectStorageMetricsReadCloseDecrementsActive(t *testing.T) { + name := t.Name() + upstream := &recordingObjectStorage{} + wrapped := newObjectStorageMetrics(upstream, name) + active := metric.FSObjectStorageOperations.WithLabelValues(name, "active-read") + before := testutil.ToFloat64(active) + + r, err := wrapped.Read(context.Background(), "key", nil, nil) + require.NoError(t, err) + require.Equal(t, before+1, testutil.ToFloat64(active)) + require.NoError(t, r.Close()) + require.Equal(t, before, testutil.ToFloat64(active)) +} diff --git a/pkg/fileservice/object_storage_semaphore.go b/pkg/fileservice/object_storage_semaphore.go index 243e1eaf7c805..a6d4768af15d0 100644 --- a/pkg/fileservice/object_storage_semaphore.go +++ b/pkg/fileservice/object_storage_semaphore.go @@ -16,6 +16,7 @@ package fileservice import ( "context" + "github.com/matrixorigin/matrixone/pkg/common/moerr" "io" "iter" "sync" @@ -46,6 +47,7 @@ func (o *objectStorageSemaphore) release() { } var _ ObjectStorage = new(objectStorageSemaphore) +var _ ParallelMultipartWriter = new(objectStorageSemaphore) func (o *objectStorageSemaphore) Delete(ctx context.Context, keys ...string) (err error) { o.acquire() @@ -107,3 +109,19 @@ func (o *objectStorageSemaphore) Write(ctx context.Context, key string, r io.Rea defer o.release() return o.upstream.Write(ctx, key, r, sizeHint, expire) } + +func (o *objectStorageSemaphore) SupportsParallelMultipart() bool { + if p, ok := o.upstream.(ParallelMultipartWriter); ok { + return p.SupportsParallelMultipart() + } + return false +} + +func (o *objectStorageSemaphore) WriteMultipartParallel(ctx context.Context, key string, r io.Reader, sizeHint *int64, opt *ParallelMultipartOption) (err error) { + o.acquire() + defer o.release() + if p, ok := o.upstream.(ParallelMultipartWriter); ok { + return p.WriteMultipartParallel(ctx, key, r, sizeHint, opt) + } + return moerr.NewNotSupportedNoCtx("parallel multipart upload") +} diff --git a/pkg/fileservice/object_storage_semaphore_test.go b/pkg/fileservice/object_storage_semaphore_test.go new file mode 100644 index 0000000000000..c9adcae8919d9 --- /dev/null +++ b/pkg/fileservice/object_storage_semaphore_test.go @@ -0,0 +1,88 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fileservice + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestObjectStorageSemaphoreSerializes(t *testing.T) { + start := make(chan struct{}, 2) + wait := make(chan struct{}) + upstream := &blockingObjectStorage{ + start: start, + wait: wait, + } + sem := newObjectStorageSemaphore(upstream, 1) + + done := make(chan struct{}) + go func() { + require.NoError(t, sem.Write(context.Background(), "a", nil, nil, nil)) + close(done) + }() + + select { + case <-start: + case <-time.After(time.Second): + t.Fatal("first write did not start") + } + + startSecond := make(chan struct{}) + go func() { + defer close(startSecond) + require.NoError(t, sem.Write(context.Background(), "b", nil, nil, nil)) + }() + + select { + case <-startSecond: + t.Fatal("second write started before release") + case <-time.After(50 * time.Millisecond): + } + + close(wait) // release first + select { + case <-startSecond: + case <-time.After(time.Second): + t.Fatal("second write not started after release") + } + <-done +} + +func TestObjectStorageSemaphoreReleasesOnError(t *testing.T) { + start := make(chan struct{}, 1) + wait := make(chan struct{}) + upstream := &blockingObjectStorage{ + start: start, + wait: wait, + err: context.DeadlineExceeded, + } + sem := newObjectStorageSemaphore(upstream, 1) + + // release the blocked write once it has started + go func() { + <-start + close(wait) + }() + + err := sem.Write(context.Background(), "a", nil, nil, nil) + require.Error(t, err) + + // another call should proceed after the failed one + require.NoError(t, sem.Delete(context.Background(), "x")) +} diff --git a/pkg/fileservice/object_storage_test_helper_test.go b/pkg/fileservice/object_storage_test_helper_test.go new file mode 100644 index 0000000000000..3ddfb8687cda3 --- /dev/null +++ b/pkg/fileservice/object_storage_test_helper_test.go @@ -0,0 +1,130 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fileservice + +import ( + "context" + "io" + "iter" + "strings" + "time" +) + +type dummyObjectStorage struct{} + +func (dummyObjectStorage) Delete(ctx context.Context, keys ...string) error { + return nil +} + +func (dummyObjectStorage) Exists(ctx context.Context, key string) (bool, error) { + return false, nil +} + +func (dummyObjectStorage) List(ctx context.Context, prefix string) iter.Seq2[*DirEntry, error] { + return func(yield func(*DirEntry, error) bool) {} +} + +func (dummyObjectStorage) Read(ctx context.Context, key string, min *int64, max *int64) (io.ReadCloser, error) { + return io.NopCloser(strings.NewReader("")), nil +} + +func (dummyObjectStorage) Stat(ctx context.Context, key string) (int64, error) { + return 0, nil +} + +func (dummyObjectStorage) Write(ctx context.Context, key string, r io.Reader, sizeHint *int64, expire *time.Time) error { + _, _ = io.Copy(io.Discard, r) + return nil +} + +type mockParallelObjectStorage struct { + dummyObjectStorage + ctx context.Context + key string + sizeHint *int64 + opt *ParallelMultipartOption + err error + supports bool +} + +func (m *mockParallelObjectStorage) SupportsParallelMultipart() bool { + return m.supports +} + +func (m *mockParallelObjectStorage) WriteMultipartParallel(ctx context.Context, key string, r io.Reader, sizeHint *int64, opt *ParallelMultipartOption) error { + m.ctx = ctx + m.key = key + m.sizeHint = sizeHint + m.opt = opt + return m.err +} + +type recordingObjectStorage struct { + calls []string + ctxs []context.Context +} + +func (r *recordingObjectStorage) record(ctx context.Context, name string) { + r.calls = append(r.calls, name) + r.ctxs = append(r.ctxs, ctx) +} + +func (r *recordingObjectStorage) Delete(ctx context.Context, keys ...string) error { + r.record(ctx, "delete") + return nil +} + +func (r *recordingObjectStorage) Exists(ctx context.Context, key string) (bool, error) { + r.record(ctx, "exists") + return true, nil +} + +func (r *recordingObjectStorage) List(ctx context.Context, prefix string) iter.Seq2[*DirEntry, error] { + r.record(ctx, "list") + return func(yield func(*DirEntry, error) bool) { + yield(&DirEntry{ + Name: "one", + }, nil) + } +} + +func (r *recordingObjectStorage) Read(ctx context.Context, key string, min *int64, max *int64) (io.ReadCloser, error) { + r.record(ctx, "read") + return io.NopCloser(strings.NewReader("data")), nil +} + +func (r *recordingObjectStorage) Stat(ctx context.Context, key string) (int64, error) { + r.record(ctx, "stat") + return 3, nil +} + +func (r *recordingObjectStorage) Write(ctx context.Context, key string, rd io.Reader, sizeHint *int64, expire *time.Time) error { + r.record(ctx, "write") + _, _ = io.Copy(io.Discard, rd) + return nil +} + +type blockingObjectStorage struct { + dummyObjectStorage + start chan struct{} + wait chan struct{} + err error +} + +func (b *blockingObjectStorage) Write(ctx context.Context, key string, rd io.Reader, sizeHint *int64, expire *time.Time) error { + b.start <- struct{}{} + <-b.wait + return b.err +} diff --git a/pkg/fileservice/parallel_mode.go b/pkg/fileservice/parallel_mode.go new file mode 100644 index 0000000000000..6201ea0811c5d --- /dev/null +++ b/pkg/fileservice/parallel_mode.go @@ -0,0 +1,62 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fileservice + +import ( + "context" + "strings" +) + +// ParallelMode controls when multipart parallel uploads are used. +type ParallelMode uint8 + +const ( + ParallelOff ParallelMode = iota + ParallelAuto + ParallelForce +) + +func parseParallelMode(s string) (ParallelMode, bool) { + switch strings.ToLower(s) { + case "off", "false", "0", "": + return ParallelOff, true + case "auto": + return ParallelAuto, true + case "force", "on", "true", "1": + return ParallelForce, true + default: + return ParallelOff, false + } +} + +type parallelModeKey struct{} + +// WithParallelMode sets a per-call parallel mode override on context. +func WithParallelMode(ctx context.Context, mode ParallelMode) context.Context { + return context.WithValue(ctx, parallelModeKey{}, mode) +} + +// parallelModeFromContext retrieves a parallel mode override if present. +func parallelModeFromContext(ctx context.Context) (ParallelMode, bool) { + if ctx == nil { + return ParallelOff, false + } + if v := ctx.Value(parallelModeKey{}); v != nil { + if mode, ok := v.(ParallelMode); ok { + return mode, true + } + } + return ParallelOff, false +} diff --git a/pkg/fileservice/parallel_mode_test.go b/pkg/fileservice/parallel_mode_test.go new file mode 100644 index 0000000000000..ae9d8b345c9f6 --- /dev/null +++ b/pkg/fileservice/parallel_mode_test.go @@ -0,0 +1,74 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fileservice + +import ( + "context" + "testing" +) + +func TestParseParallelModeVariants(t *testing.T) { + cases := []struct { + in string + expect ParallelMode + ok bool + }{ + {"", ParallelOff, true}, + {"off", ParallelOff, true}, + {"false", ParallelOff, true}, + {"0", ParallelOff, true}, + {"auto", ParallelAuto, true}, + {"force", ParallelForce, true}, + {"on", ParallelForce, true}, + {"1", ParallelForce, true}, + {"xxx", ParallelOff, false}, + } + + for _, c := range cases { + mode, ok := parseParallelMode(c.in) + if mode != c.expect || ok != c.ok { + t.Fatalf("input %s got (%v,%v) expect (%v,%v)", c.in, mode, ok, c.expect, c.ok) + } + } +} + +func TestWithParallelModeOnContext(t *testing.T) { + base := context.Background() + ctx := WithParallelMode(base, ParallelForce) + mode, ok := parallelModeFromContext(ctx) + if !ok || mode != ParallelForce { + t.Fatalf("expected force override, got %v %v", mode, ok) + } + if _, ok := parallelModeFromContext(base); ok { + t.Fatalf("unexpected mode on base context") + } +} + +func TestETLParallelModePriority(t *testing.T) { + t.Setenv("MO_ETL_PARALLEL_MODE", "off") + if m := etlParallelMode(context.Background()); m != ParallelOff { + t.Fatalf("expect off from env, got %v", m) + } + + ctx := WithParallelMode(context.Background(), ParallelForce) + if m := etlParallelMode(ctx); m != ParallelForce { + t.Fatalf("ctx override should win, got %v", m) + } + + t.Setenv("MO_ETL_PARALLEL_MODE", "auto") + if m := etlParallelMode(context.Background()); m != ParallelAuto { + t.Fatalf("env auto not applied, got %v", m) + } +} diff --git a/pkg/fileservice/parallel_sdk_test.go b/pkg/fileservice/parallel_sdk_test.go new file mode 100644 index 0000000000000..54edf6511c451 --- /dev/null +++ b/pkg/fileservice/parallel_sdk_test.go @@ -0,0 +1,552 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fileservice + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/s3" + costypes "github.com/tencentyun/cos-go-sdk-v5" +) + +func newMockAWSServer(t *testing.T, failPart int32) (*httptest.Server, *awsServerState) { + t.Helper() + state := &awsServerState{ + failPart: failPart, + parts: make(map[int32][]byte), + } + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodPost && strings.Contains(r.URL.RawQuery, "uploads"): + if state.failCreate { + w.WriteHeader(http.StatusInternalServerError) + return + } + _, _ = io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/xml") + _, _ = fmt.Fprintf(w, `%s`, state.uploadID) + case r.Method == http.MethodPut && !strings.Contains(r.URL.RawQuery, "partNumber") && !strings.Contains(r.URL.RawQuery, "uploadId") && !strings.Contains(r.URL.RawQuery, "uploads"): + body, _ := io.ReadAll(r.Body) + state.mu.Lock() + state.putCount++ + state.putBody = append([]byte{}, body...) + state.mu.Unlock() + w.WriteHeader(http.StatusOK) + case r.Method == http.MethodPut && strings.Contains(r.URL.RawQuery, "partNumber"): + partStr := r.URL.Query().Get("partNumber") + pn, _ := strconv.Atoi(partStr) + body, _ := io.ReadAll(r.Body) + if state.failPart > 0 && int32(pn) == state.failPart { + w.WriteHeader(http.StatusInternalServerError) + return + } + state.mu.Lock() + state.parts[int32(pn)] = body + state.mu.Unlock() + w.Header().Set("ETag", fmt.Sprintf("\"etag-%d\"", pn)) + case r.Method == http.MethodPost && strings.Contains(r.URL.RawQuery, "uploadId"): + body, _ := io.ReadAll(r.Body) + if state.failComplete { + w.WriteHeader(http.StatusInternalServerError) + return + } + state.mu.Lock() + state.completeBody = append([]byte{}, body...) + state.mu.Unlock() + w.Header().Set("Content-Type", "application/xml") + _, _ = w.Write([]byte(`locbucketobject"etag"`)) + case r.Method == http.MethodDelete && strings.Contains(r.URL.RawQuery, "uploadId"): + state.aborted.Store(true) + w.WriteHeader(http.StatusOK) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + return httptest.NewServer(handler), state +} + +type awsServerState struct { + mu sync.Mutex + parts map[int32][]byte + completeBody []byte + failPart int32 + failComplete bool + failCreate bool + uploadID string + aborted atomic.Bool + putCount int + putBody []byte +} + +func newTestAWSClient(t *testing.T, srv *httptest.Server) *AwsSDKv2 { + t.Helper() + cfg := aws.Config{ + Region: "us-east-1", + Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider("id", "key", "")), + HTTPClient: srv.Client(), + Retryer: func() aws.Retryer { + return aws.NopRetryer{} + }, + } + client := s3.NewFromConfig(cfg, func(o *s3.Options) { + o.UsePathStyle = true + o.BaseEndpoint = aws.String(srv.URL) + }) + + return &AwsSDKv2{ + name: "aws-test", + bucket: "bucket", + client: client, + } +} + +func TestAwsParallelMultipartSuccess(t *testing.T) { + server, state := newMockAWSServer(t, 0) + defer server.Close() + state.uploadID = "uid-success" + + sdk := newTestAWSClient(t, server) + + data := bytes.Repeat([]byte("a"), int(minMultipartPartSize+1)) + size := int64(len(data)) + err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(data), &size, &ParallelMultipartOption{ + PartSize: minMultipartPartSize, + Concurrency: 2, + }) + if err != nil { + t.Fatalf("write failed: %v, parts=%d, complete=%s", err, len(state.parts), string(state.completeBody)) + } + if len(state.parts) != 2 { + t.Fatalf("expected 2 parts, got %d", len(state.parts)) + } + if len(state.completeBody) == 0 { + t.Fatalf("complete body not recorded") + } +} + +func TestAwsParallelMultipartAbortOnError(t *testing.T) { + server, state := newMockAWSServer(t, 0) + defer server.Close() + state.uploadID = "uid-fail" + state.failComplete = true + + sdk := newTestAWSClient(t, server) + + data := bytes.Repeat([]byte("b"), int(minMultipartPartSize*2)) + size := int64(len(data)) + err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(data), &size, &ParallelMultipartOption{ + PartSize: minMultipartPartSize, + Concurrency: 2, + }) + if err == nil { + t.Fatalf("expected error") + } + if !state.aborted.Load() { + t.Fatalf("expected abort request to be sent") + } +} + +func TestAwsMultipartFallbackSmallSize(t *testing.T) { + server, state := newMockAWSServer(t, 0) + defer server.Close() + state.uploadID = "uid-small" + + sdk := newTestAWSClient(t, server) + data := bytes.Repeat([]byte("x"), int(minMultipartPartSize-1)) + size := int64(len(data)) + if err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(data), &size, nil); err != nil { + t.Fatalf("write failed: %v", err) + } + if state.putCount != 1 { + t.Fatalf("expected single PUT fallback, got %d", state.putCount) + } + if string(state.putBody) != string(data) { + t.Fatalf("unexpected put body") + } +} + +func TestAwsMultipartFallbackUnknownSmall(t *testing.T) { + server, state := newMockAWSServer(t, 0) + defer server.Close() + state.uploadID = "uid-unknown" + + sdk := newTestAWSClient(t, server) + data := bytes.Repeat([]byte("y"), int(minMultipartPartSize-1)) + reader := bytes.NewReader(data) + if err := sdk.WriteMultipartParallel(context.Background(), "object", reader, nil, nil); err != nil { + t.Fatalf("write failed: %v", err) + } + if state.putCount != 1 { + t.Fatalf("expected fallback PUT for unknown small") + } + if string(state.putBody) != string(data) { + t.Fatalf("unexpected put body") + } +} + +func TestAwsMultipartTooManyParts(t *testing.T) { + server, _ := newMockAWSServer(t, 0) + defer server.Close() + + sdk := newTestAWSClient(t, server) + huge := int64(maxMultipartParts+1) * defaultParallelMultipartPartSize + err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(nil), &huge, &ParallelMultipartOption{ + PartSize: defaultParallelMultipartPartSize, + }) + if err == nil { + t.Fatalf("expected error for too many parts") + } +} + +func TestAwsMultipartUploadPartError(t *testing.T) { + server, _ := newMockAWSServer(t, 1) + defer server.Close() + + sdk := newTestAWSClient(t, server) + data := bytes.Repeat([]byte("p"), int(minMultipartPartSize*2)) + size := int64(len(data)) + if err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(data), &size, &ParallelMultipartOption{ + PartSize: minMultipartPartSize, + Concurrency: 2, + }); err == nil { + t.Fatalf("expected upload part error") + } +} + +func TestAwsParallelMultipartUnknownSize(t *testing.T) { + server, state := newMockAWSServer(t, 0) + defer server.Close() + state.uploadID = "uid-unknown-size" + + sdk := newTestAWSClient(t, server) + data := bytes.Repeat([]byte("z"), int(minMultipartPartSize+1)) + if err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(data), nil, &ParallelMultipartOption{ + PartSize: minMultipartPartSize, + Concurrency: 2, + }); err != nil { + t.Fatalf("write failed: %v", err) + } + if len(state.parts) != 2 { + t.Fatalf("expected multipart upload with unknown size") + } +} + +func TestAwsMultipartEmptyReader(t *testing.T) { + server, state := newMockAWSServer(t, 0) + defer server.Close() + + sdk := newTestAWSClient(t, server) + if err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(nil), nil, nil); err != nil { + t.Fatalf("expected nil error for empty reader, got %v", err) + } + if state.putCount != 0 && len(state.parts) != 0 { + t.Fatalf("no upload should happen for empty reader") + } +} + +func TestAwsMultipartContextCanceled(t *testing.T) { + server, _ := newMockAWSServer(t, 0) + defer server.Close() + + sdk := newTestAWSClient(t, server) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if err := sdk.WriteMultipartParallel(ctx, "object", bytes.NewReader([]byte("data")), nil, nil); err == nil { + t.Fatalf("expected context canceled error") + } +} + +func TestAwsMultipartCreateFail(t *testing.T) { + server, state := newMockAWSServer(t, 0) + defer server.Close() + state.failCreate = true + + sdk := newTestAWSClient(t, server) + data := bytes.Repeat([]byte("i"), int(minMultipartPartSize+1)) + size := int64(len(data)) + if err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(data), &size, nil); err == nil { + t.Fatalf("expected create multipart error") + } +} + +func newMockCOSServer(t *testing.T, failPart int) (*httptest.Server, *cosServerState) { + t.Helper() + state := &cosServerState{ + failPart: failPart, + parts: make(map[int][]byte), + } + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodPost && strings.Contains(r.URL.RawQuery, "uploads"): + if state.failCreate { + w.WriteHeader(http.StatusInternalServerError) + return + } + _, _ = io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/xml") + _, _ = fmt.Fprintf(w, `%s`, state.uploadID) + case r.Method == http.MethodPut && !strings.Contains(r.URL.RawQuery, "partNumber") && !strings.Contains(r.URL.RawQuery, "uploadId") && !strings.Contains(r.URL.RawQuery, "uploads"): + body, _ := io.ReadAll(r.Body) + state.mu.Lock() + state.putCount++ + state.putBody = append([]byte{}, body...) + state.mu.Unlock() + w.WriteHeader(http.StatusOK) + case r.Method == http.MethodPut && strings.Contains(r.URL.RawQuery, "partNumber"): + partStr := r.URL.Query().Get("partNumber") + pn, _ := strconv.Atoi(partStr) + body, _ := io.ReadAll(r.Body) + if state.failPart > 0 && pn == state.failPart { + w.WriteHeader(http.StatusInternalServerError) + return + } + state.mu.Lock() + state.parts[pn] = body + state.mu.Unlock() + case r.Method == http.MethodPost && strings.Contains(r.URL.RawQuery, "uploadId"): + body, _ := io.ReadAll(r.Body) + if state.failComplete { + w.WriteHeader(http.StatusInternalServerError) + return + } + state.mu.Lock() + state.completeBody = append([]byte{}, body...) + state.mu.Unlock() + w.Header().Set("Content-Type", "application/xml") + state.respBody = `locbucketobjectetag` + state.completed.Store(true) + _, _ = w.Write([]byte(state.respBody)) + case r.Method == http.MethodDelete && strings.Contains(r.URL.RawQuery, "uploadId"): + state.aborted.Store(true) + w.WriteHeader(http.StatusOK) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + return httptest.NewServer(handler), state +} + +type cosServerState struct { + mu sync.Mutex + parts map[int][]byte + completeBody []byte + failPart int + failComplete bool + failCreate bool + uploadID string + aborted atomic.Bool + completed atomic.Bool + respBody string + putCount int + putBody []byte +} + +func newTestCOSClient(t *testing.T, srv *httptest.Server) *QCloudSDK { + t.Helper() + baseURL, err := url.Parse(srv.URL) + if err != nil { + t.Fatalf("parse url: %v", err) + } + + client := costypes.NewClient( + &costypes.BaseURL{BucketURL: baseURL}, + srv.Client(), + ) + client.Conf.EnableCRC = false + + return &QCloudSDK{ + name: "cos-test", + client: client, + } +} + +func TestCOSParallelMultipartSuccess(t *testing.T) { + server, state := newMockCOSServer(t, 0) + defer server.Close() + state.uploadID = "cos-uid" + + sdk := newTestCOSClient(t, server) + data := bytes.Repeat([]byte("c"), int(minMultipartPartSize+2)) + size := int64(len(data)) + err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(data), &size, &ParallelMultipartOption{ + PartSize: minMultipartPartSize, + Concurrency: 2, + }) + if err != nil { + t.Fatalf("write failed: %v, parts=%d, complete=%s", err, len(state.parts), string(state.completeBody)) + } + if len(state.parts) != 2 { + t.Fatalf("expected 2 parts, got %d", len(state.parts)) + } + if len(state.completeBody) == 0 { + t.Fatalf("complete body not recorded") + } +} + +func TestCOSParallelMultipartAbortOnError(t *testing.T) { + server, state := newMockCOSServer(t, 0) + defer server.Close() + state.uploadID = "cos-uid-fail" + state.failComplete = true + + sdk := newTestCOSClient(t, server) + data := bytes.Repeat([]byte("d"), int(minMultipartPartSize*2)) + size := int64(len(data)) + err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(data), &size, &ParallelMultipartOption{ + PartSize: minMultipartPartSize, + Concurrency: 2, + }) + if err == nil { + t.Fatalf("expected error") + } + if !state.aborted.Load() { + t.Fatalf("expected abort request") + } +} + +func TestCOSMultipartFallbackSmallSize(t *testing.T) { + server, state := newMockCOSServer(t, 0) + defer server.Close() + state.uploadID = "cos-small" + + sdk := newTestCOSClient(t, server) + data := bytes.Repeat([]byte("e"), int(minMultipartPartSize-1)) + size := int64(len(data)) + if err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(data), &size, nil); err != nil { + t.Fatalf("write failed: %v", err) + } + if state.putCount != 1 { + t.Fatalf("expected single PUT fallback, got %d", state.putCount) + } + if string(state.putBody) != string(data) { + t.Fatalf("unexpected put body") + } +} + +func TestCOSMultipartFallbackUnknownSmall(t *testing.T) { + server, state := newMockCOSServer(t, 0) + defer server.Close() + state.uploadID = "cos-unknown" + + sdk := newTestCOSClient(t, server) + data := bytes.Repeat([]byte("f"), int(minMultipartPartSize-1)) + if err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(data), nil, nil); err != nil { + t.Fatalf("write failed: %v", err) + } + if state.putCount != 1 { + t.Fatalf("expected fallback PUT for unknown small") + } + if string(state.putBody) != string(data) { + t.Fatalf("unexpected put body") + } +} + +func TestCOSMultipartTooManyParts(t *testing.T) { + server, _ := newMockCOSServer(t, 0) + defer server.Close() + + sdk := newTestCOSClient(t, server) + huge := int64(maxMultipartParts+1) * defaultParallelMultipartPartSize + err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(nil), &huge, &ParallelMultipartOption{ + PartSize: defaultParallelMultipartPartSize, + }) + if err == nil { + t.Fatalf("expected error for too many parts") + } +} + +func TestCOSMultipartUploadPartError(t *testing.T) { + server, _ := newMockCOSServer(t, 1) + defer server.Close() + + sdk := newTestCOSClient(t, server) + data := bytes.Repeat([]byte("h"), int(minMultipartPartSize*2)) + size := int64(len(data)) + if err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(data), &size, &ParallelMultipartOption{ + PartSize: minMultipartPartSize, + Concurrency: 2, + }); err == nil { + t.Fatalf("expected upload part error") + } +} + +func TestCOSParallelMultipartUnknownSize(t *testing.T) { + server, state := newMockCOSServer(t, 0) + defer server.Close() + state.uploadID = "cos-unknown-size" + + sdk := newTestCOSClient(t, server) + data := bytes.Repeat([]byte("g"), int(minMultipartPartSize+1)) + if err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(data), nil, &ParallelMultipartOption{ + PartSize: minMultipartPartSize, + Concurrency: 2, + }); err != nil { + t.Fatalf("write failed: %v", err) + } + if len(state.parts) != 2 { + t.Fatalf("expected multipart upload with unknown size") + } +} + +func TestCOSMultipartEmptyReader(t *testing.T) { + server, state := newMockCOSServer(t, 0) + defer server.Close() + + sdk := newTestCOSClient(t, server) + if err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(nil), nil, nil); err != nil { + t.Fatalf("expected nil error for empty reader, got %v", err) + } + if state.putCount != 0 && len(state.parts) != 0 { + t.Fatalf("no upload should happen for empty reader") + } +} + +func TestCOSMultipartContextCanceled(t *testing.T) { + server, _ := newMockCOSServer(t, 0) + defer server.Close() + + sdk := newTestCOSClient(t, server) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if err := sdk.WriteMultipartParallel(ctx, "object", bytes.NewReader([]byte("data")), nil, nil); err == nil { + t.Fatalf("expected context canceled error") + } +} + +func TestCOSMultipartCreateFail(t *testing.T) { + server, state := newMockCOSServer(t, 0) + defer server.Close() + state.failCreate = true + + sdk := newTestCOSClient(t, server) + data := bytes.Repeat([]byte("j"), int(minMultipartPartSize+1)) + size := int64(len(data)) + if err := sdk.WriteMultipartParallel(context.Background(), "object", bytes.NewReader(data), &size, nil); err == nil { + t.Fatalf("expected create multipart error") + } +} diff --git a/pkg/fileservice/parallel_upload_test.go b/pkg/fileservice/parallel_upload_test.go new file mode 100644 index 0000000000000..e3e3fb435931b --- /dev/null +++ b/pkg/fileservice/parallel_upload_test.go @@ -0,0 +1,511 @@ +// Copyright 2025 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fileservice + +import ( + "bytes" + "context" + "io" + "iter" + "runtime" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestNormalizeParallelOption(t *testing.T) { + cases := []struct { + name string + opt *ParallelMultipartOption + expectSize int64 + expectConc int + expectExpire bool + }{ + { + name: "default values", + opt: nil, + expectSize: defaultParallelMultipartPartSize, + expectConc: runtime.NumCPU(), + }, + { + name: "clamp below min", + opt: &ParallelMultipartOption{ + PartSize: minMultipartPartSize - 1, + }, + expectSize: minMultipartPartSize, + expectConc: runtime.NumCPU(), + }, + { + name: "clamp above max", + opt: &ParallelMultipartOption{ + PartSize: maxMultipartPartSize + 1, + }, + expectSize: maxMultipartPartSize, + expectConc: runtime.NumCPU(), + }, + { + name: "custom values", + opt: &ParallelMultipartOption{ + PartSize: 16 << 20, + Concurrency: 3, + }, + expectSize: 16 << 20, + expectConc: 3, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + res := normalizeParallelOption(c.opt) + if res.PartSize != c.expectSize { + t.Fatalf("part size mismatch, got %d expect %d", res.PartSize, c.expectSize) + } + if res.Concurrency != c.expectConc { + t.Fatalf("concurrency mismatch, got %d expect %d", res.Concurrency, c.expectConc) + } + }) + } +} + +func TestGetParallelUploadPoolSingleton(t *testing.T) { + p1 := getParallelUploadPool() + p2 := getParallelUploadPool() + if p1 == nil || p2 == nil { + t.Fatalf("pool not initialized") + } + if p1 != p2 { + t.Fatalf("expected singleton pool") + } +} + +type mockParallelStorage struct { + supports bool + exists bool + + writeCalled int + mpCalled int + + lastKey string + lastData []byte + lastSize *int64 + lastOpt ParallelMultipartOption + lastExpire *time.Time +} + +func (m *mockParallelStorage) List(ctx context.Context, prefix string) iter.Seq2[*DirEntry, error] { + return func(yield func(*DirEntry, error) bool) {} +} + +func (m *mockParallelStorage) Stat(ctx context.Context, key string) (int64, error) { + return 0, nil +} + +func (m *mockParallelStorage) Exists(ctx context.Context, key string) (bool, error) { + return m.exists, nil +} + +func (m *mockParallelStorage) Write(ctx context.Context, key string, r io.Reader, sizeHint *int64, expire *time.Time) error { + m.writeCalled++ + m.lastKey = key + m.lastSize = sizeHint + m.lastExpire = expire + b, _ := io.ReadAll(r) + m.lastData = b + return nil +} + +func (m *mockParallelStorage) Read(ctx context.Context, key string, min *int64, max *int64) (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(nil)), nil +} + +func (m *mockParallelStorage) Delete(ctx context.Context, keys ...string) error { + return nil +} + +func (m *mockParallelStorage) SupportsParallelMultipart() bool { + return m.supports +} + +func (m *mockParallelStorage) WriteMultipartParallel(ctx context.Context, key string, r io.Reader, sizeHint *int64, opt *ParallelMultipartOption) error { + m.mpCalled++ + m.lastKey = key + if opt != nil { + m.lastOpt = *opt + } + m.lastSize = sizeHint + b, _ := io.ReadAll(r) + m.lastData = b + return nil +} + +func TestS3FSWriteUsesParallelWhenForced(t *testing.T) { + storage := &mockParallelStorage{supports: true} + fs := &S3FS{ + name: "s3", + storage: storage, + ioMerger: NewIOMerger(), + asyncUpdate: true, + parallelMode: ParallelForce, + } + + data := []byte("hello") + vector := IOVector{ + FilePath: "obj", + Entries: []IOEntry{ + {Offset: 0, Size: int64(len(data)), Data: data}, + }, + } + + if err := fs.Write(context.Background(), vector); err != nil { + t.Fatalf("write failed: %v", err) + } + if storage.mpCalled != 1 { + t.Fatalf("expected parallel write, got %d", storage.mpCalled) + } + if storage.writeCalled != 0 { + t.Fatalf("unexpected non-parallel write") + } + if storage.lastKey != "obj" { + t.Fatalf("unexpected key %s", storage.lastKey) + } + if string(storage.lastData) != "hello" { + t.Fatalf("unexpected data %s", string(storage.lastData)) + } +} + +func TestS3FSWriteDisableParallel(t *testing.T) { + storage := &mockParallelStorage{supports: true} + fs := &S3FS{ + name: "s3", + storage: storage, + ioMerger: NewIOMerger(), + asyncUpdate: true, + parallelMode: ParallelOff, + } + + data := bytes.Repeat([]byte("a"), int(minMultipartPartSize+1)) + vector := IOVector{ + FilePath: "large", + Entries: []IOEntry{ + {Offset: 0, Size: int64(len(data)), Data: data}, + }, + } + + if err := fs.Write(context.Background(), vector); err != nil { + t.Fatalf("write failed: %v", err) + } + if storage.mpCalled != 0 { + t.Fatalf("parallel write should be disabled") + } + if storage.writeCalled != 1 { + t.Fatalf("expected single write, got %d", storage.writeCalled) + } + if storage.lastKey != "large" { + t.Fatalf("unexpected key %s", storage.lastKey) + } +} + +func TestS3FSWriteUnknownSizeUsesParallelWhenForced(t *testing.T) { + storage := &mockParallelStorage{supports: true} + fs := &S3FS{ + name: "s3", + storage: storage, + ioMerger: NewIOMerger(), + asyncUpdate: true, + parallelMode: ParallelForce, + } + + reader := strings.NewReader("data") + vector := IOVector{ + FilePath: "unknown", + Entries: []IOEntry{ + {Offset: 0, Size: -1, ReaderForWrite: reader}, + }, + } + + if err := fs.Write(context.Background(), vector); err != nil { + t.Fatalf("write failed: %v", err) + } + if storage.mpCalled != 1 { + t.Fatalf("expected parallel write for unknown size") + } + if storage.lastSize != nil { + t.Fatalf("size hint should be nil for unknown size") + } +} + +func TestS3FSWriteUnknownSizeDefaultNoParallel(t *testing.T) { + storage := &mockParallelStorage{supports: true} + fs := &S3FS{ + name: "s3", + storage: storage, + ioMerger: NewIOMerger(), + asyncUpdate: true, + parallelMode: ParallelOff, + } + + reader := strings.NewReader("data") + vector := IOVector{ + FilePath: "unknown", + Entries: []IOEntry{ + {Offset: 0, Size: -1, ReaderForWrite: reader}, + }, + } + + if err := fs.Write(context.Background(), vector); err != nil { + t.Fatalf("write failed: %v", err) + } + if storage.mpCalled != 0 { + t.Fatalf("parallel write should not be used by default") + } + if storage.writeCalled != 1 { + t.Fatalf("expected single write by default") + } +} + +func TestS3FSSmallSizeSkipsParallel(t *testing.T) { + storage := &mockParallelStorage{supports: true} + fs := &S3FS{ + name: "s3", + storage: storage, + ioMerger: NewIOMerger(), + asyncUpdate: true, + parallelMode: ParallelAuto, + } + + data := bytes.Repeat([]byte("b"), int(minMultipartPartSize-1)) + vector := IOVector{ + FilePath: "small", + Entries: []IOEntry{ + {Offset: 0, Size: int64(len(data)), Data: data}, + }, + } + + if err := fs.Write(context.Background(), vector); err != nil { + t.Fatalf("write failed: %v", err) + } + if storage.writeCalled != 1 { + t.Fatalf("expected non-parallel write for small object") + } + if storage.mpCalled != 0 { + t.Fatalf("unexpected parallel write") + } +} + +func TestS3FSLargeSizeDefaultNoParallel(t *testing.T) { + storage := &mockParallelStorage{supports: true} + fs := &S3FS{ + name: "s3", + storage: storage, + ioMerger: NewIOMerger(), + asyncUpdate: true, + parallelMode: ParallelOff, + } + + data := bytes.Repeat([]byte("d"), int(minMultipartPartSize+1)) + vector := IOVector{ + FilePath: "large-default", + Entries: []IOEntry{ + {Offset: 0, Size: int64(len(data)), Data: data}, + }, + } + + if err := fs.Write(context.Background(), vector); err != nil { + t.Fatalf("write failed: %v", err) + } + if storage.mpCalled != 0 { + t.Fatalf("parallel write should be disabled by default") + } + if storage.writeCalled != 1 { + t.Fatalf("expected single write by default, got %d", storage.writeCalled) + } +} + +func TestS3FSAutoUsesParallelForLarge(t *testing.T) { + storage := &mockParallelStorage{supports: true} + fs := &S3FS{ + name: "s3", + storage: storage, + ioMerger: NewIOMerger(), + asyncUpdate: true, + parallelMode: ParallelAuto, + } + + data := bytes.Repeat([]byte("p"), int(minMultipartPartSize+1)) + vector := IOVector{ + FilePath: "auto-large", + Entries: []IOEntry{ + {Offset: 0, Size: int64(len(data)), Data: data}, + }, + } + + if err := fs.Write(context.Background(), vector); err != nil { + t.Fatalf("write failed: %v", err) + } + if storage.mpCalled != 1 { + t.Fatalf("expected parallel write with auto mode") + } +} + +func TestS3FSForceParallelFallbackWhenUnsupported(t *testing.T) { + storage := &mockParallelStorage{supports: false} + fs := &S3FS{ + name: "s3", + storage: storage, + ioMerger: NewIOMerger(), + asyncUpdate: true, + parallelMode: ParallelForce, + } + + data := []byte("force") + vector := IOVector{ + FilePath: "force", + Entries: []IOEntry{ + {Offset: 0, Size: int64(len(data)), Data: data}, + }, + } + + if err := fs.Write(context.Background(), vector); err != nil { + t.Fatalf("write failed: %v", err) + } + if storage.mpCalled != 0 { + t.Fatalf("parallel write should not happen when unsupported") + } + if storage.writeCalled != 1 { + t.Fatalf("expected fallback single write") + } +} + +func TestS3FSParallelPassesExpire(t *testing.T) { + storage := &mockParallelStorage{supports: true} + fs := &S3FS{ + name: "s3", + storage: storage, + ioMerger: NewIOMerger(), + asyncUpdate: true, + parallelMode: ParallelForce, + } + + data := bytes.Repeat([]byte("c"), int(minMultipartPartSize+1)) + expire := time.Now().Add(time.Hour) + vector := IOVector{ + FilePath: "with-expire", + ExpireAt: expire, + Entries: []IOEntry{ + {Offset: 0, Size: int64(len(data)), Data: data}, + }, + } + + if err := fs.Write(context.Background(), vector); err != nil { + t.Fatalf("write failed: %v", err) + } + if storage.mpCalled != 1 { + t.Fatalf("expected parallel write") + } + if storage.lastOpt.Expire == nil || !storage.lastOpt.Expire.Equal(expire) { + t.Fatalf("expire not propagated") + } +} + +func TestS3FSWriteFileExists(t *testing.T) { + storage := &mockParallelStorage{supports: true, exists: true} + fs := &S3FS{ + name: "s3", + storage: storage, + ioMerger: NewIOMerger(), + asyncUpdate: true, + parallelMode: ParallelOff, + } + vector := IOVector{ + FilePath: "dup", + Entries: []IOEntry{ + {Offset: 0, Size: 1, Data: []byte("x")}, + }, + } + err := fs.Write(context.Background(), vector) + if err == nil { + t.Fatalf("expected file exists error") + } +} + +type countingParallelStorage struct { + supports bool + current atomic.Int64 + max atomic.Int64 +} + +func (c *countingParallelStorage) List(ctx context.Context, prefix string) iter.Seq2[*DirEntry, error] { + return func(yield func(*DirEntry, error) bool) {} +} + +func (c *countingParallelStorage) Stat(ctx context.Context, key string) (int64, error) { + return 0, nil +} + +func (c *countingParallelStorage) Exists(ctx context.Context, key string) (bool, error) { + return false, nil +} + +func (c *countingParallelStorage) Write(ctx context.Context, key string, r io.Reader, sizeHint *int64, expire *time.Time) error { + return nil +} + +func (c *countingParallelStorage) Read(ctx context.Context, key string, min *int64, max *int64) (io.ReadCloser, error) { + return io.NopCloser(strings.NewReader("")), nil +} + +func (c *countingParallelStorage) Delete(ctx context.Context, keys ...string) error { + return nil +} + +func (c *countingParallelStorage) SupportsParallelMultipart() bool { + return c.supports +} + +func (c *countingParallelStorage) WriteMultipartParallel(ctx context.Context, key string, r io.Reader, sizeHint *int64, opt *ParallelMultipartOption) error { + cur := c.current.Add(1) + defer c.current.Add(-1) + for { + old := c.max.Load() + if cur <= old || c.max.CompareAndSwap(old, cur) { + break + } + } + time.Sleep(20 * time.Millisecond) + return nil +} + +func TestObjectStorageSemaphoreLimitsParallel(t *testing.T) { + upstream := &countingParallelStorage{supports: true} + sem := newObjectStorageSemaphore(upstream, 1) + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + _ = sem.WriteMultipartParallel(context.Background(), "a", strings.NewReader("1"), nil, nil) + }() + go func() { + defer wg.Done() + _ = sem.WriteMultipartParallel(context.Background(), "b", strings.NewReader("2"), nil, nil) + }() + wg.Wait() + + if upstream.max.Load() > 1 { + t.Fatalf("expected semaphore to cap concurrency at 1, got %d", upstream.max.Load()) + } +} diff --git a/pkg/fileservice/qcloud_sdk.go b/pkg/fileservice/qcloud_sdk.go index 008806f8b609a..165199aa37563 100644 --- a/pkg/fileservice/qcloud_sdk.go +++ b/pkg/fileservice/qcloud_sdk.go @@ -25,7 +25,9 @@ import ( "net/url" "os" gotrace "runtime/trace" + "sort" "strconv" + "sync" "time" "github.com/matrixorigin/matrixone/pkg/common/moerr" @@ -125,6 +127,7 @@ func NewQCloudSDK( } var _ ObjectStorage = new(QCloudSDK) +var _ ParallelMultipartWriter = new(QCloudSDK) func (a *QCloudSDK) List( ctx context.Context, @@ -280,6 +283,269 @@ func (a *QCloudSDK) Write( return } +func (a *QCloudSDK) SupportsParallelMultipart() bool { + return true +} + +func (a *QCloudSDK) WriteMultipartParallel( + ctx context.Context, + key string, + r io.Reader, + sizeHint *int64, + opt *ParallelMultipartOption, +) (err error) { + defer wrapSizeMismatchErr(&err) + + options := normalizeParallelOption(opt) + if sizeHint != nil && *sizeHint < minMultipartPartSize { + return a.Write(ctx, key, r, sizeHint, options.Expire) + } + if sizeHint != nil { + expectedParts := (*sizeHint + options.PartSize - 1) / options.PartSize + if expectedParts > maxMultipartParts { + return moerr.NewInternalErrorNoCtxf("too many parts for multipart upload: %d", expectedParts) + } + } + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + bufPool := sync.Pool{ + New: func() any { + buf := make([]byte, options.PartSize) + return &buf + }, + } + + readChunk := func() (bufPtr *[]byte, buf []byte, n int, err error) { + bufPtr = bufPool.Get().(*[]byte) + raw := *bufPtr + n, err = io.ReadFull(r, raw) + switch { + case errors.Is(err, io.EOF): + bufPool.Put(bufPtr) + return nil, nil, 0, io.EOF + case errors.Is(err, io.ErrUnexpectedEOF): + err = io.EOF + return bufPtr, raw, n, err + case err != nil: + bufPool.Put(bufPtr) + return nil, nil, 0, err + default: + return bufPtr, raw, n, nil + } + } + + firstBufPtr, firstBuf, firstN, err := readChunk() + if err != nil && !errors.Is(err, io.EOF) { + return err + } + if firstN == 0 && errors.Is(err, io.EOF) { + return nil + } + if errors.Is(err, io.EOF) && int64(firstN) < minMultipartPartSize { + data := make([]byte, firstN) + copy(data, firstBuf[:firstN]) + bufPool.Put(firstBufPtr) + size := int64(firstN) + return a.Write(ctx, key, bytes.NewReader(data), &size, options.Expire) + } + + var expiresHeader string + if options.Expire != nil { + expiresHeader = options.Expire.UTC().Format(http.TimeFormat) + } + + initOpt := &cos.InitiateMultipartUploadOptions{ + ObjectPutHeaderOptions: &cos.ObjectPutHeaderOptions{ + Expires: expiresHeader, + }, + } + output, createErr := DoWithRetry("cos initiate multipart upload", func() (*cos.InitiateMultipartUploadResult, error) { + res, _, e := a.client.Object.InitiateMultipartUpload(ctx, key, initOpt) + return res, e + }, maxRetryAttemps, IsRetryableError) + if createErr != nil { + bufPool.Put(firstBufPtr) + return createErr + } + + defer func() { + if err != nil { + _, _ = a.client.Object.AbortMultipartUpload(ctx, key, output.UploadID) + } + }() + + type partJob struct { + num int32 + buf []byte + bufPtr *[]byte + n int + } + + var ( + partNum int32 + parts []cos.Object + partsLock sync.Mutex + wg sync.WaitGroup + errOnce sync.Once + firstErr error + ) + + setErr := func(e error) { + if e == nil { + return + } + errOnce.Do(func() { + firstErr = e + cancel() + }) + } + + jobCh := make(chan partJob, options.Concurrency*2) + + startWorker := func() error { + wg.Add(1) + return getParallelUploadPool().Submit(func() { + defer wg.Done() + for job := range jobCh { + if ctx.Err() != nil { + if job.bufPtr != nil { + bufPool.Put(job.bufPtr) + } + continue + } + uploadOpt := &cos.ObjectUploadPartOptions{ + ContentLength: int64(job.n), + } + resp, uploadErr := DoWithRetry("cos upload part", func() (*cos.Response, error) { + return a.client.Object.UploadPart(ctx, key, output.UploadID, int(job.num), bytes.NewReader(job.buf[:job.n]), uploadOpt) + }, maxRetryAttemps, IsRetryableError) + if uploadErr != nil { + setErr(uploadErr) + if job.bufPtr != nil { + bufPool.Put(job.bufPtr) + } + continue + } + etag := "" + if resp != nil && resp.Header != nil { + etag = resp.Header.Get("ETag") + } + if job.bufPtr != nil { + bufPool.Put(job.bufPtr) + } + partsLock.Lock() + parts = append(parts, cos.Object{ + PartNumber: int(job.num), + ETag: etag, + }) + partsLock.Unlock() + } + }) + } + + for i := 0; i < options.Concurrency; i++ { + if submitErr := startWorker(); submitErr != nil { + setErr(submitErr) + break + } + } + + sendJob := func(bufPtr *[]byte, buf []byte, n int) bool { + partNum++ + if partNum > maxMultipartParts { + setErr(moerr.NewInternalErrorNoCtxf("too many parts for multipart upload: %d", partNum)) + if bufPtr != nil { + bufPool.Put(bufPtr) + } + return false + } + job := partJob{ + num: partNum, + buf: buf, + bufPtr: bufPtr, + n: n, + } + select { + case jobCh <- job: + return true + case <-ctx.Done(): + if bufPtr != nil { + bufPool.Put(bufPtr) + } + setErr(ctx.Err()) + return false + } + } + + if !sendJob(firstBufPtr, firstBuf, firstN) { + close(jobCh) + wg.Wait() + if firstErr != nil { + return firstErr + } + return ctx.Err() + } + + for { + nextBufPtr, nextBuf, nextN, readErr := readChunk() + if errors.Is(readErr, io.EOF) && nextN == 0 { + break + } + if readErr != nil && !errors.Is(readErr, io.EOF) { + setErr(readErr) + if nextBufPtr != nil { + bufPool.Put(nextBufPtr) + } + break + } + if nextN == 0 { + if nextBufPtr != nil { + bufPool.Put(nextBufPtr) + } + break + } + if !sendJob(nextBufPtr, nextBuf, nextN) { + break + } + if readErr != nil && errors.Is(readErr, io.EOF) { + break + } + } + + close(jobCh) + wg.Wait() + + if firstErr != nil { + err = firstErr + return err + } + if len(parts) == 0 { + return nil + } + if len(parts) != int(partNum) { + return moerr.NewInternalErrorNoCtxf("multipart upload incomplete, expect %d parts got %d", partNum, len(parts)) + } + + sort.Slice(parts, func(i, j int) bool { + return parts[i].PartNumber < parts[j].PartNumber + }) + + completeOpt := &cos.CompleteMultipartUploadOptions{ + Parts: parts, + } + _, err = DoWithRetry("cos complete multipart upload", func() (*cos.CompleteMultipartUploadResult, error) { + res, _, e := a.client.Object.CompleteMultipartUpload(ctx, key, output.UploadID, completeOpt) + return res, e + }, maxRetryAttemps, IsRetryableError) + if err != nil { + return err + } + + return nil +} + func (a *QCloudSDK) Read( ctx context.Context, key string, diff --git a/pkg/fileservice/s3_fs.go b/pkg/fileservice/s3_fs.go index 7a2a58570a838..c1d562b87d77e 100644 --- a/pkg/fileservice/s3_fs.go +++ b/pkg/fileservice/s3_fs.go @@ -52,6 +52,8 @@ type S3FS struct { perfCounterSets []*perfcounter.CounterSet ioMerger *IOMerger + + parallelMode ParallelMode } // key mapping scheme: @@ -76,6 +78,7 @@ func NewS3FS( asyncUpdate: true, perfCounterSets: perfCounterSets, ioMerger: NewIOMerger(), + parallelMode: args.ParallelMode, } var err error @@ -452,8 +455,30 @@ func (s *S3FS) write(ctx context.Context, vector IOVector) (bytesWritten int, er expire = &vector.ExpireAt } key := s.pathToKey(path.File) - if err := s.storage.Write(ctx, key, reader, size, expire); err != nil { - return 0, err + enableParallel := false + switch s.parallelMode { + case ParallelForce: + enableParallel = true + case ParallelAuto: + if size == nil || *size >= minMultipartPartSize { + enableParallel = true + } + } + + if pmw, ok := s.storage.(ParallelMultipartWriter); ok && pmw.SupportsParallelMultipart() && + enableParallel { + opt := &ParallelMultipartOption{ + PartSize: defaultParallelMultipartPartSize, + Concurrency: runtime.NumCPU(), + Expire: expire, + } + if err := pmw.WriteMultipartParallel(ctx, key, reader, size, opt); err != nil { + return 0, err + } + } else { + if err := s.storage.Write(ctx, key, reader, size, expire); err != nil { + return 0, err + } } // write to disk cache diff --git a/pkg/fileservice/s3_fs_test.go b/pkg/fileservice/s3_fs_test.go index 042bc27369395..c01189cec2599 100644 --- a/pkg/fileservice/s3_fs_test.go +++ b/pkg/fileservice/s3_fs_test.go @@ -226,6 +226,69 @@ func TestDynamicS3Opts(t *testing.T) { }) } +func TestDynamicS3NoKey(t *testing.T) { + ctx := WithParallelMode(context.Background(), ParallelForce) + prefix := "etl-prefix" + buf := new(strings.Builder) + w := csv.NewWriter(buf) + err := w.Write([]string{ + "s3-no-key", + "disk", + "", + t.TempDir(), + prefix, + "s3-nokey", + }) + assert.Nil(t, err) + w.Flush() + + fs, path, err := GetForETL(ctx, nil, JoinPath(buf.String(), "foo.txt")) + assert.Nil(t, err) + assert.Equal(t, "foo.txt", path) + s3fs, ok := fs.(*S3FS) + assert.True(t, ok) + assert.Equal(t, ParallelForce, s3fs.parallelMode) + assert.Equal(t, prefix, s3fs.keyPrefix) + assert.IsType(t, &diskObjectStorage{}, baseStorage(t, s3fs)) +} + +func TestDynamicMinio(t *testing.T) { + ctx := WithParallelMode(context.Background(), ParallelForce) + buf := new(strings.Builder) + w := csv.NewWriter(buf) + err := w.Write([]string{ + "minio", + "http://127.0.0.1:9000", + "us-east-1", + "bucket", + "ak", + "sk", + "pref", + "minio-etl", + }) + assert.Nil(t, err) + w.Flush() + + fs, path, err := GetForETL(ctx, nil, JoinPath(buf.String(), "bar.txt")) + assert.Nil(t, err) + assert.Equal(t, "bar.txt", path) + s3fs, ok := fs.(*S3FS) + assert.True(t, ok) + assert.Equal(t, ParallelForce, s3fs.parallelMode) + assert.Equal(t, "pref", s3fs.keyPrefix) + assert.IsType(t, &MinioSDK{}, baseStorage(t, s3fs)) +} + +func baseStorage(t *testing.T, fs *S3FS) ObjectStorage { + trace, ok := fs.storage.(*objectStorageHTTPTrace) + assert.True(t, ok) + metrics, ok := trace.upstream.(*objectStorageMetrics) + assert.True(t, ok) + sem, ok := metrics.upstream.(*objectStorageSemaphore) + assert.True(t, ok) + return sem.upstream +} + func BenchmarkS3FS(b *testing.B) { cacheDir := b.TempDir() b.ResetTimer() diff --git a/pkg/frontend/data_branch.go b/pkg/frontend/data_branch.go index 99eb2519c4efb..88df513ac5341 100644 --- a/pkg/frontend/data_branch.go +++ b/pkg/frontend/data_branch.go @@ -144,6 +144,42 @@ type batchWithKind struct { batch *batch.Batch } +type emitFunc func(batchWithKind) (stop bool, err error) + +func newEmitter( + ctx context.Context, stopCh <-chan struct{}, retCh chan batchWithKind, +) emitFunc { + return func(wrapped batchWithKind) (bool, error) { + select { + case <-ctx.Done(): + return false, ctx.Err() + case <-stopCh: + return true, nil + default: + } + + select { + case <-ctx.Done(): + return false, ctx.Err() + case <-stopCh: + return true, nil + case retCh <- wrapped: + return false, nil + } + } +} + +func emitBatch( + emit emitFunc, wrapped batchWithKind, forTombstone bool, pool *retBatchList, +) (bool, error) { + stop, err := emit(wrapped) + if stop || err != nil { + pool.releaseRetBatch(wrapped.batch, forTombstone) + return stop, err + } + return false, nil +} + type retBatchList struct { mu sync.Mutex // 0: data @@ -633,11 +669,17 @@ func diffMergeAgency( }() var ( - dagInfo branchMetaInfo - tblStuff tableStuff - copt compositeOption - ctx, cancel = context.WithCancel(execCtx.reqCtx) + ctx context.Context + cancel context.CancelFunc + ) + //ctx = fileservice.WithParallelMode(execCtx.reqCtx, fileservice.ParallelForce) + ctx, cancel = context.WithCancel(execCtx.reqCtx) + + var ( + dagInfo branchMetaInfo + tblStuff tableStuff + copt compositeOption ok bool diffStmt *tree.DataBranchDiff mergeStmt *tree.DataBranchMerge @@ -683,7 +725,11 @@ func diffMergeAgency( done bool wg = new(sync.WaitGroup) outputErr atomic.Value - retBatCh = make(chan batchWithKind, 100) + retBatCh = make(chan batchWithKind, 10) + stopCh = make(chan struct{}) + stopOnce sync.Once + emit emitFunc + stop func() waited bool ) @@ -699,6 +745,13 @@ func diffMergeAgency( } }() + emit = newEmitter(ctx, stopCh, retBatCh) + stop = func() { + stopOnce.Do(func() { + close(stopCh) + }) + } + if diffStmt != nil { if err = buildOutputSchema(ctx, ses, diffStmt, tblStuff); err != nil { return @@ -725,7 +778,7 @@ func diffMergeAgency( // 5. as file if err2 := satisfyDiffOutputOpt( - ctx, cancel, ses, bh, diffStmt, dagInfo, tblStuff, retBatCh, + ctx, cancel, stop, ses, bh, diffStmt, dagInfo, tblStuff, retBatCh, ); err2 != nil { outputErr.Store(err2) } @@ -739,7 +792,7 @@ func diffMergeAgency( }() if err = diffOnBase( - ctx, ses, bh, wg, dagInfo, tblStuff, retBatCh, copt, + ctx, ses, bh, wg, dagInfo, tblStuff, copt, emit, ); err != nil { return } @@ -909,15 +962,19 @@ func writeReplaceRowValues( tblStuff tableStuff, row []any, buf *bytes.Buffer, -) { +) error { buf.WriteString("(") for i, idx := range tblStuff.def.visibleIdxes { - formatValIntoString(ses, row[idx], tblStuff.def.colTypes[idx], buf) + if err := formatValIntoString(ses, row[idx], tblStuff.def.colTypes[idx], buf); err != nil { + return err + } if i != len(tblStuff.def.visibleIdxes)-1 { buf.WriteString(",") } } buf.WriteString(")") + + return nil } func writeDeleteRowValues( @@ -925,12 +982,14 @@ func writeDeleteRowValues( tblStuff tableStuff, row []any, buf *bytes.Buffer, -) { +) error { if len(tblStuff.def.pkColIdxes) > 1 { buf.WriteString("(") } for i, colIdx := range tblStuff.def.pkColIdxes { - formatValIntoString(ses, row[colIdx], tblStuff.def.colTypes[colIdx], buf) + if err := formatValIntoString(ses, row[colIdx], tblStuff.def.colTypes[colIdx], buf); err != nil { + return err + } if i != len(tblStuff.def.pkColIdxes)-1 { buf.WriteString(",") } @@ -938,6 +997,8 @@ func writeDeleteRowValues( if len(tblStuff.def.pkColIdxes) > 1 { buf.WriteString(")") } + + return nil } func appendBatchRowsAsSQLValues( @@ -1005,9 +1066,13 @@ func appendBatchRowsAsSQLValues( tmpValsBuffer.Reset() if wrapped.kind == diffDelete { - writeDeleteRowValues(ses, tblStuff, row, tmpValsBuffer) + if err = writeDeleteRowValues(ses, tblStuff, row, tmpValsBuffer); err != nil { + return + } } else { - writeReplaceRowValues(ses, tblStuff, row, tmpValsBuffer) + if err = writeReplaceRowValues(ses, tblStuff, row, tmpValsBuffer); err != nil { + return + } } if tmpValsBuffer.Len() == 0 { @@ -1104,6 +1169,7 @@ func mergeDiffs( func satisfyDiffOutputOpt( ctx context.Context, cancel context.CancelFunc, + stop func(), ses *Session, bh BackgroundExec, stmt *tree.DataBranchDiff, @@ -1161,10 +1227,10 @@ func satisfyDiffOutputOpt( rows = append(rows, row) if stmt.OutputOpt != nil && stmt.OutputOpt.Limit != nil && - int64(mrs.GetRowCount()) >= *stmt.OutputOpt.Limit { + int64(len(rows)) >= *stmt.OutputOpt.Limit { // hit limit, cancel producers but keep draining the channel hitLimit = true - cancel() + stop() break } } @@ -2049,6 +2115,12 @@ func writeCSV( case <-inputCtx.Done(): err = errors.Join(err, inputCtx.Err()) stop = true + case e := <-writerErr: + if e != nil { + err = errors.Join(err, e) + } + stop = true + cancelCtx() case e, ok := <-errChan: if !ok { errOpen = false @@ -2217,8 +2289,8 @@ func diffOnBase( wg *sync.WaitGroup, dagInfo branchMetaInfo, tblStuff tableStuff, - retCh chan batchWithKind, copt compositeOption, + emit emitFunc, ) (err error) { defer func() { @@ -2278,8 +2350,8 @@ func diffOnBase( } if err = hashDiff( - ctx, ses, bh, tblStuff, dagInfo, retCh, - copt, tarHandle, baseHandle, + ctx, ses, bh, tblStuff, dagInfo, + copt, emit, tarHandle, baseHandle, ); err != nil { return } @@ -2407,8 +2479,8 @@ func hashDiff( bh BackgroundExec, tblStuff tableStuff, dagInfo branchMetaInfo, - retCh chan batchWithKind, copt compositeOption, + emit emitFunc, tarHandle []engine.ChangesHandle, baseHandle []engine.ChangesHandle, ) ( @@ -2452,7 +2524,7 @@ func hashDiff( if dagInfo.lcaType == lcaEmpty { if err = hashDiffIfNoLCA( - ctx, ses, tblStuff, retCh, copt, + ctx, ses, tblStuff, copt, emit, tarDataHashmap, tarTombstoneHashmap, baseDataHashmap, baseTombstoneHashmap, ); err != nil { @@ -2460,7 +2532,7 @@ func hashDiff( } } else { if err = hashDiffIfHasLCA( - ctx, ses, bh, dagInfo, tblStuff, retCh, copt, + ctx, ses, bh, dagInfo, tblStuff, copt, emit, tarDataHashmap, tarTombstoneHashmap, baseDataHashmap, baseTombstoneHashmap, ); err != nil { @@ -2477,8 +2549,8 @@ func hashDiffIfHasLCA( bh BackgroundExec, dagInfo branchMetaInfo, tblStuff tableStuff, - retCh chan batchWithKind, copt compositeOption, + emit emitFunc, tarDataHashmap databranchutils.BranchHashmap, tarTombstoneHashmap databranchutils.BranchHashmap, baseDataHashmap databranchutils.BranchHashmap, @@ -2511,7 +2583,11 @@ func hashDiffIfHasLCA( handleTarDeleteAndUpdates := func(wrapped batchWithKind) (err2 error) { if len(baseUpdateBatches) == 0 && len(baseDeleteBatches) == 0 { // no need to check conflict - retCh <- wrapped + if stop, e := emitBatch(emit, wrapped, false, tblStuff.retPool); e != nil { + return e + } else if stop { + return nil + } return nil } @@ -2576,7 +2652,12 @@ func hashDiffIfHasLCA( } buf := acquireBuffer(tblStuff.bufPool) - formatValIntoString(ses, tarRow[0], tblStuff.def.colTypes[tblStuff.def.pkColIdx], buf) + if err3 = formatValIntoString( + ses, tarRow[0], tblStuff.def.colTypes[tblStuff.def.pkColIdx], buf, + ); err3 != nil { + releaseBuffer(tblStuff.bufPool, buf) + return + } err3 = moerr.NewInternalErrorNoCtxf( "conflict: %s %s and %s %s on pk(%v) with different values", @@ -2653,7 +2734,13 @@ func hashDiffIfHasLCA( return } - retCh <- wrapped + stop, e := emitBatch(emit, wrapped, false, tblStuff.retPool) + if e != nil { + return e + } + if stop { + return nil + } return } @@ -2725,12 +2812,16 @@ func hashDiffIfHasLCA( } else { err2 = handleTarDeleteAndUpdates(wrapped) } + + if errors.Is(err2, context.Canceled) { + err2 = nil + cancel() + } } if err2 != nil { first = err2 cancel() - tblStuff.retPool.releaseRetBatch(wrapped.batch, false) } } @@ -2758,24 +2849,49 @@ func hashDiffIfHasLCA( // what can I do with these left base updates/inserts ? if copt.conflictOpt == nil { - for _, w := range baseUpdateBatches { - retCh <- w + stopped := false + for i, w := range baseUpdateBatches { + var stop bool + if stop, err = emitBatch(emit, w, false, tblStuff.retPool); err != nil { + return err + } + if stop { + stopped = true + for j := i + 1; j < len(baseUpdateBatches); j++ { + tblStuff.retPool.releaseRetBatch(baseUpdateBatches[j].batch, false) + } + for _, bw := range baseDeleteBatches { + tblStuff.retPool.releaseRetBatch(bw.batch, false) + } + break + } } - for _, w := range baseDeleteBatches { - retCh <- w + if !stopped { + for i, w := range baseDeleteBatches { + var stop bool + if stop, err = emitBatch(emit, w, false, tblStuff.retPool); err != nil { + return err + } + if stop { + for j := i + 1; j < len(baseDeleteBatches); j++ { + tblStuff.retPool.releaseRetBatch(baseDeleteBatches[j].batch, false) + } + break + } + } } } - return diffDataHelper(ctx, ses, copt, tblStuff, retCh, tarDataHashmap, baseDataHashmap) + return diffDataHelper(ctx, ses, copt, tblStuff, emit, tarDataHashmap, baseDataHashmap) } func hashDiffIfNoLCA( ctx context.Context, ses *Session, tblStuff tableStuff, - retCh chan batchWithKind, copt compositeOption, + emit emitFunc, tarDataHashmap databranchutils.BranchHashmap, tarTombstoneHashmap databranchutils.BranchHashmap, baseDataHashmap databranchutils.BranchHashmap, @@ -2802,7 +2918,7 @@ func hashDiffIfNoLCA( return } - return diffDataHelper(ctx, ses, copt, tblStuff, retCh, tarDataHashmap, baseDataHashmap) + return diffDataHelper(ctx, ses, copt, tblStuff, emit, tarDataHashmap, baseDataHashmap) } func compareRowInWrappedBatches( @@ -3005,7 +3121,10 @@ func checkConflictAndAppendToBat( case tree.CONFLICT_FAIL: buf := acquireBuffer(tblStuff.bufPool) for i, idx := range tblStuff.def.pkColIdxes { - formatValIntoString(ses, tarTuple[idx], tblStuff.def.colTypes[idx], buf) + if err2 = formatValIntoString(ses, tarTuple[idx], tblStuff.def.colTypes[idx], buf); err2 != nil { + releaseBuffer(tblStuff.bufPool, buf) + return err2 + } if i < len(tblStuff.def.pkColIdxes)-1 { buf.WriteString(",") } @@ -3041,7 +3160,7 @@ func diffDataHelper( ses *Session, copt compositeOption, tblStuff tableStuff, - retCh chan batchWithKind, + emit emitFunc, tarDataHashmap databranchutils.BranchHashmap, baseDataHashmap databranchutils.BranchHashmap, ) (err error) { @@ -3156,16 +3275,24 @@ func diffDataHelper( default: } - retCh <- batchWithKind{ + if stop, err3 := emitBatch(emit, batchWithKind{ batch: tarBat, kind: diffInsert, name: tblStuff.tarRel.GetTableName(), + }, false, tblStuff.retPool); err3 != nil { + return err3 + } else if stop { + return nil } - retCh <- batchWithKind{ + if stop, err3 := emitBatch(emit, batchWithKind{ batch: baseBat, kind: diffInsert, name: tblStuff.baseRel.GetTableName(), + }, false, tblStuff.retPool); err3 != nil { + return err3 + } else if stop { + return nil } return nil @@ -3214,10 +3341,14 @@ func diffDataHelper( default: } - retCh <- batchWithKind{ + if stop, err3 := emitBatch(emit, batchWithKind{ batch: bat, kind: diffInsert, name: tblStuff.baseRel.GetTableName(), + }, false, tblStuff.retPool); err3 != nil { + return err3 + } else if stop { + return nil } return nil }, -1); err != nil { @@ -3313,7 +3444,11 @@ func handleDelsOnLCA( valsBuf.WriteString(fmt.Sprintf("row(%d,", i)) for j := range tuple { - formatValIntoString(ses, tuple[j], colTypes[expandedPKColIdxes[j]], valsBuf) + if err = formatValIntoString( + ses, tuple[j], colTypes[expandedPKColIdxes[j]], valsBuf, + ); err != nil { + return nil, err + } if j != len(tuple)-1 { valsBuf.WriteString(", ") } @@ -3461,10 +3596,10 @@ func handleDelsOnLCA( return } -func formatValIntoString(ses *Session, val any, t types.Type, buf *bytes.Buffer) { +func formatValIntoString(ses *Session, val any, t types.Type, buf *bytes.Buffer) error { if val == nil { buf.WriteString("NULL") - return + return nil } var scratch [64]byte @@ -3481,6 +3616,10 @@ func formatValIntoString(ses *Session, val any, t types.Type, buf *bytes.Buffer) buf.Write(strconv.AppendFloat(scratch[:0], v, 'g', -1, bitSize)) } + writeBool := func(v bool) { + buf.WriteString(strconv.FormatBool(v)) + } + switch t.Oid { case types.T_varchar, types.T_text, types.T_json, types.T_char, types. T_varbinary, types.T_binary: @@ -3491,7 +3630,7 @@ func formatValIntoString(ses *Session, val any, t types.Type, buf *bytes.Buffer) strVal = x.String() case *bytejson.ByteJson: if x == nil { - panic(moerr.NewInternalErrorNoCtx("formatValIntoString: nil *bytejson.ByteJson")) + return moerr.NewInternalErrorNoCtx("formatValIntoString: nil *bytejson.ByteJson") } strVal = x.String() case []byte: @@ -3499,14 +3638,14 @@ func formatValIntoString(ses *Session, val any, t types.Type, buf *bytes.Buffer) case string: strVal = x default: - panic(moerr.NewInternalErrorNoCtxf("formatValIntoString: unexpected json type %T", val)) + return moerr.NewInternalErrorNoCtxf("formatValIntoString: unexpected json type %T", val) } jsonLiteral := escapeJSONControlBytes([]byte(strVal)) if !json.Valid(jsonLiteral) { - panic(moerr.NewInternalErrorNoCtxf("formatValIntoString: invalid json input %q", strVal)) + return moerr.NewInternalErrorNoCtxf("formatValIntoString: invalid json input %q", strVal) } writeEscapedSQLString(buf, jsonLiteral) - return + return nil } switch x := val.(type) { case []byte: @@ -3514,7 +3653,7 @@ func formatValIntoString(ses *Session, val any, t types.Type, buf *bytes.Buffer) case string: writeEscapedSQLString(buf, []byte(x)) default: - panic(moerr.NewInternalErrorNoCtxf("formatValIntoString: unexpected string type %T", val)) + return moerr.NewInternalErrorNoCtxf("formatValIntoString: unexpected string type %T", val) } case types.T_timestamp: buf.WriteString("'") @@ -3524,6 +3663,10 @@ func formatValIntoString(ses *Session, val any, t types.Type, buf *bytes.Buffer) buf.WriteString("'") buf.WriteString(val.(types.Datetime).String2(t.Scale)) buf.WriteString("'") + case types.T_time: + buf.WriteString("'") + buf.WriteString(val.(types.Time).String2(t.Scale)) + buf.WriteString("'") case types.T_date: buf.WriteString("'") buf.WriteString(val.(types.Date).String()) @@ -3534,6 +3677,8 @@ func formatValIntoString(ses *Session, val any, t types.Type, buf *bytes.Buffer) buf.WriteString(val.(types.Decimal128).Format(t.Scale)) case types.T_decimal256: buf.WriteString(val.(types.Decimal256).Format(t.Scale)) + case types.T_bool: + writeBool(val.(bool)) case types.T_uint8: writeUint(uint64(val.(uint8))) case types.T_uint16: @@ -3563,8 +3708,10 @@ func formatValIntoString(ses *Session, val any, t types.Type, buf *bytes.Buffer) buf.WriteString(types.ArrayToString[float64](val.([]float64))) buf.WriteString("'") default: - panic(moerr.NewInternalErrorNoCtxf("formatValIntoString: unsupported type %v", t.Oid)) + return moerr.NewNotSupportedNoCtxf("formatValIntoString: not support type %v", t.Oid) } + + return nil } // writeEscapedSQLString escapes special and control characters for SQL literal output. @@ -4389,114 +4536,140 @@ func compareSingleValInVector( // Use raw values to avoid format conversions in extractRowFromVector. switch vec1.GetType().Oid { case types.T_json: - val1 := types.DecodeJson(vec1.GetBytesAt(rowIdx1)) - val2 := types.DecodeJson(vec2.GetBytesAt(rowIdx2)) - return bytejson.CompareByteJson(val1, val2), nil + return bytejson.CompareByteJson( + types.DecodeJson(vec1.GetBytesAt(rowIdx1)), + types.DecodeJson(vec2.GetBytesAt(rowIdx2)), + ), nil case types.T_bool: - val1 := vector.GetFixedAtNoTypeCheck[bool](vec1, rowIdx1) - val2 := vector.GetFixedAtNoTypeCheck[bool](vec2, rowIdx2) - return types.CompareValue(val1, val2), nil + return types.CompareValue( + vector.GetFixedAtNoTypeCheck[bool](vec1, rowIdx1), + vector.GetFixedAtNoTypeCheck[bool](vec2, rowIdx2), + ), nil case types.T_bit: - val1 := vector.GetFixedAtNoTypeCheck[uint64](vec1, rowIdx1) - val2 := vector.GetFixedAtNoTypeCheck[uint64](vec2, rowIdx2) - return types.CompareValue(val1, val2), nil + return types.CompareValue( + vector.GetFixedAtNoTypeCheck[uint64](vec1, rowIdx1), + vector.GetFixedAtNoTypeCheck[uint64](vec2, rowIdx2), + ), nil case types.T_int8: - val1 := vector.GetFixedAtNoTypeCheck[int8](vec1, rowIdx1) - val2 := vector.GetFixedAtNoTypeCheck[int8](vec2, rowIdx2) - return types.CompareValue(val1, val2), nil + return types.CompareValue( + vector.GetFixedAtNoTypeCheck[int8](vec1, rowIdx1), + vector.GetFixedAtNoTypeCheck[int8](vec2, rowIdx2), + ), nil case types.T_uint8: - val1 := vector.GetFixedAtNoTypeCheck[uint8](vec1, rowIdx1) - val2 := vector.GetFixedAtNoTypeCheck[uint8](vec2, rowIdx2) - return types.CompareValue(val1, val2), nil + return types.CompareValue( + vector.GetFixedAtNoTypeCheck[uint8](vec1, rowIdx1), + vector.GetFixedAtNoTypeCheck[uint8](vec2, rowIdx2), + ), nil case types.T_int16: - val1 := vector.GetFixedAtNoTypeCheck[int16](vec1, rowIdx1) - val2 := vector.GetFixedAtNoTypeCheck[int16](vec2, rowIdx2) - return types.CompareValue(val1, val2), nil + return types.CompareValue( + vector.GetFixedAtNoTypeCheck[int16](vec1, rowIdx1), + vector.GetFixedAtNoTypeCheck[int16](vec2, rowIdx2), + ), nil case types.T_uint16: - val1 := vector.GetFixedAtNoTypeCheck[uint16](vec1, rowIdx1) - val2 := vector.GetFixedAtNoTypeCheck[uint16](vec2, rowIdx2) - return types.CompareValue(val1, val2), nil + return types.CompareValue( + vector.GetFixedAtNoTypeCheck[uint16](vec1, rowIdx1), + vector.GetFixedAtNoTypeCheck[uint16](vec2, rowIdx2), + ), nil case types.T_int32: - val1 := vector.GetFixedAtNoTypeCheck[int32](vec1, rowIdx1) - val2 := vector.GetFixedAtNoTypeCheck[int32](vec2, rowIdx2) - return types.CompareValue(val1, val2), nil + return types.CompareValue( + vector.GetFixedAtNoTypeCheck[int32](vec1, rowIdx1), + vector.GetFixedAtNoTypeCheck[int32](vec2, rowIdx2), + ), nil case types.T_uint32: - val1 := vector.GetFixedAtNoTypeCheck[uint32](vec1, rowIdx1) - val2 := vector.GetFixedAtNoTypeCheck[uint32](vec2, rowIdx2) - return types.CompareValue(val1, val2), nil + return types.CompareValue( + vector.GetFixedAtNoTypeCheck[uint32](vec1, rowIdx1), + vector.GetFixedAtNoTypeCheck[uint32](vec2, rowIdx2), + ), nil case types.T_int64: - val1 := vector.GetFixedAtNoTypeCheck[int64](vec1, rowIdx1) - val2 := vector.GetFixedAtNoTypeCheck[int64](vec2, rowIdx2) - return types.CompareValue(val1, val2), nil + return types.CompareValue( + vector.GetFixedAtNoTypeCheck[int64](vec1, rowIdx1), + vector.GetFixedAtNoTypeCheck[int64](vec2, rowIdx2), + ), nil case types.T_uint64: - val1 := vector.GetFixedAtNoTypeCheck[uint64](vec1, rowIdx1) - val2 := vector.GetFixedAtNoTypeCheck[uint64](vec2, rowIdx2) - return types.CompareValue(val1, val2), nil + return types.CompareValue( + vector.GetFixedAtNoTypeCheck[uint64](vec1, rowIdx1), + vector.GetFixedAtNoTypeCheck[uint64](vec2, rowIdx2), + ), nil case types.T_float32: - val1 := vector.GetFixedAtNoTypeCheck[float32](vec1, rowIdx1) - val2 := vector.GetFixedAtNoTypeCheck[float32](vec2, rowIdx2) - return types.CompareValue(val1, val2), nil + return types.CompareValue( + vector.GetFixedAtNoTypeCheck[float32](vec1, rowIdx1), + vector.GetFixedAtNoTypeCheck[float32](vec2, rowIdx2), + ), nil case types.T_float64: - val1 := vector.GetFixedAtNoTypeCheck[float64](vec1, rowIdx1) - val2 := vector.GetFixedAtNoTypeCheck[float64](vec2, rowIdx2) - return types.CompareValue(val1, val2), nil + return types.CompareValue( + vector.GetFixedAtNoTypeCheck[float64](vec1, rowIdx1), + vector.GetFixedAtNoTypeCheck[float64](vec2, rowIdx2), + ), nil case types.T_char, types.T_varchar, types.T_blob, types.T_text, types.T_binary, types.T_varbinary, types.T_datalink: return bytes.Compare( vec1.GetBytesAt(rowIdx1), vec2.GetBytesAt(rowIdx2), ), nil case types.T_array_float32: - val1 := vector.GetArrayAt[float32](vec1, rowIdx1) - val2 := vector.GetArrayAt[float32](vec2, rowIdx2) - return types.CompareValue(val1, val2), nil + return types.CompareValue( + vector.GetArrayAt[float32](vec1, rowIdx1), + vector.GetArrayAt[float32](vec2, rowIdx2), + ), nil case types.T_array_float64: - val1 := vector.GetArrayAt[float64](vec1, rowIdx1) - val2 := vector.GetArrayAt[float64](vec2, rowIdx2) - return types.CompareValue(val1, val2), nil + return types.CompareValue( + vector.GetArrayAt[float64](vec1, rowIdx1), + vector.GetArrayAt[float64](vec2, rowIdx2), + ), nil case types.T_date: - val1 := vector.GetFixedAtNoTypeCheck[types.Date](vec1, rowIdx1) - val2 := vector.GetFixedAtNoTypeCheck[types.Date](vec2, rowIdx2) - return types.CompareValue(val1, val2), nil + return types.CompareValue( + vector.GetFixedAtNoTypeCheck[types.Date](vec1, rowIdx1), + vector.GetFixedAtNoTypeCheck[types.Date](vec2, rowIdx2), + ), nil case types.T_datetime: - val1 := vector.GetFixedAtNoTypeCheck[types.Datetime](vec1, rowIdx1) - val2 := vector.GetFixedAtNoTypeCheck[types.Datetime](vec2, rowIdx2) - return types.CompareValue(val1, val2), nil + return types.CompareValue( + vector.GetFixedAtNoTypeCheck[types.Datetime](vec1, rowIdx1), + vector.GetFixedAtNoTypeCheck[types.Datetime](vec2, rowIdx2), + ), nil case types.T_time: - val1 := vector.GetFixedAtNoTypeCheck[types.Time](vec1, rowIdx1) - val2 := vector.GetFixedAtNoTypeCheck[types.Time](vec2, rowIdx2) - return types.CompareValue(val1, val2), nil + return types.CompareValue( + vector.GetFixedAtNoTypeCheck[types.Time](vec1, rowIdx1), + vector.GetFixedAtNoTypeCheck[types.Time](vec2, rowIdx2), + ), nil case types.T_timestamp: - val1 := vector.GetFixedAtNoTypeCheck[types.Timestamp](vec1, rowIdx1) - val2 := vector.GetFixedAtNoTypeCheck[types.Timestamp](vec2, rowIdx2) - return types.CompareValue(val1, val2), nil + return types.CompareValue( + vector.GetFixedAtNoTypeCheck[types.Timestamp](vec1, rowIdx1), + vector.GetFixedAtNoTypeCheck[types.Timestamp](vec2, rowIdx2), + ), nil case types.T_decimal64: - val1 := vector.GetFixedAtNoTypeCheck[types.Decimal64](vec1, rowIdx1) - val2 := vector.GetFixedAtNoTypeCheck[types.Decimal64](vec2, rowIdx2) - return types.CompareValue(val1, val2), nil + return types.CompareValue( + vector.GetFixedAtNoTypeCheck[types.Decimal64](vec1, rowIdx1), + vector.GetFixedAtNoTypeCheck[types.Decimal64](vec2, rowIdx2), + ), nil case types.T_decimal128: - val1 := vector.GetFixedAtNoTypeCheck[types.Decimal128](vec1, rowIdx1) - val2 := vector.GetFixedAtNoTypeCheck[types.Decimal128](vec2, rowIdx2) - return types.CompareValue(val1, val2), nil + return types.CompareValue( + vector.GetFixedAtNoTypeCheck[types.Decimal128](vec1, rowIdx1), + vector.GetFixedAtNoTypeCheck[types.Decimal128](vec2, rowIdx2), + ), nil case types.T_uuid: - val1 := vector.GetFixedAtNoTypeCheck[types.Uuid](vec1, rowIdx1) - val2 := vector.GetFixedAtNoTypeCheck[types.Uuid](vec2, rowIdx2) - return types.CompareValue(val1, val2), nil + return types.CompareValue( + vector.GetFixedAtNoTypeCheck[types.Uuid](vec1, rowIdx1), + vector.GetFixedAtNoTypeCheck[types.Uuid](vec2, rowIdx2), + ), nil case types.T_Rowid: - val1 := vector.GetFixedAtNoTypeCheck[types.Rowid](vec1, rowIdx1) - val2 := vector.GetFixedAtNoTypeCheck[types.Rowid](vec2, rowIdx2) - return types.CompareValue(val1, val2), nil + return types.CompareValue( + vector.GetFixedAtNoTypeCheck[types.Rowid](vec1, rowIdx1), + vector.GetFixedAtNoTypeCheck[types.Rowid](vec2, rowIdx2), + ), nil case types.T_Blockid: - val1 := vector.GetFixedAtNoTypeCheck[types.Blockid](vec1, rowIdx1) - val2 := vector.GetFixedAtNoTypeCheck[types.Blockid](vec2, rowIdx2) - return types.CompareValue(val1, val2), nil + return types.CompareValue( + vector.GetFixedAtNoTypeCheck[types.Blockid](vec1, rowIdx1), + vector.GetFixedAtNoTypeCheck[types.Blockid](vec2, rowIdx2), + ), nil case types.T_TS: - val1 := vector.GetFixedAtNoTypeCheck[types.TS](vec1, rowIdx1) - val2 := vector.GetFixedAtNoTypeCheck[types.TS](vec2, rowIdx2) - return types.CompareValue(val1, val2), nil + return types.CompareValue( + vector.GetFixedAtNoTypeCheck[types.TS](vec1, rowIdx1), + vector.GetFixedAtNoTypeCheck[types.TS](vec2, rowIdx2), + ), nil case types.T_enum: - val1 := vector.GetFixedAtNoTypeCheck[types.Enum](vec1, rowIdx1) - val2 := vector.GetFixedAtNoTypeCheck[types.Enum](vec2, rowIdx2) - return types.CompareValue(val1, val2), nil + return types.CompareValue( + vector.GetFixedAtNoTypeCheck[types.Enum](vec1, rowIdx1), + vector.GetFixedAtNoTypeCheck[types.Enum](vec2, rowIdx2), + ), nil default: return 0, moerr.NewInternalErrorNoCtxf("compareSingleValInVector : unsupported type %d", vec1.GetType().Oid) } diff --git a/pkg/frontend/data_branch_test.go b/pkg/frontend/data_branch_test.go index 0d9aba8339b0f..c6be503e4c9af 100644 --- a/pkg/frontend/data_branch_test.go +++ b/pkg/frontend/data_branch_test.go @@ -40,7 +40,7 @@ func TestFormatValIntoString_StringEscaping(t *testing.T) { ses := &Session{} val := "a'b\"c\\\n\t\r\x1a\x00" - formatValIntoString(ses, val, types.New(types.T_varchar, 0, 0), &buf) + require.NoError(t, formatValIntoString(ses, val, types.New(types.T_varchar, 0, 0), &buf)) require.Equal(t, `'a\'b"c\\\n\t\r\Z\0'`, buf.String()) } @@ -49,16 +49,27 @@ func TestFormatValIntoString_ByteEscaping(t *testing.T) { ses := &Session{} val := []byte{'x', 0x00, '\\', 0x07, '\''} - formatValIntoString(ses, val, types.New(types.T_varbinary, 0, 0), &buf) + require.NoError(t, formatValIntoString(ses, val, types.New(types.T_varbinary, 0, 0), &buf)) require.Equal(t, `'x\0\\\x07\''`, buf.String()) } +func TestFormatValIntoString_Time(t *testing.T) { + var buf bytes.Buffer + ses := &Session{} + + val, err := types.ParseTime("12:34:56.123456", 6) + require.NoError(t, err) + + require.NoError(t, formatValIntoString(ses, val, types.New(types.T_time, 0, 6), &buf)) + require.Equal(t, `'12:34:56.123456'`, buf.String()) +} + func TestFormatValIntoString_JSONEscaping(t *testing.T) { var buf bytes.Buffer ses := &Session{} val := `{"k":"` + string([]byte{0x01, '\n'}) + `"}` - formatValIntoString(ses, val, types.New(types.T_json, 0, 0), &buf) + require.NoError(t, formatValIntoString(ses, val, types.New(types.T_json, 0, 0), &buf)) require.Equal(t, `'{"k":"\\u0001\\u000a"}'`, buf.String()) } @@ -69,7 +80,7 @@ func TestFormatValIntoString_JSONByteJson(t *testing.T) { bj, err := types.ParseStringToByteJson(`{"a":1}`) require.NoError(t, err) - formatValIntoString(ses, bj, types.New(types.T_json, 0, 0), &buf) + require.NoError(t, formatValIntoString(ses, bj, types.New(types.T_json, 0, 0), &buf)) require.Equal(t, `'{"a": 1}'`, buf.String()) } @@ -77,7 +88,7 @@ func TestFormatValIntoString_Nil(t *testing.T) { var buf bytes.Buffer ses := &Session{} - formatValIntoString(ses, nil, types.New(types.T_varchar, 0, 0), &buf) + require.NoError(t, formatValIntoString(ses, nil, types.New(types.T_varchar, 0, 0), &buf)) require.Equal(t, "NULL", buf.String()) } @@ -85,9 +96,9 @@ func TestFormatValIntoString_UnsupportedType(t *testing.T) { var buf bytes.Buffer ses := &Session{} - require.Panics(t, func() { - formatValIntoString(ses, true, types.New(types.T_bool, 0, 0), &buf) - }) + err := formatValIntoString(ses, true, types.New(types.T_Rowid, 0, 0), &buf) + require.Error(t, err) + require.Contains(t, err.Error(), "not support type") } func TestCompareSingleValInVector_AllTypes(t *testing.T) { diff --git a/pkg/frontend/export.go b/pkg/frontend/export.go index 1a62b65cca1cf..e6a0316a23551 100644 --- a/pkg/frontend/export.go +++ b/pkg/frontend/export.go @@ -364,6 +364,21 @@ func constructByte(ctx context.Context, obj FeSession, bat *batch.Batch, index i mp = ss.GetMemPool() } + // respect cancellation to avoid blocking when downstream writer stops + if ctx.Err() != nil { + bat.Clean(mp) + return + } + + sendByte := func(bb *BatchByte) bool { + select { + case ByteChan <- bb: + return true + case <-ctx.Done(): + return false + } + } + symbol := ep.Symbol closeby := ep.userConfig.Fields.EnclosedBy.Value terminated := ep.userConfig.Fields.Terminated.Value @@ -501,9 +516,10 @@ func constructByte(ctx context.Context, obj FeSession, bat *batch.Batch, index i zap.Int("typeOid", int(vec.GetType().Oid))) } - ByteChan <- &BatchByte{ + // stop early if downstream already failed + sendByte(&BatchByte{ err: moerr.NewInternalErrorf(ctx, "constructByte : unsupported type %d", vec.GetType().Oid), - } + }) bat.Clean(mp) return } @@ -516,10 +532,13 @@ func constructByte(ctx context.Context, obj FeSession, bat *batch.Batch, index i copy(result, buffer.Bytes()) buffer = nil - ByteChan <- &BatchByte{ + if !sendByte(&BatchByte{ index: index, writeByte: result, err: nil, + }) { + bat.Clean(mp) + return } if ss != nil { diff --git a/pkg/tests/dml/dml_test.go b/pkg/tests/dml/dml_test.go index ecb29a789837a..49376bad692cf 100644 --- a/pkg/tests/dml/dml_test.go +++ b/pkg/tests/dml/dml_test.go @@ -152,6 +152,15 @@ func TestDataBranchDiffAsFile(t *testing.T) { t.Log("csv diff covers rich data type payloads") runCSVLoadRichTypes(t, ctx, sqlDB) + t.Log("diff output limit returns subset of full diff") + runDiffOutputLimitSubset(t, ctx, sqlDB) + + t.Log("diff output limit without branch relationship returns subset of full diff") + runDiffOutputLimitNoBase(t, ctx, sqlDB) + + t.Log("diff output limit with large base workload still returns subset of full diff") + runDiffOutputLimitLargeBase(t, ctx, sqlDB) + t.Log("diff output to stage and load via datalink") runDiffOutputToStage(t, ctx, sqlDB) }) @@ -532,6 +541,149 @@ insert into %s values assertTablesEqual(t, ctx, db, dbName, target, base) } +func runDiffOutputLimitSubset(t *testing.T, parentCtx context.Context, db *sql.DB) { + t.Helper() + + ctx, cancel := context.WithTimeout(parentCtx, time.Second*90) + defer cancel() + + dbName := testutils.GetDatabaseName(t) + base := "limit_base" + branch := "limit_branch" + + execSQLDB(t, ctx, db, fmt.Sprintf("create database `%s`", dbName)) + defer func() { + execSQLDB(t, ctx, db, "use mo_catalog") + execSQLDB(t, ctx, db, fmt.Sprintf("drop database if exists `%s`", dbName)) + }() + execSQLDB(t, ctx, db, fmt.Sprintf("use `%s`", dbName)) + + execSQLDB(t, ctx, db, fmt.Sprintf("create table %s (id int primary key, val int, note varchar(16))", base)) + execSQLDB(t, ctx, db, fmt.Sprintf("insert into %s values (1, 10, 'seed'), (2, 20, 'seed'), (3, 30, 'seed'), (4, 40, 'seed'), (5, 50, 'seed'), (6, 60, 'seed')", base)) + + execSQLDB(t, ctx, db, fmt.Sprintf("data branch create table %s from %s", branch, base)) + execSQLDB(t, ctx, db, fmt.Sprintf("insert into %s values (7, 70, 'inserted'), (8, 80, 'inserted')", branch)) + execSQLDB(t, ctx, db, fmt.Sprintf("update %s set val = val + 100, note = 'updated' where id in (2,3)", branch)) + execSQLDB(t, ctx, db, fmt.Sprintf("delete from %s where id in (4,5)", branch)) + + fullStmt := fmt.Sprintf("data branch diff %s against %s", branch, base) + fullRows := fetchDiffRowsAsStrings(t, ctx, db, fullStmt) + require.GreaterOrEqual(t, len(fullRows), 6) + + limit := 3 + limitStmt := fmt.Sprintf("data branch diff %s against %s output limit %d", branch, base, limit) + limitedRows := fetchDiffRowsAsStrings(t, ctx, db, limitStmt) + + require.NotEmpty(t, limitedRows, "limited diff returned no rows") + require.LessOrEqual(t, len(limitedRows), limit, "limited diff returned too many rows") + + fullSet := make(map[string]struct{}, len(fullRows)) + for _, row := range fullRows { + fullSet[strings.Join(row, "||")] = struct{}{} + } + for _, row := range limitedRows { + _, ok := fullSet[strings.Join(row, "||")] + require.Truef(t, ok, "limited diff row not contained in full diff: %v", row) + } +} + +func runDiffOutputLimitNoBase(t *testing.T, parentCtx context.Context, db *sql.DB) { + t.Helper() + + ctx, cancel := context.WithTimeout(parentCtx, time.Second*90) + defer cancel() + + dbName := testutils.GetDatabaseName(t) + base := "limit_nobranch_base" + target := "limit_nobranch_target" + + execSQLDB(t, ctx, db, fmt.Sprintf("create database `%s`", dbName)) + defer func() { + execSQLDB(t, ctx, db, "use mo_catalog") + execSQLDB(t, ctx, db, fmt.Sprintf("drop database if exists `%s`", dbName)) + }() + execSQLDB(t, ctx, db, fmt.Sprintf("use `%s`", dbName)) + + execSQLDB(t, ctx, db, fmt.Sprintf("create table %s (id int primary key, val int, note varchar(16))", base)) + execSQLDB(t, ctx, db, fmt.Sprintf("create table %s (id int primary key, val int, note varchar(16))", target)) + + execSQLDB(t, ctx, db, fmt.Sprintf("insert into %s values (1, 10, 'seed'), (2, 20, 'seed'), (3, 30, 'seed'), (4, 40, 'seed')", base)) + + execSQLDB(t, ctx, db, fmt.Sprintf("insert into %s values (1, 110, 'updated'), (2, 20, 'seed'), (5, 500, 'added'), (6, 600, 'added')", target)) + + fullStmt := fmt.Sprintf("data branch diff %s against %s", target, base) + fullRows := fetchDiffRowsAsStrings(t, ctx, db, fullStmt) + require.GreaterOrEqual(t, len(fullRows), 3) + + limit := 1 + limitStmt := fmt.Sprintf("data branch diff %s against %s output limit %d", target, base, limit) + limitedRows := fetchDiffRowsAsStrings(t, ctx, db, limitStmt) + + require.NotEmpty(t, limitedRows, "limited diff returned no rows") + require.LessOrEqual(t, len(limitedRows), limit, "limited diff returned too many rows") + + fullSet := make(map[string]struct{}, len(fullRows)) + for _, row := range fullRows { + fullSet[strings.Join(row, "||")] = struct{}{} + } + for _, row := range limitedRows { + _, ok := fullSet[strings.Join(row, "||")] + require.Truef(t, ok, "limited diff row not contained in full diff: %v", row) + } +} + +func runDiffOutputLimitLargeBase(t *testing.T, parentCtx context.Context, db *sql.DB) { + t.Helper() + + ctx, cancel := context.WithTimeout(parentCtx, time.Second*180) + defer cancel() + + dbName := testutils.GetDatabaseName(t) + base := "limit_large_t1" + branch := "limit_large_t2" + + execSQLDB(t, ctx, db, fmt.Sprintf("create database `%s`", dbName)) + defer func() { + execSQLDB(t, ctx, db, "use mo_catalog") + execSQLDB(t, ctx, db, fmt.Sprintf("drop database if exists `%s`", dbName)) + }() + execSQLDB(t, ctx, db, fmt.Sprintf("use `%s`", dbName)) + + execSQLDB(t, ctx, db, fmt.Sprintf("create table %s (a int primary key, b int, c time)", base)) + execSQLDB(t, ctx, db, fmt.Sprintf("insert into %s select *, *, '12:34:56' from generate_series(1, 8192*100)g", base)) + + execSQLDB(t, ctx, db, fmt.Sprintf("data branch create table %s from %s", branch, base)) + + execSQLDB(t, ctx, db, fmt.Sprintf("update %s set b = b + 1 where a between 1 and 10000", base)) + execSQLDB(t, ctx, db, fmt.Sprintf("update %s set b = b + 2 where a between 10000 and 10001", branch)) + execSQLDB(t, ctx, db, fmt.Sprintf("delete from %s where a between 30000 and 100000", base)) + + fullStmt := fmt.Sprintf("data branch diff %s against %s", branch, base) + fullRows := fetchDiffRowsAsStrings(t, ctx, db, fullStmt) + //require.Equal(t, 30, len(fullRows), fmt.Sprintf("full diff: %v", fullRows)) + fmt.Println("full diff:", len(fullRows)) + + limitQuery := func(cnt int) { + limitStmt := fmt.Sprintf("data branch diff %s against %s output limit %d", branch, base, cnt) + limitedRows := fetchDiffRowsAsStrings(t, ctx, db, limitStmt) + + require.NotEmpty(t, limitedRows, "limited diff returned no rows") + require.LessOrEqual(t, len(limitedRows), cnt, fmt.Sprintf("limited diff returned too many rows: %v", limitedRows)) + + fullSet := make(map[string]struct{}, len(fullRows)) + for _, row := range fullRows { + fullSet[strings.Join(row, "||")] = struct{}{} + } + for _, row := range limitedRows { + _, ok := fullSet[strings.Join(row, "||")] + require.Truef(t, ok, "limited diff row not contained in full diff: %v", row) + } + } + + limitQuery(len(fullRows) * 1 / 100) + limitQuery(len(fullRows) * 20 / 100) +} + func runDiffOutputToStage(t *testing.T, parentCtx context.Context, db *sql.DB) { t.Helper() @@ -678,6 +830,40 @@ func parseSQLStatements(content string) []string { return stmts } +func fetchDiffRowsAsStrings(t *testing.T, ctx context.Context, db *sql.DB, stmt string) [][]string { + t.Helper() + + rows, err := db.QueryContext(ctx, stmt) + require.NoErrorf(t, err, "sql: %s", stmt) + defer rows.Close() + + cols, err := rows.Columns() + require.NoError(t, err) + + result := make([][]string, 0, 8) + for rows.Next() { + raw := make([]sql.RawBytes, len(cols)) + dest := make([]any, len(cols)) + for i := range raw { + dest[i] = &raw[i] + } + require.NoError(t, rows.Scan(dest...)) + + row := make([]string, len(cols)) + for i, b := range raw { + if b == nil { + row[i] = "NULL" + continue + } + row[i] = string(b) + } + result = append(result, row) + } + require.NoErrorf(t, rows.Err(), "sql: %s", stmt) + require.NotEmpty(t, result, "diff statement returned no rows: %s", stmt) + return result +} + func execDiffAndFetchFile(t *testing.T, ctx context.Context, db *sql.DB, stmt string) string { t.Helper()