diff --git a/feature/s3/transfermanager/api.go b/feature/s3/transfermanager/api.go index 526cbf96ff0..0255135accf 100644 --- a/feature/s3/transfermanager/api.go +++ b/feature/s3/transfermanager/api.go @@ -15,4 +15,5 @@ type S3APIClient interface { AbortMultipartUpload(context.Context, *s3.AbortMultipartUploadInput, ...func(*s3.Options)) (*s3.AbortMultipartUploadOutput, error) GetObject(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) (*s3.GetObjectOutput, error) HeadObject(context.Context, *s3.HeadObjectInput, ...func(*s3.Options)) (*s3.HeadObjectOutput, error) + ListObjectsV2(context.Context, *s3.ListObjectsV2Input, ...func(*s3.Options)) (*s3.ListObjectsV2Output, error) } diff --git a/feature/s3/transfermanager/api_op_DownloadDirectory.go b/feature/s3/transfermanager/api_op_DownloadDirectory.go new file mode 100644 index 00000000000..d1b31b5de04 --- /dev/null +++ b/feature/s3/transfermanager/api_op_DownloadDirectory.go @@ -0,0 +1,252 @@ +package transfermanager + +import ( + "context" + "errors" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3" + s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" +) + +// DownloadDirectoryInput represents a request to the DownloadDirectory() call +type DownloadDirectoryInput struct { + // Bucket where objects are downloaded from + Bucket string + + // The destination directory to download + Destination string + + // The S3 key prefix to use for listing objects. If not provided, + // all objects under a bucket will be retrieved + KeyPrefix string + + // The s3 delimiter used to convert keyname to local filepath if it + // is different from local file separator + S3Delimiter string + + // A callback func to allow users to fileter out unwanted objects + // according to bool returned from the function + Filter ObjectFilter + + // A callback function to allow customers to update individual + // GetObjectInput that the S3 Transfer Manager generates + Callback GetRequestCallback +} + +// ObjectFilter is the callback to allow users to filter out unwanted objects. +// It is invoked for each object listed. +type ObjectFilter interface { + // FilterObject take the Object struct and decides if the + // object should be downloaded + FilterObject(s3types.Object) bool +} + +// GetRequestCallback is the callback mechanism to allow customers to update +// individual GetObjectInput that the S3 Transfer Manager generates +type GetRequestCallback interface { + // UpdateRequest preprocesses each GetObjectInput as customized + UpdateRequest(*GetObjectInput) +} + +// DownloadDirectoryOutput represents a response from the DownloadDirectory() call +type DownloadDirectoryOutput struct { + // Total number of objects successfully downloaded + ObjectsDownloaded int +} + +type objectEntry struct { + key string + path string +} + +// DownloadDirectory traverses a s3 bucket and intelligently downloads all valid objects +// to local directory in parallel across multiple goroutines. You can configure the concurrency, +// valid object filtering and hierarchical file naming through the Options and input parameters. +// +// Additional functional options can be provided to configure the individual directory +// download. These options are copies of the original Options instance, the client of which DownloadDirectory is called from. +// Modifying the options will not impact the original Client and Options instance. +func (c *Client) DownloadDirectory(ctx context.Context, input *DownloadDirectoryInput, opts ...func(*Options)) (*DownloadDirectoryOutput, error) { + fileInfo, err := os.Stat(input.Destination) + if err != nil { + if !errors.Is(err, fs.ErrNotExist) { + return nil, fmt.Errorf("error when getting destination folder info: %v", err) + } + } else if !fileInfo.IsDir() { + return nil, fmt.Errorf("the destination path %s doesn't point to a valid directory", input.Destination) + + } + + i := directoryDownloader{c: c, in: input, options: c.options.Copy()} + for _, opt := range opts { + opt(&i.options) + } + + return i.downloadDirectory(ctx) +} + +type directoryDownloader struct { + c *Client + options Options + in *DownloadDirectoryInput + + objectsDownloaded int + + err error + + mu sync.Mutex + wg sync.WaitGroup +} + +func (d *directoryDownloader) downloadDirectory(ctx context.Context) (*DownloadDirectoryOutput, error) { + d.init() + ch := make(chan objectEntry) + + for i := 0; i < d.options.DirectoryConcurrency; i++ { + d.wg.Add(1) + go d.downloadObject(ctx, ch) + } + + isTruncated := true + continuationToken := "" + for isTruncated { + if d.getErr() != nil { + break + } + listOutput, err := d.options.S3.ListObjectsV2(ctx, &s3.ListObjectsV2Input{ + Bucket: aws.String(d.in.Bucket), + Prefix: nzstring(d.in.KeyPrefix), + ContinuationToken: nzstring(continuationToken), + }) + if err != nil { + d.setErr(fmt.Errorf("error when listing objects %v", err)) + break + } + + for _, o := range listOutput.Contents { + key := aws.ToString(o.Key) + if strings.HasSuffix(key, "/") || strings.HasSuffix(key, d.in.S3Delimiter) { + continue // skip folder object + } + if d.in.Filter != nil && !d.in.Filter.FilterObject(o) { + continue + } + path, err := d.getLocalPath(key) + if err != nil { + d.setErr(fmt.Errorf("error when resolving local path for object %s, %v", key, err)) + break + } + ch <- objectEntry{key, path} + } + + continuationToken = aws.ToString(listOutput.NextContinuationToken) + isTruncated = aws.ToBool(listOutput.IsTruncated) + } + + close(ch) + d.wg.Wait() + + if d.err != nil { + return nil, d.err + } + + return &DownloadDirectoryOutput{ + ObjectsDownloaded: d.objectsDownloaded, + }, nil +} + +func (d *directoryDownloader) init() { + if d.in.S3Delimiter == "" { + d.in.S3Delimiter = "/" + } +} + +func (d *directoryDownloader) getLocalPath(key string) (string, error) { + keyprefix := d.in.KeyPrefix + if keyprefix != "" && !strings.HasSuffix(keyprefix, d.in.S3Delimiter) { + keyprefix = keyprefix + d.in.S3Delimiter + } + path := filepath.Join(d.in.Destination, strings.ReplaceAll(strings.TrimPrefix(key, keyprefix), d.in.S3Delimiter, string(os.PathSeparator))) + relPath, err := filepath.Rel(d.in.Destination, path) + if err != nil { + return "", err + } + if relPath == "." || strings.Contains(relPath, "..") { + return "", fmt.Errorf("resolved local path %s is outside of destination %s", path, d.in.Destination) + } + + return path, nil +} + +func (d *directoryDownloader) downloadObject(ctx context.Context, ch chan objectEntry) { + defer d.wg.Done() + for { + data, ok := <-ch + if !ok { + break + } + if d.getErr() != nil { + break + } + + input := &GetObjectInput{ + Bucket: d.in.Bucket, + Key: data.key, + } + if d.in.Callback != nil { + d.in.Callback.UpdateRequest(input) + } + out, err := d.c.GetObject(ctx, input) + if err != nil { + d.setErr(fmt.Errorf("error when downloading object %s: %v", data.key, err)) + break + } + + err = os.MkdirAll(filepath.Dir(data.path), os.ModePerm) + if err != nil { + d.setErr(fmt.Errorf("error when creating directory for file %s: %v", data.path, err)) + break + } + file, err := os.Create(data.path) + if err != nil { + d.setErr(fmt.Errorf("error when creating file %s: %v", data.path, err)) + break + } + _, err = io.Copy(file, out.Body) + if err != nil { + d.setErr(fmt.Errorf("error when writing to local file %s: %v", data.path, err)) + os.Remove(data.path) + break + } + d.incrObjectsDownloaded(1) + } +} + +func (d *directoryDownloader) incrObjectsDownloaded(n int) { + d.mu.Lock() + defer d.mu.Unlock() + + d.objectsDownloaded += n +} + +func (d *directoryDownloader) setErr(err error) { + d.mu.Lock() + defer d.mu.Unlock() + + d.err = err +} + +func (d *directoryDownloader) getErr() error { + d.mu.Lock() + defer d.mu.Unlock() + + return d.err +} diff --git a/feature/s3/transfermanager/api_op_DownloadDirectory_integ_test.go b/feature/s3/transfermanager/api_op_DownloadDirectory_integ_test.go new file mode 100644 index 00000000000..1bfc71b29de --- /dev/null +++ b/feature/s3/transfermanager/api_op_DownloadDirectory_integ_test.go @@ -0,0 +1,44 @@ +//go:build integration +// +build integration + +package transfermanager + +import ( + "testing" +) + +func TestInteg_DownloadDirectory(t *testing.T) { + cases := map[string]downloadDirectoryTestData{ + "multi objects with prefix": { + ObjectsSize: map[string]int64{ + "oii/bar": 2 * 1024 * 1024, + "oiibaz/zoo": 10 * 1024 * 1024, + "oii/baz/zoo": 10 * 1024 * 1024, + "oi": 20 * 1024 * 1024, + }, + KeyPrefix: "oii", + ExpectObjectsDownloaded: 3, + ExpectFiles: []string{"bar", "oiibaz/zoo", "baz/zoo"}, + }, + "multi file with prefix and custom delimiter": { + ObjectsSize: map[string]int64{ + "yee#bar": 2 * 1024 * 1024, + "yee#baz#": 0, + "yee#baz#zoo": 10 * 1024 * 1024, + "yee#oii@zoo": 10 * 1024 * 1024, + "yee#yee#..#bla": 2 * 1024 * 1024, + "ye": 20 * 1024 * 1024, + }, + KeyPrefix: "yee#", + Delimiter: "#", + ExpectObjectsDownloaded: 4, + ExpectFiles: []string{"bar", "baz/zoo", "oii@zoo", "bla"}, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + testDownloadDirectory(t, setupMetadata.Buckets.Source.Name, c) + }) + } +} diff --git a/feature/s3/transfermanager/api_op_UploadDirectory.go b/feature/s3/transfermanager/api_op_UploadDirectory.go index e4513ac67ea..ae95c9a3c5a 100644 --- a/feature/s3/transfermanager/api_op_UploadDirectory.go +++ b/feature/s3/transfermanager/api_op_UploadDirectory.go @@ -315,22 +315,22 @@ func (u *directoryUploader) uploadFile(ctx context.Context, ch chan fileEntry) { f, err := os.Open(data.path) if err != nil { u.setErr(fmt.Errorf("error when opening file %s: %v", data.path, err)) - } else { - input := &PutObjectInput{ - Bucket: u.in.Bucket, - Key: data.key, - Body: f, - } - if u.in.Callback != nil { - u.in.Callback.UpdateRequest(input) - } - _, err := u.c.PutObject(ctx, input) - if err != nil { - u.setErr(fmt.Errorf("error when uploading file %s: %v", data.path, err)) - } else { - u.incrFilesUploaded(1) - } + break + } + input := &PutObjectInput{ + Bucket: u.in.Bucket, + Key: data.key, + Body: f, + } + if u.in.Callback != nil { + u.in.Callback.UpdateRequest(input) + } + _, err = u.c.PutObject(ctx, input) + if err != nil { + u.setErr(fmt.Errorf("error when uploading file %s: %v", data.path, err)) + break } + u.incrFilesUploaded(1) } } diff --git a/feature/s3/transfermanager/api_op_UploadDirectory_integ_test.go b/feature/s3/transfermanager/api_op_UploadDirectory_integ_test.go index e32fad79df1..cff128ec93b 100644 --- a/feature/s3/transfermanager/api_op_UploadDirectory_integ_test.go +++ b/feature/s3/transfermanager/api_op_UploadDirectory_integ_test.go @@ -40,7 +40,7 @@ func TestInteg_UploadDirectory(t *testing.T) { ExpectFilesUploaded: 3, ExpectKeys: []string{"bla/foo", "bla/to/bar", "bla/to/the/baz"}, }, - "multi file recursive with prefix and custome delimiter": { + "multi file recursive with prefix and custom delimiter": { FilesSize: map[string]int64{ "foo": 2 * 1024 * 1024, "to/bar": 10 * 1024 * 1024, diff --git a/feature/s3/transfermanager/download_directory_test.go b/feature/s3/transfermanager/download_directory_test.go new file mode 100644 index 00000000000..9f49fe3a908 --- /dev/null +++ b/feature/s3/transfermanager/download_directory_test.go @@ -0,0 +1,511 @@ +package transfermanager + +import ( + "bytes" + "context" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "reflect" + "runtime" + "sort" + "strings" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + s3testing "github.com/aws/aws-sdk-go-v2/feature/s3/transfermanager/internal/testing" + "github.com/aws/aws-sdk-go-v2/service/s3" + s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" +) + +type objectkeyFilter struct { + keyword string +} + +func (of *objectkeyFilter) FilterObject(object s3types.Object) bool { + if strings.Contains(aws.ToString(object.Key), of.keyword) { + return false + } + return true +} + +type objectkeyCallback struct { + keyword string +} + +func (oc *objectkeyCallback) UpdateRequest(in *GetObjectInput) { + if in.Key == oc.keyword { + in.Key = in.Key + "gotyou" + } +} + +func TestDownloadDirectory(t *testing.T) { + _, filename, _, _ := runtime.Caller(0) + root := filepath.Join(filepath.Dir(filename), "testdata") + + cases := map[string]struct { + destination string + keyPrefix string + objectsLists [][]s3types.Object + continuationTokens []string + filter ObjectFilter + s3Delimiter string + callback GetRequestCallback + getobjectFn func(*s3testing.TransferManagerLoggingClient, *s3.GetObjectInput) (*s3.GetObjectOutput, error) + expectTokens []string + expectKeys []string + expectFiles []string + expectErr string + expectObjectsDownloaded int + }{ + "single object": { + destination: "single-object", + objectsLists: [][]s3types.Object{ + { + { + Key: aws.String("foo/bar"), + }, + }, + }, + expectTokens: []string{""}, + expectKeys: []string{"foo/bar"}, + expectFiles: []string{"foo/bar"}, + expectObjectsDownloaded: 1, + }, + "multiple objects": { + destination: "multiple-objects", + objectsLists: [][]s3types.Object{ + { + { + Key: aws.String("foo/bar"), + }, + { + Key: aws.String("baz"), + }, + { + Key: aws.String("foo/zoo/bar"), + }, + { + Key: aws.String("foo/zoo/oii/bababoii"), + }, + }, + }, + expectTokens: []string{""}, + expectKeys: []string{"foo/bar", "baz", "foo/zoo/bar", "foo/zoo/oii/bababoii"}, + expectFiles: []string{"foo/bar", "baz", "foo/zoo/bar", "foo/zoo/oii/bababoii"}, + expectObjectsDownloaded: 4, + }, + "multiple objects paginated": { + destination: "multiple-objects-paginated", + objectsLists: [][]s3types.Object{ + { + { + Key: aws.String("foo/bar"), + }, + { + Key: aws.String("baz"), + }, + }, + { + { + Key: aws.String("foo/zoo/bar"), + }, + { + Key: aws.String("foo/zoo/oii/bababoii"), + }, + }, + { + { + Key: aws.String("foo/zoo/baz"), + }, + { + Key: aws.String("foo/zoo/oii/yee"), + }, + }, + }, + continuationTokens: []string{"token1", "token2"}, + expectTokens: []string{"", "token1", "token2"}, + expectKeys: []string{"foo/bar", "baz", "foo/zoo/bar", "foo/zoo/oii/bababoii", "foo/zoo/baz", "foo/zoo/oii/yee"}, + expectObjectsDownloaded: 6, + }, + "multiple objects containing folder object": { + destination: "multiple-objects-with-folder-object", + objectsLists: [][]s3types.Object{ + { + { + Key: aws.String("foo/bar"), + }, + { + Key: aws.String("baz"), + }, + { + Key: aws.String("foo/zoo/"), + }, + }, + }, + expectTokens: []string{""}, + expectKeys: []string{"foo/bar", "baz"}, + expectFiles: []string{"foo/bar", "baz"}, + expectObjectsDownloaded: 2, + }, + "single object named with keyprefix": { + destination: "single-object-named-with-keyprefix", + objectsLists: [][]s3types.Object{ + { + { + Key: aws.String("a"), + }, + }, + }, + keyPrefix: "a", + expectTokens: []string{""}, + expectKeys: []string{"a"}, + expectFiles: []string{"a"}, + expectObjectsDownloaded: 1, + }, + "multiple objects with keyprefix without delimiter suffix": { + destination: "multiple-objects-with-keyprefix-no-delimiter", + objectsLists: [][]s3types.Object{ + { + { + Key: aws.String("a/"), + }, + { + Key: aws.String("a/b"), + }, + { + Key: aws.String("ad"), + }, + { + Key: aws.String("ab/c"), + }, + { + Key: aws.String("ae"), + }, + }, + }, + keyPrefix: "a", + expectTokens: []string{""}, + expectKeys: []string{"a/b", "ad", "ab/c", "ae"}, + expectFiles: []string{"b", "ad", "ab/c", "ae"}, + expectObjectsDownloaded: 4, + }, + "multiple objects with keyprefix with default delimiter suffix": { + destination: "multiple-objects-with-keyprefix-default-delimiter", + objectsLists: [][]s3types.Object{ + { + { + Key: aws.String("a/"), + }, + { + Key: aws.String("a/b"), + }, + { + Key: aws.String("a/c"), + }, + { + Key: aws.String("ad"), + }, + { + Key: aws.String("ab/c/d"), + }, + { + Key: aws.String("ab/c/e"), + }, + }, + }, + keyPrefix: "a/", + expectTokens: []string{""}, + expectKeys: []string{"a/b", "a/c", "ad", "ab/c/d", "ab/c/e"}, + expectFiles: []string{"b", "c", "ad", "ab/c/d", "ab/c/e"}, + expectObjectsDownloaded: 5, + }, + "multiple objects with keyprefix with customized delimiter suffix": { + destination: "multiple-objects-with-keyprefix-customized-delimiter", + objectsLists: [][]s3types.Object{ + { + { + Key: aws.String("ab/c*d"), + }, + { + Key: aws.String("ab/c/e"), + }, + { + Key: aws.String("ab/c*f*g"), + }, + }, + }, + keyPrefix: "ab/c", + s3Delimiter: "*", + expectTokens: []string{""}, + expectKeys: []string{"ab/c*d", "ab/c/e", "ab/c*f*g"}, + expectFiles: []string{"d", "ab/c/e", "f*g"}, + expectObjectsDownloaded: 3, + }, + "error when path resolved from objects key out of destination scope": { + destination: "error-bucket", + objectsLists: [][]s3types.Object{ + { + { + Key: aws.String("a/"), + }, + { + Key: aws.String("a/b"), + }, + { + Key: aws.String("a/c"), + }, + { + Key: aws.String(filepath.Join("a", "..", "..", "d")), + }, + }, + }, + expectErr: "outside of destination", + }, + "multiple objects with filter applied": { + destination: "multiple-objects-with-filter-applied", + objectsLists: [][]s3types.Object{ + { + { + Key: aws.String("foo/bar"), + }, + { + Key: aws.String("baz"), + }, + { + Key: aws.String("foo/zoo/bar"), + }, + { + Key: aws.String("foo/zoo/oii/bababoii"), + }, + }, + }, + filter: &objectkeyFilter{"bababoii"}, + expectTokens: []string{""}, + expectKeys: []string{"foo/bar", "baz", "foo/zoo/bar"}, + expectFiles: []string{"foo/bar", "baz", "foo/zoo/bar"}, + expectObjectsDownloaded: 3, + }, + "multiple objects with keyprefix and filter": { + destination: "multiple-objects-with-keyprefix-and-filter", + objectsLists: [][]s3types.Object{ + { + { + Key: aws.String("a/"), + }, + { + Key: aws.String("a/b"), + }, + { + Key: aws.String("ad"), + }, + { + Key: aws.String("ab/c"), + }, + { + Key: aws.String("ae"), + }, + }, + }, + keyPrefix: "a", + filter: &objectkeyFilter{"e"}, + expectTokens: []string{""}, + expectKeys: []string{"a/b", "ad", "ab/c"}, + expectFiles: []string{"b", "ad", "ab/c"}, + expectObjectsDownloaded: 3, + }, + "multiple objects with keyprefix and request callback": { + destination: "multiple-objects-with-keyprefix-and-callback", + objectsLists: [][]s3types.Object{ + { + { + Key: aws.String("a/"), + }, + { + Key: aws.String("a/b"), + }, + { + Key: aws.String("ad"), + }, + { + Key: aws.String("ab/c"), + }, + { + Key: aws.String("ae"), + }, + }, + }, + keyPrefix: "a", + callback: &objectkeyCallback{"ad"}, + expectTokens: []string{""}, + expectKeys: []string{"a/b", "adgotyou", "ab/c", "ae"}, + expectFiles: []string{"b", "ad", "ab/c", "ae"}, + expectObjectsDownloaded: 4, + }, + "multiple objects paginated with keyprefix, delimiter, filter and callback": { + destination: "multiple-objects-with-keyprefix-delimiter-filter-callback", + objectsLists: [][]s3types.Object{ + { + { + Key: aws.String("a&"), + }, + { + Key: aws.String("a&b"), + }, + { + Key: aws.String("a@b"), + }, + }, + { + { + Key: aws.String("a&foo&bar"), + }, + { + Key: aws.String("ac"), + }, + { + Key: aws.String("ac@d&e"), + }, + }, + { + { + Key: aws.String("ac/d/unwanted"), + }, + { + Key: aws.String("a&k.b"), + }, + }, + }, + continuationTokens: []string{"token1", "token2"}, + s3Delimiter: "&", + keyPrefix: "a", + filter: &objectkeyFilter{"unwanted"}, + callback: &objectkeyCallback{"a&k.b"}, + expectTokens: []string{"", "token1", "token2"}, + expectKeys: []string{"a&b", "a@b", "a&foo&bar", "ac", "ac@d&e", "a&k.bgotyou"}, + expectFiles: []string{"b", "a@b", "foo&bar", "ac", "ac@d&e", "k.b"}, + expectObjectsDownloaded: 6, + }, + "error when getting object": { + destination: "error-bucket", + objectsLists: [][]s3types.Object{ + { + { + Key: aws.String("foo/bar"), + }, + { + Key: aws.String("baz"), + }, + }, + { + { + Key: aws.String("foo/zoo/bar"), + }, + { + Key: aws.String("foo/zoo/oii/bababoii"), + }, + }, + { + { + Key: aws.String("foo/zoo/baz"), + }, + { + Key: aws.String("foo/zoo/oii/yee"), + }, + }, + }, + continuationTokens: []string{"token1", "token2"}, + getobjectFn: func(c *s3testing.TransferManagerLoggingClient, in *s3.GetObjectInput) (*s3.GetObjectOutput, error) { + if aws.ToString(in.Key) == "foo/zoo/bar" { + return nil, fmt.Errorf("mocking error") + } + return &s3.GetObjectOutput{ + Body: ioutil.NopCloser(bytes.NewReader(c.Data)), + ContentLength: aws.Int64(int64(len(c.Data))), + PartsCount: aws.Int32(c.PartsCount), + ETag: aws.String(etag), + }, nil + }, + expectErr: "mocking error", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + s3Client, params := s3testing.NewDownloadDirectoryClient() + s3Client.ListObjectsData = c.objectsLists + s3Client.ContinuationTokens = c.continuationTokens + if c.getobjectFn == nil { + s3Client.GetObjectFn = s3testing.PartGetObjectFn + } else { + s3Client.GetObjectFn = c.getobjectFn + } + s3Client.Data = make([]byte, 0) + s3Client.PartsCount = 1 + mgr := New(s3Client, Options{}) + + dstPath := filepath.Join(root, c.destination) + defer os.RemoveAll(dstPath) + resp, err := mgr.DownloadDirectory(context.Background(), &DownloadDirectoryInput{ + Bucket: "mock-bucket", + Destination: dstPath, + KeyPrefix: c.keyPrefix, + S3Delimiter: c.s3Delimiter, + Filter: c.filter, + Callback: c.callback, + }) + + if err != nil { + if c.expectErr == "" { + t.Fatalf("expect not error, got %v", err) + } else if e, a := c.expectErr, err.Error(); !strings.Contains(a, e) { + t.Fatalf("expect %s error message to be in %s", e, a) + } + } else if c.expectErr != "" { + t.Fatalf("expect error %s, got none", c.expectErr) + } + if c.expectErr != "" { + return + } + + if e, a := c.expectObjectsDownloaded, resp.ObjectsDownloaded; e != a { + t.Errorf("expect %d objects downloaded, got %d", e, a) + } + + var actualTokens []string + var actualKeys []string + for _, param := range *params { + if input, ok := param.(*s3.ListObjectsV2Input); ok { + actualTokens = append(actualTokens, aws.ToString(input.ContinuationToken)) + } else if input, ok := param.(*s3.GetObjectInput); ok { + actualKeys = append(actualKeys, aws.ToString(input.Key)) + } else { + t.Fatalf("error when casting captured inputs") + } + } + + if e, a := c.expectTokens, actualTokens; !reflect.DeepEqual(e, a) { + t.Errorf("expect continuation tokens to be %v, got %v", e, a) + } + + sort.Strings(actualKeys) + sort.Strings(c.expectKeys) + if e, a := c.expectKeys, actualKeys; !reflect.DeepEqual(e, a) { + t.Errorf("expect downloaded keys to be %v, got %v", e, a) + } + + delimiter := c.s3Delimiter + if delimiter == "" { + delimiter = "/" + } + for _, file := range c.expectFiles { + path := filepath.Join(dstPath, strings.ReplaceAll(file, delimiter, string(os.PathSeparator))) + _, err := os.Stat(path) + if os.IsNotExist(err) { + t.Errorf("expect %s to be downloaded, got none", path) + } + } + }) + } +} diff --git a/feature/s3/transfermanager/internal/testing/client.go b/feature/s3/transfermanager/internal/testing/client.go index c7a31e202d3..66e1c537ed9 100644 --- a/feature/s3/transfermanager/internal/testing/client.go +++ b/feature/s3/transfermanager/internal/testing/client.go @@ -15,6 +15,7 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/s3" + s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" ) var etag = "myetag" @@ -46,7 +47,12 @@ type TransferManagerLoggingClient struct { Etags []string ErrReaders []TestErrReader - index int + + // params for keyprefix download test + ListObjectsData [][]s3types.Object + ContinuationTokens []string + + index int m sync.Mutex @@ -213,6 +219,7 @@ func (c *TransferManagerLoggingClient) GetObject(ctx context.Context, params *s3 c.m.Lock() defer c.m.Unlock() + c.traceOperation("GetObject", params) c.GetObjectInvocations++ if params.Range != nil { @@ -243,6 +250,30 @@ func (c *TransferManagerLoggingClient) HeadObject(ctx context.Context, params *s }, nil } +// ListObjectsV2 is the S3 ListObjectsV2 API +func (c *TransferManagerLoggingClient) ListObjectsV2(ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options)) (*s3.ListObjectsV2Output, error) { + c.m.Lock() + defer c.m.Unlock() + + c.traceOperation("ListObjectsV2", params) + + var nextToken *string + var isTruncated bool + if c.index < len(c.ContinuationTokens) { + nextToken = aws.String(c.ContinuationTokens[c.index]) + isTruncated = true + } + + out := &s3.ListObjectsV2Output{ + Contents: c.ListObjectsData[c.index], + NextContinuationToken: nextToken, + IsTruncated: aws.Bool(isTruncated), + } + c.index++ + + return out, nil +} + // NewUploadLoggingClient returns a new TransferManagerLoggingClient for upload testing. func NewUploadLoggingClient(ignoredOps []string) (*TransferManagerLoggingClient, *[]string, *[]interface{}) { c := &TransferManagerLoggingClient{ @@ -400,3 +431,10 @@ func NewUploadDirectoryClient(ignoredOps []string) (*TransferManagerLoggingClien return c, &c.Params } + +// NewDownloadDirectoryClient returns a new TransferManagerLoggingClient for download directory testing +func NewDownloadDirectoryClient() (*TransferManagerLoggingClient, *[]interface{}) { + c := &TransferManagerLoggingClient{} + + return c, &c.Params +} diff --git a/feature/s3/transfermanager/setup_integ_test.go b/feature/s3/transfermanager/setup_integ_test.go index 924e05f18a5..65045142274 100644 --- a/feature/s3/transfermanager/setup_integ_test.go +++ b/feature/s3/transfermanager/setup_integ_test.go @@ -426,7 +426,6 @@ func testUploadDirectory(t *testing.T, bucket string, testData uploadDirectoryTe if err != nil { t.Fatalf("error when writing test file %s: %v", path, err) } - // defer os.Remove(path) key := strings.Replace(f, "/", delimiter, -1) if testData.KeyPrefix != "" { key = testData.KeyPrefix + delimiter + key @@ -476,7 +475,92 @@ func testUploadDirectory(t *testing.T, bucket string, testData uploadDirectoryTe t.Errorf("no data recorded for object %s", key) } if e, a := expectData, b; !bytes.EqualFold(e, a) { - t.Errorf("expect %s, got %s", e, a) + t.Errorf("for object %s, expect %s, got %s", key, e, a) + } + } +} + +type downloadDirectoryTestData struct { + ObjectsSize map[string]int64 + Delimiter string + KeyPrefix string + ExpectObjectsDownloaded int + ExpectFiles []string + ExpectError string +} + +func testDownloadDirectory(t *testing.T, bucket string, testData downloadDirectoryTestData) { + _, filename, _, _ := runtime.Caller(0) + dst := filepath.Join(filepath.Dir(filename), "testdata", "integ") + defer os.RemoveAll(dst) + + delimiter := testData.Delimiter + if delimiter == "" { + delimiter = "/" + } + keyprefix := testData.KeyPrefix + if keyprefix != "" && !strings.HasSuffix(keyprefix, delimiter) { + keyprefix = keyprefix + delimiter + } + expectFiles := map[string][]byte{} + for key, size := range testData.ObjectsSize { + fileBuf := make([]byte, size) + _, err := rand.Read(fileBuf) + if err != nil { + t.Fatalf("error when mocking test data for object %s", key) + } + _, err = s3Client.PutObject(context.Background(), + &s3.PutObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + Body: bytes.NewReader(fileBuf), + }) + if err != nil { + t.Fatalf("error when putting object %s", key) + } + file := filepath.Join(strings.ReplaceAll(strings.TrimPrefix(key, keyprefix), delimiter, string(os.PathSeparator))) + expectFiles[file] = fileBuf + } + + out, err := s3TransferManagerClient.DownloadDirectory(context.Background(), &DownloadDirectoryInput{ + Bucket: bucket, + Destination: dst, + KeyPrefix: testData.KeyPrefix, + S3Delimiter: testData.Delimiter, + }) + if err != nil { + if len(testData.ExpectError) == 0 { + t.Fatalf("expect no error, got %v", err) + } + if e, a := testData.ExpectError, err.Error(); !strings.Contains(a, e) { + t.Fatalf("expect error to contain %v, got %v", e, a) + } + } else { + if e := testData.ExpectError; len(e) != 0 { + t.Fatalf("expect error: %v, got none", e) + } + } + if len(testData.ExpectError) != 0 { + return + } + + if e, a := testData.ExpectObjectsDownloaded, out.ObjectsDownloaded; e != a { + t.Errorf("expect %d objects downloaded, got %d", e, a) + } + for _, file := range testData.ExpectFiles { + f := strings.ReplaceAll(file, "/", string(os.PathSeparator)) + path := filepath.Join(dst, f) + b, err := os.ReadFile(path) + if err != nil { + t.Fatalf("error when reading downloaded file %s: %v", path, err) + } + expectData, ok := expectFiles[f] + if !ok { + t.Errorf("no data recorded for file %s", path) + continue + } + if e, a := expectData, b; !bytes.EqualFold(e, a) { + t.Errorf("for file %s, expect %s, got %s", f, e, a) } } } diff --git a/feature/s3/transfermanager/upload_directory_test.go b/feature/s3/transfermanager/upload_directory_test.go index 3b38b3c352b..e2a55706d45 100644 --- a/feature/s3/transfermanager/upload_directory_test.go +++ b/feature/s3/transfermanager/upload_directory_test.go @@ -409,10 +409,9 @@ func TestUploadDirectory(t *testing.T) { } else if e, a := c.expectErr, err.Error(); !strings.Contains(a, e) { t.Fatalf("expect %s error message to be in %s", e, a) } - } else { - if c.expectErr != "" { - t.Fatalf("expect error %s, got none", c.expectErr) - } + } else if c.expectErr != "" { + t.Fatalf("expect error %s, got none", c.expectErr) + } if err != nil {