Skip to content

Commit 85f553f

Browse files
committed
adding unit tests
1 parent 6ff87ee commit 85f553f

File tree

2 files changed

+158
-7
lines changed

2 files changed

+158
-7
lines changed

src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#include <aws/core/utils/memory/stl/AWSStreamFwd.h>
1616
#include <aws/core/utils/memory/stl/AWSStringStream.h>
1717
#include <aws/core/utils/stream/PreallocatedStreamBuf.h>
18+
#include <aws/common/byte_order.h>
19+
#include <cstring>
1820
#include <aws/crt/checksum/CRC.h>
1921
#include <aws/s3/S3Client.h>
2022
#include <aws/s3/model/AbortMultipartUploadRequest.h>
@@ -24,8 +26,6 @@
2426
#include <aws/s3/model/ListObjectsV2Request.h>
2527
#include <aws/transfer/TransferManager.h>
2628
#include <sys/stat.h>
27-
#include <aws/common/byte_order.h>
28-
#include <cstring>
2929

3030
#include <algorithm>
3131
#include <fstream>
@@ -409,12 +409,9 @@ namespace Aws
409409

410410
const auto fullObjectHashCalculator = [](const std::shared_ptr<TransferHandle>& handle, bool isRetry, S3::Model::ChecksumAlgorithm algorithm) -> std::shared_ptr<Aws::Utils::Crypto::Hash> {
411411
if (handle->GetChecksum().empty() && !isRetry) {
412-
if (algorithm == S3::Model::ChecksumAlgorithm::CRC32) {
412+
if (algorithm == S3::Model::ChecksumAlgorithm::CRC32 || algorithm == S3::Model::ChecksumAlgorithm::CRC32C) {
413413
return Aws::MakeShared<Aws::Utils::Crypto::CRC32>("TransferManager");
414414
}
415-
if (algorithm == S3::Model::ChecksumAlgorithm::CRC32C) {
416-
return Aws::MakeShared<Aws::Utils::Crypto::CRC32C>("TransferManager");
417-
}
418415
if (algorithm == S3::Model::ChecksumAlgorithm::SHA1) {
419416
return Aws::MakeShared<Aws::Utils::Crypto::Sha1>("TransferManager");
420417
}
@@ -1223,9 +1220,21 @@ namespace Aws
12231220
Aws::IOStream* bufferStream = partState->GetDownloadPartStream();
12241221
assert(bufferStream);
12251222

1226-
if (m_transferConfig.validateChecksums) { handle->AddChecksumForPart(bufferStream, partState); }
12271223
Aws::String errMsg{handle->WritePartToDownloadStream(bufferStream, partState->GetRangeBegin())};
12281224
if (errMsg.empty()) {
1225+
if (!outcome.GetResult().GetChecksumCRC32().empty()) {
1226+
partState->SetChecksum(outcome.GetResult().GetChecksumCRC32());
1227+
} else if (!outcome.GetResult().GetChecksumCRC32C().empty()) {
1228+
partState->SetChecksum(outcome.GetResult().GetChecksumCRC32C());
1229+
} else if (!outcome.GetResult().GetChecksumCRC64NVME().empty()) {
1230+
partState->SetChecksum(outcome.GetResult().GetChecksumCRC64NVME());
1231+
} else if (!outcome.GetResult().GetChecksumSHA1().empty()) {
1232+
partState->SetChecksum(outcome.GetResult().GetChecksumSHA1());
1233+
} else if (!outcome.GetResult().GetChecksumSHA256().empty()) {
1234+
partState->SetChecksum(outcome.GetResult().GetChecksumSHA256());
1235+
} else {
1236+
if (m_transferConfig.validateChecksums) { handle->AddChecksumForPart(bufferStream, partState); }
1237+
}
12291238
handle->ChangePartToCompleted(partState, outcome.GetResult().GetETag());
12301239
} else {
12311240
Aws::Client::AWSError<Aws::S3::S3Errors> error(Aws::S3::S3Errors::INTERNAL_FAILURE,

tests/aws-cpp-sdk-transfer-unit-tests/TransferUnitTests.cpp

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
#include <aws/s3/S3Client.h>
55
#include <aws/s3/model/GetObjectRequest.h>
66
#include <aws/s3/model/GetObjectResult.h>
7+
#include <aws/s3/model/HeadObjectRequest.h>
8+
#include <aws/s3/model/HeadObjectResult.h>
79
#include <aws/transfer/TransferManager.h>
810
#include <aws/testing/AwsTestHelpers.h>
911
#include <aws/testing/MemoryTesting.h>
1012
#include <sstream>
13+
#include <fstream>
1114

1215
using namespace Aws;
1316
using namespace Aws::S3;
@@ -36,11 +39,87 @@ class MockS3Client : public S3Client {
3639
}
3740
};
3841

42+
class MockMultipartS3Client : public S3Client {
43+
public:
44+
Aws::String FULL_OBJECT_CHECKSUM; //"SBi/K+1ooBg="
45+
MockMultipartS3Client(Aws::String expected_checksum) : S3Client() {
46+
FULL_OBJECT_CHECKSUM = expected_checksum;
47+
};
48+
49+
HeadObjectOutcome HeadObject(const HeadObjectRequest&) const override {
50+
HeadObjectResult result;
51+
result.SetContentLength(78643200);
52+
result.SetChecksumCRC64NVME(FULL_OBJECT_CHECKSUM);
53+
result.SetChecksumType(Aws::S3::Model::ChecksumType::FULL_OBJECT); // This is key!
54+
result.SetETag("\"test-etag-12345\""); // Add ETag
55+
return HeadObjectOutcome(std::move(result));
56+
}
57+
58+
GetObjectOutcome GetObject(const GetObjectRequest& request) const override {
59+
GetObjectResult result;
60+
61+
const uint64_t totalSize = 78643200;
62+
const uint64_t partSize = 5242880;
63+
const std::vector<std::string> checksums = {
64+
"wAQOkgd/LJk=", "zfmsUj6AZfs=", "oyENjcGDHcY=", "wAQOkgd/LJk=", "zfmsUj6AZfs=",
65+
"oyENjcGDHcY=", "wAQOkgd/LJk=", "zfmsUj6AZfs=", "oyENjcGDHcY=", "wAQOkgd/LJk=",
66+
"zfmsUj6AZfs=", "oyENjcGDHcY=", "wAQOkgd/LJk=", "zfmsUj6AZfs=", "oyENjcGDHcY="
67+
};
68+
69+
if (request.RangeHasBeenSet()) {
70+
auto range = request.GetRange();
71+
size_t dashPos = range.find('-');
72+
uint64_t start = std::stoull(range.substr(6, dashPos - 6));
73+
uint64_t end = std::stoull(range.substr(dashPos + 1));
74+
uint64_t size = end - start + 1;
75+
76+
int partNum = start / partSize;
77+
if (partNum < 15) {
78+
result.SetContentRange("bytes " + std::to_string(start) + "-" + std::to_string(end) + "/" + std::to_string(totalSize));
79+
result.SetChecksumCRC64NVME(checksums[partNum]);
80+
result.SetContentLength(size);
81+
result.SetETag("\"part-etag-" + std::to_string(partNum) + "\"");
82+
83+
// Call the response stream factory if provided
84+
if (request.GetResponseStreamFactory()) {
85+
auto responseStream = request.GetResponseStreamFactory()();
86+
87+
// Write part-specific data to the response stream
88+
char partChar = 'A' + (partNum % 3);
89+
for (uint64_t i = 0; i < size; ++i) {
90+
responseStream->put(partChar);
91+
}
92+
responseStream->flush();
93+
94+
// Simulate data received callback to track bytes transferred
95+
if (request.GetDataReceivedEventHandler()) {
96+
request.GetDataReceivedEventHandler()(nullptr, nullptr, size);
97+
}
98+
99+
result.ReplaceBody(responseStream);
100+
} else {
101+
// Fallback for non-factory requests
102+
auto stream = Aws::New<std::stringstream>(ALLOCATION_TAG);
103+
char partChar = 'A' + (partNum % 3);
104+
for (uint64_t i = 0; i < size; ++i) {
105+
stream->put(partChar);
106+
}
107+
stream->seekg(0, std::ios::beg);
108+
result.ReplaceBody(stream);
109+
}
110+
}
111+
}
112+
113+
return GetObjectOutcome(std::move(result));
114+
}
115+
};
116+
39117
class TransferUnitTest : public testing::Test {
40118
protected:
41119
void SetUp() override {
42120
executor = Aws::MakeShared<PooledThreadExecutor>(ALLOCATION_TAG, 1);
43121
mockS3Client = Aws::MakeShared<MockS3Client>(ALLOCATION_TAG);
122+
mockMultipartS3Client = Aws::MakeShared<MockMultipartS3Client>(ALLOCATION_TAG, "SBi/K+1ooBg=");
44123
}
45124

46125
static void SetUpTestSuite() {
@@ -53,6 +132,7 @@ class TransferUnitTest : public testing::Test {
53132

54133
std::shared_ptr<PooledThreadExecutor> executor;
55134
std::shared_ptr<MockS3Client> mockS3Client;
135+
std::shared_ptr<MockMultipartS3Client> mockMultipartS3Client;
56136
static SDKOptions _options;
57137
};
58138

@@ -73,3 +153,65 @@ TEST_F(TransferUnitTest, ContentValidationShouldFail) {
73153

74154
EXPECT_EQ(TransferStatus::FAILED, handle->GetStatus());
75155
}
156+
157+
TEST_F(TransferUnitTest, MultipartDownloadTest) {
158+
TransferManagerConfiguration config(executor.get());
159+
config.s3Client = mockMultipartS3Client;
160+
config.bufferSize = 5242880; // 5MB to ensure multipart
161+
auto transferManager = TransferManager::Create(config);
162+
163+
// Create a temporary file for download since multipart needs seekable stream
164+
std::string tempFile = "/tmp/test_download_" + std::to_string(rand());
165+
auto createStreamFn = [tempFile]() -> Aws::IOStream* {
166+
return Aws::New<Aws::FStream>(ALLOCATION_TAG, tempFile.c_str(),
167+
std::ios_base::out | std::ios_base::in |
168+
std::ios_base::binary | std::ios_base::trunc);
169+
};
170+
171+
// Download the full 78MB file
172+
auto handle = transferManager->DownloadFile("test-bucket", "test-key", createStreamFn);
173+
handle->WaitUntilFinished();
174+
175+
// Test multipart download functionality - should PASS with correct checksum
176+
EXPECT_TRUE(handle->IsMultipart());
177+
EXPECT_EQ(78643200u, handle->GetBytesTotalSize());
178+
EXPECT_EQ(15u, handle->GetCompletedParts().size());
179+
EXPECT_EQ(0u, handle->GetFailedParts().size());
180+
EXPECT_EQ(0u, handle->GetPendingParts().size());
181+
EXPECT_EQ(TransferStatus::COMPLETED, handle->GetStatus()); // Should PASS
182+
183+
// Clean up
184+
std::remove(tempFile.c_str());
185+
}
186+
187+
TEST_F(TransferUnitTest, MultipartDownloadTest_Fail) {
188+
TransferManagerConfiguration config(executor.get());
189+
auto mockFailClient = Aws::MakeShared<MockMultipartS3Client>(ALLOCATION_TAG, "WRONG_CHECKSUM=");
190+
config.s3Client = mockFailClient;
191+
config.bufferSize = 5242880; // 5MB to ensure multipart
192+
auto transferManager = TransferManager::Create(config);
193+
194+
// Create a temporary file for download since multipart needs seekable stream
195+
std::string tempFile = "/tmp/test_download_" + std::to_string(rand());
196+
auto createStreamFn = [tempFile]() -> Aws::IOStream* {
197+
return Aws::New<Aws::FStream>(ALLOCATION_TAG, tempFile.c_str(),
198+
std::ios_base::out | std::ios_base::in |
199+
std::ios_base::binary | std::ios_base::trunc);
200+
};
201+
202+
// Download the full 78MB file
203+
auto handle = transferManager->DownloadFile("test-bucket", "test-key", createStreamFn);
204+
handle->WaitUntilFinished();
205+
206+
// Test multipart download functionality - should FAIL with wrong checksum
207+
EXPECT_TRUE(handle->IsMultipart());
208+
EXPECT_EQ(78643200u, handle->GetBytesTotalSize());
209+
EXPECT_EQ(15u, handle->GetCompletedParts().size());
210+
EXPECT_EQ(0u, handle->GetFailedParts().size());
211+
EXPECT_EQ(0u, handle->GetPendingParts().size());
212+
EXPECT_EQ(TransferStatus::FAILED, handle->GetStatus()); // Should FAIL due to wrong checksum
213+
EXPECT_EQ("Full-object checksum validation failed", handle->GetLastError().GetMessage());
214+
215+
// Clean up
216+
std::remove(tempFile.c_str());
217+
}

0 commit comments

Comments
 (0)