diff --git a/src/aws-cpp-sdk-core/include/aws/core/AmazonWebServiceRequest.h b/src/aws-cpp-sdk-core/include/aws/core/AmazonWebServiceRequest.h index 2b2a9e13c7d7..a16c939a31cb 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/AmazonWebServiceRequest.h +++ b/src/aws-cpp-sdk-core/include/aws/core/AmazonWebServiceRequest.h @@ -35,6 +35,10 @@ namespace Aws typedef std::function RequestRetryHandler; typedef std::function RequestSignedHandler; + struct RetryContext { + std::shared_ptr>> m_requestHash; + }; + /** * Base level abstraction for all modeled AWS requests */ @@ -222,7 +226,11 @@ namespace Aws */ Aws::Set GetUserAgentFeatures() const { return m_userAgentFeatures; } - inline virtual bool RequestChecksumRequired() const { return false; } + inline virtual bool RequestChecksumRequired() const { return false; } + + RetryContext GetRetryContext() const { return m_retryContext; } + + void SetRetryContext(const RetryContext& context) const { m_retryContext = context; } protected: /** * Default does nothing. Override this to convert what would otherwise be the payload of the @@ -242,6 +250,7 @@ namespace Aws RequestRetryHandler m_requestRetryHandler; mutable std::shared_ptr m_serviceSpecificParameters; mutable Aws::Set m_userAgentFeatures; + mutable Aws::RetryContext m_retryContext; }; } // namespace Aws diff --git a/src/aws-cpp-sdk-core/include/smithy/client/features/ChecksumInterceptor.h b/src/aws-cpp-sdk-core/include/smithy/client/features/ChecksumInterceptor.h index a860642bf573..627a16757b2a 100644 --- a/src/aws-cpp-sdk-core/include/smithy/client/features/ChecksumInterceptor.h +++ b/src/aws-cpp-sdk-core/include/smithy/client/features/ChecksumInterceptor.h @@ -16,6 +16,7 @@ #include #include #include +#include #include @@ -82,49 +83,10 @@ class ChecksumInterceptor : public smithy::interceptor::Interceptor { // For non-streaming payload, the resolved checksum location is always header. // For streaming payload, the resolved checksum location depends on whether it is an unsigned payload, we let // AwsAuthSigner decide it. - if (request.IsStreaming() && checksumValueAndAlgorithmProvided) { - addChecksumFeatureForChecksumName(checksumAlgorithmName, request); - const auto hash = Aws::MakeShared(AWS_SMITHY_CLIENT_CHECKSUM, checksumHeader->second); - httpRequest->SetRequestHash(checksumAlgorithmName, hash); - } else if (checksumValueAndAlgorithmProvided) { - httpRequest->SetHeaderValue(checksumType, checksumHeader->second); - } else if (checksumAlgorithmName == "crc64nvme") { - request.AddUserAgentFeature(Aws::Client::UserAgentFeature::FLEXIBLE_CHECKSUMS_REQ_CRC64); - if (request.IsStreaming()) { - httpRequest->SetRequestHash(checksumAlgorithmName, Aws::MakeShared(AWS_SMITHY_CLIENT_CHECKSUM)); - } else { - httpRequest->SetHeaderValue(checksumType, HashingUtils::Base64Encode(HashingUtils::CalculateCRC64(*(GetBodyStream(request))))); - } - } else if (checksumAlgorithmName == "crc32") { - request.AddUserAgentFeature(Aws::Client::UserAgentFeature::FLEXIBLE_CHECKSUMS_REQ_CRC32); - if (request.IsStreaming()) { - httpRequest->SetRequestHash(checksumAlgorithmName, Aws::MakeShared(AWS_SMITHY_CLIENT_CHECKSUM)); - } else { - httpRequest->SetHeaderValue(checksumType, HashingUtils::Base64Encode(HashingUtils::CalculateCRC32(*(GetBodyStream(request))))); - } - } else if (checksumAlgorithmName == "crc32c") { - request.AddUserAgentFeature(Aws::Client::UserAgentFeature::FLEXIBLE_CHECKSUMS_REQ_CRC32C); - if (request.IsStreaming()) { - httpRequest->SetRequestHash(checksumAlgorithmName, Aws::MakeShared(AWS_SMITHY_CLIENT_CHECKSUM)); - } else { - httpRequest->SetHeaderValue(checksumType, HashingUtils::Base64Encode(HashingUtils::CalculateCRC32C(*(GetBodyStream(request))))); - } - } else if (checksumAlgorithmName == "sha256") { - request.AddUserAgentFeature(Aws::Client::UserAgentFeature::FLEXIBLE_CHECKSUMS_REQ_SHA256); - if (request.IsStreaming()) { - httpRequest->SetRequestHash(checksumAlgorithmName, Aws::MakeShared(AWS_SMITHY_CLIENT_CHECKSUM)); - } else { - httpRequest->SetHeaderValue(checksumType, HashingUtils::Base64Encode(HashingUtils::CalculateSHA256(*(GetBodyStream(request))))); - } - } else if (checksumAlgorithmName == "sha1") { - request.AddUserAgentFeature(Aws::Client::UserAgentFeature::FLEXIBLE_CHECKSUMS_REQ_SHA1); - if (request.IsStreaming()) { - httpRequest->SetRequestHash(checksumAlgorithmName, Aws::MakeShared(AWS_SMITHY_CLIENT_CHECKSUM)); - } else { - httpRequest->SetHeaderValue(checksumType, HashingUtils::Base64Encode(HashingUtils::CalculateSHA1(*(GetBodyStream(request))))); - } + if (checksumValueAndAlgorithmProvided) { + handleProvidedChecksum(request, httpRequest, checksumAlgorithmName, checksumType, checksumHeader->second); } else { - AWS_LOGSTREAM_WARN(AWS_SMITHY_CLIENT_CHECKSUM, "Checksum algorithm: " << checksumAlgorithmName << "is not supported by SDK."); + calculateAndSetChecksum(request, httpRequest, checksumAlgorithmName, checksumType); } } } @@ -133,30 +95,7 @@ class ChecksumInterceptor : public smithy::interceptor::Interceptor { if ((!request.GetResponseChecksumAlgorithmNames().empty() && m_responseChecksumValidation == ResponseChecksumValidation::WHEN_SUPPORTED) || request.ShouldValidateResponseChecksum()) { - for (const Aws::String& responseChecksumAlgorithmName : request.GetResponseChecksumAlgorithmNames()) { - const auto responseChecksum = Aws::Utils::StringUtils::ToLower(responseChecksumAlgorithmName.c_str()); - if (responseChecksum == "crc32c") { - std::shared_ptr crc32c = Aws::MakeShared(AWS_SMITHY_CLIENT_CHECKSUM); - httpRequest->AddResponseValidationHash("crc32c", crc32c); - } else if (responseChecksum == "crc32") { - std::shared_ptr crc32 = Aws::MakeShared(AWS_SMITHY_CLIENT_CHECKSUM); - httpRequest->AddResponseValidationHash("crc32", crc32); - } else if (responseChecksum == "sha1") { - std::shared_ptr sha1 = Aws::MakeShared(AWS_SMITHY_CLIENT_CHECKSUM); - httpRequest->AddResponseValidationHash("sha1", sha1); - } else if (responseChecksum == "sha256") { - std::shared_ptr sha256 = Aws::MakeShared(AWS_SMITHY_CLIENT_CHECKSUM); - httpRequest->AddResponseValidationHash("sha256", sha256); - } else if (responseChecksum == "crc64nvme") { - std::shared_ptr crc64 = Aws::MakeShared(AWS_SMITHY_CLIENT_CHECKSUM); - httpRequest->AddResponseValidationHash("crc64nvme", crc64); - } else { - AWS_LOGSTREAM_WARN(AWS_SMITHY_CLIENT_CHECKSUM, - "Checksum algorithm: " << responseChecksum << " is not supported in validating response body yet."); - } - } - // we have to set the checksum mode to enabled if it was not previously - httpRequest->SetHeaderValue("x-amz-checksum-mode", "enabled"); + SetResponseChecksum(request, httpRequest); } return httpRequest; @@ -233,6 +172,94 @@ class ChecksumInterceptor : public smithy::interceptor::Interceptor { } } + void handleProvidedChecksum(const Aws::AmazonWebServiceRequest& request, std::shared_ptr httpRequest, + const Aws::String& algorithm, const Aws::String& checksumType, const Aws::String& checksumValue) { + if (request.IsStreaming()) { + addChecksumFeatureForChecksumName(algorithm, request); + if (httpRequest->GetRequestHash().second == nullptr) { + auto hash = Aws::MakeShared(AWS_SMITHY_CLIENT_CHECKSUM, checksumValue); + httpRequest->SetRequestHash(algorithm, hash); + } + } else { + httpRequest->SetHeaderValue(checksumType, checksumValue); + } + } + + void calculateAndSetChecksum(const Aws::AmazonWebServiceRequest& request, std::shared_ptr httpRequest, + const Aws::String& algorithm, const Aws::String& checksumType) { + static const Aws::Array, 5> algorithmMap = {{ + std::make_pair("crc64nvme", ChecksumHandler{ + []() { return Aws::MakeShared(AWS_SMITHY_CLIENT_CHECKSUM); }, + [](Aws::IOStream& stream) { return HashingUtils::Base64Encode(HashingUtils::CalculateCRC64(stream)); }, + Aws::Client::UserAgentFeature::FLEXIBLE_CHECKSUMS_REQ_CRC64}), + std::make_pair("crc32", ChecksumHandler{ + []() { return Aws::MakeShared(AWS_SMITHY_CLIENT_CHECKSUM); }, + [](Aws::IOStream& stream) { return HashingUtils::Base64Encode(HashingUtils::CalculateCRC32(stream)); }, + Aws::Client::UserAgentFeature::FLEXIBLE_CHECKSUMS_REQ_CRC32}), + std::make_pair("crc32c", ChecksumHandler{ + []() { return Aws::MakeShared(AWS_SMITHY_CLIENT_CHECKSUM); }, + [](Aws::IOStream& stream) { return HashingUtils::Base64Encode(HashingUtils::CalculateCRC32C(stream)); }, + Aws::Client::UserAgentFeature::FLEXIBLE_CHECKSUMS_REQ_CRC32C}), + std::make_pair("sha256", ChecksumHandler{ + []() { return Aws::MakeShared(AWS_SMITHY_CLIENT_CHECKSUM); }, + [](Aws::IOStream& stream) { return HashingUtils::Base64Encode(HashingUtils::CalculateSHA256(stream)); }, + Aws::Client::UserAgentFeature::FLEXIBLE_CHECKSUMS_REQ_SHA256}), + std::make_pair("sha1", ChecksumHandler{ + []() { return Aws::MakeShared(AWS_SMITHY_CLIENT_CHECKSUM); }, + [](Aws::IOStream& stream) { return HashingUtils::Base64Encode(HashingUtils::CalculateSHA1(stream)); }, + Aws::Client::UserAgentFeature::FLEXIBLE_CHECKSUMS_REQ_SHA1}) + }}; + + const auto it = find_if(algorithmMap.begin(), algorithmMap.end(), [&](const std::pair &pair) { return algorithm == pair.first; }); + if (it == algorithmMap.end()) { + AWS_LOGSTREAM_WARN(AWS_SMITHY_CLIENT_CHECKSUM, "Checksum algorithm: " << algorithm << " is not supported by SDK."); + return; + } + + request.AddUserAgentFeature(it->second.userAgentFeature); + + if (request.IsStreaming()) { + if (httpRequest->GetRequestHash().second == nullptr) { + httpRequest->SetRequestHash(algorithm, it->second.createHash()); + } + } else { + httpRequest->SetHeaderValue(checksumType, it->second.calculateHash(*GetBodyStream(request))); + } + } + + void SetResponseChecksum(const Aws::AmazonWebServiceRequest& request, std::shared_ptr httpRequest) { + for (const Aws::String& responseChecksumAlgorithmName : request.GetResponseChecksumAlgorithmNames()) { + const auto responseChecksum = Aws::Utils::StringUtils::ToLower(responseChecksumAlgorithmName.c_str()); + if (responseChecksum == "crc32c") { + std::shared_ptr crc32c = Aws::MakeShared(AWS_SMITHY_CLIENT_CHECKSUM); + httpRequest->AddResponseValidationHash("crc32c", crc32c); + } else if (responseChecksum == "crc32") { + std::shared_ptr crc32 = Aws::MakeShared(AWS_SMITHY_CLIENT_CHECKSUM); + httpRequest->AddResponseValidationHash("crc32", crc32); + } else if (responseChecksum == "sha1") { + std::shared_ptr sha1 = Aws::MakeShared(AWS_SMITHY_CLIENT_CHECKSUM); + httpRequest->AddResponseValidationHash("sha1", sha1); + } else if (responseChecksum == "sha256") { + std::shared_ptr sha256 = Aws::MakeShared(AWS_SMITHY_CLIENT_CHECKSUM); + httpRequest->AddResponseValidationHash("sha256", sha256); + } else if (responseChecksum == "crc64nvme") { + std::shared_ptr crc64 = Aws::MakeShared(AWS_SMITHY_CLIENT_CHECKSUM); + httpRequest->AddResponseValidationHash("crc64nvme", crc64); + } else { + AWS_LOGSTREAM_WARN(AWS_SMITHY_CLIENT_CHECKSUM, + "Checksum algorithm: " << responseChecksum << " is not supported in validating response body yet."); + } + } + // we have to set the checksum mode to enabled if it was not previously + httpRequest->SetHeaderValue("x-amz-checksum-mode", "enabled"); + } + + struct ChecksumHandler { + std::function()> createHash; + std::function calculateHash; + Aws::Client::UserAgentFeature userAgentFeature; + }; + RequestChecksumCalculation m_requestChecksumCalculation{RequestChecksumCalculation::WHEN_SUPPORTED}; ResponseChecksumValidation m_responseChecksumValidation{ResponseChecksumValidation::WHEN_SUPPORTED}; }; diff --git a/src/aws-cpp-sdk-core/source/client/AWSClient.cpp b/src/aws-cpp-sdk-core/source/client/AWSClient.cpp index 80ece5c90515..18bdfd60e474 100644 --- a/src/aws-cpp-sdk-core/source/client/AWSClient.cpp +++ b/src/aws-cpp-sdk-core/source/client/AWSClient.cpp @@ -389,6 +389,17 @@ HttpResponseOutcome AWSClient::AttemptExhaustively(const Aws::Http::URI& uri, { newUri.SetAuthority(newEndpoint); } + + // Save checksum information from the original request if we haven't already - safe to assume that the checksum has been finalized, since we have sent and received a response + RetryContext context = request.GetRetryContext(); + if (context.m_requestHash == nullptr) { + auto originalRequestHash = httpRequest->GetRequestHash(); + if (originalRequestHash.second != nullptr) { + context.m_requestHash = Aws::MakeShared>>(AWS_CLIENT_LOG_TAG, originalRequestHash); + request.SetRetryContext(context); + } + } + httpRequest = CreateHttpRequest(newUri, method, request.GetResponseStreamFactory()); httpRequest->SetHeaderValue(Http::SDK_INVOCATION_ID_HEADER, invocationId); @@ -920,6 +931,13 @@ void AWSClient::BuildHttpRequest(const Aws::AmazonWebServiceRequest& request, co httpRequest->SetContinueRequestHandle(request.GetContinueRequestHandler()); httpRequest->SetServiceSpecificParameters(request.GetServiceSpecificParameters()); request.AddQueryStringParameters(httpRequest->GetUri()); + + // check for retry context, if present use it + RetryContext context = request.GetRetryContext(); + if (context.m_requestHash != nullptr) { + const auto hash = Aws::MakeShared(smithy::client::AWS_SMITHY_CLIENT_CHECKSUM, HashingUtils::Base64Encode(context.m_requestHash->second->GetHash().GetResult())); + httpRequest->SetRequestHash(context.m_requestHash->first, hash); + } } Aws::String AWSClient::GeneratePresignedUrl(const Aws::Http::URI& uri, Aws::Http::HttpMethod method, long long expirationInSeconds, const std::shared_ptr serviceSpecificParameter)