Skip to content

Commit faa33a6

Browse files
authored
Reuse checksum on streaming retries + refactor ChecksumInterceptor (#3500)
* Reuse checksum on streaming request retries * Fix precalculated hash allocationtag namespace * ChecksumInterceptor.h - Refactor ModifyBeforeSigning logic for readability/scalability * ChecksumInterceptor.h - use memory safe Aws::Array instead of Aws::UnorderedMap * AmazonWebServiceRequest - RetryContext to be private, and contain a pointer to the request hash pair * Remove explicit nullptr initialization
1 parent 7d25b7f commit faa33a6

File tree

3 files changed

+121
-67
lines changed

3 files changed

+121
-67
lines changed

src/aws-cpp-sdk-core/include/aws/core/AmazonWebServiceRequest.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ namespace Aws
3535
typedef std::function<void(const AmazonWebServiceRequest&)> RequestRetryHandler;
3636
typedef std::function<void(const Aws::Http::HttpRequest&)> RequestSignedHandler;
3737

38+
struct RetryContext {
39+
std::shared_ptr<std::pair<Aws::String, std::shared_ptr<Aws::Utils::Crypto::Hash>>> m_requestHash;
40+
};
41+
3842
/**
3943
* Base level abstraction for all modeled AWS requests
4044
*/
@@ -222,7 +226,11 @@ namespace Aws
222226
*/
223227
Aws::Set<Aws::Client::UserAgentFeature> GetUserAgentFeatures() const { return m_userAgentFeatures; }
224228

225-
inline virtual bool RequestChecksumRequired() const { return false; }
229+
inline virtual bool RequestChecksumRequired() const { return false; }
230+
231+
RetryContext GetRetryContext() const { return m_retryContext; }
232+
233+
void SetRetryContext(const RetryContext& context) const { m_retryContext = context; }
226234
protected:
227235
/**
228236
* Default does nothing. Override this to convert what would otherwise be the payload of the
@@ -242,6 +250,7 @@ namespace Aws
242250
RequestRetryHandler m_requestRetryHandler;
243251
mutable std::shared_ptr<Aws::Http::ServiceSpecificParameters> m_serviceSpecificParameters;
244252
mutable Aws::Set<Client::UserAgentFeature> m_userAgentFeatures;
253+
mutable Aws::RetryContext m_retryContext;
245254
};
246255

247256
} // namespace Aws

src/aws-cpp-sdk-core/include/smithy/client/features/ChecksumInterceptor.h

Lines changed: 93 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <aws/core/utils/crypto/Sha1.h>
1717
#include <aws/core/utils/crypto/Sha256.h>
1818
#include <smithy/interceptor/Interceptor.h>
19+
#include <aws/core/utils/memory/stl/AWSArray.h>
1920

2021
#include <iomanip>
2122

@@ -82,49 +83,10 @@ class ChecksumInterceptor : public smithy::interceptor::Interceptor {
8283
// For non-streaming payload, the resolved checksum location is always header.
8384
// For streaming payload, the resolved checksum location depends on whether it is an unsigned payload, we let
8485
// AwsAuthSigner decide it.
85-
if (request.IsStreaming() && checksumValueAndAlgorithmProvided) {
86-
addChecksumFeatureForChecksumName(checksumAlgorithmName, request);
87-
const auto hash = Aws::MakeShared<PrecalculatedHash>(AWS_SMITHY_CLIENT_CHECKSUM, checksumHeader->second);
88-
httpRequest->SetRequestHash(checksumAlgorithmName, hash);
89-
} else if (checksumValueAndAlgorithmProvided) {
90-
httpRequest->SetHeaderValue(checksumType, checksumHeader->second);
91-
} else if (checksumAlgorithmName == "crc64nvme") {
92-
request.AddUserAgentFeature(Aws::Client::UserAgentFeature::FLEXIBLE_CHECKSUMS_REQ_CRC64);
93-
if (request.IsStreaming()) {
94-
httpRequest->SetRequestHash(checksumAlgorithmName, Aws::MakeShared<CRC64>(AWS_SMITHY_CLIENT_CHECKSUM));
95-
} else {
96-
httpRequest->SetHeaderValue(checksumType, HashingUtils::Base64Encode(HashingUtils::CalculateCRC64(*(GetBodyStream(request)))));
97-
}
98-
} else if (checksumAlgorithmName == "crc32") {
99-
request.AddUserAgentFeature(Aws::Client::UserAgentFeature::FLEXIBLE_CHECKSUMS_REQ_CRC32);
100-
if (request.IsStreaming()) {
101-
httpRequest->SetRequestHash(checksumAlgorithmName, Aws::MakeShared<CRC32>(AWS_SMITHY_CLIENT_CHECKSUM));
102-
} else {
103-
httpRequest->SetHeaderValue(checksumType, HashingUtils::Base64Encode(HashingUtils::CalculateCRC32(*(GetBodyStream(request)))));
104-
}
105-
} else if (checksumAlgorithmName == "crc32c") {
106-
request.AddUserAgentFeature(Aws::Client::UserAgentFeature::FLEXIBLE_CHECKSUMS_REQ_CRC32C);
107-
if (request.IsStreaming()) {
108-
httpRequest->SetRequestHash(checksumAlgorithmName, Aws::MakeShared<CRC32C>(AWS_SMITHY_CLIENT_CHECKSUM));
109-
} else {
110-
httpRequest->SetHeaderValue(checksumType, HashingUtils::Base64Encode(HashingUtils::CalculateCRC32C(*(GetBodyStream(request)))));
111-
}
112-
} else if (checksumAlgorithmName == "sha256") {
113-
request.AddUserAgentFeature(Aws::Client::UserAgentFeature::FLEXIBLE_CHECKSUMS_REQ_SHA256);
114-
if (request.IsStreaming()) {
115-
httpRequest->SetRequestHash(checksumAlgorithmName, Aws::MakeShared<Sha256>(AWS_SMITHY_CLIENT_CHECKSUM));
116-
} else {
117-
httpRequest->SetHeaderValue(checksumType, HashingUtils::Base64Encode(HashingUtils::CalculateSHA256(*(GetBodyStream(request)))));
118-
}
119-
} else if (checksumAlgorithmName == "sha1") {
120-
request.AddUserAgentFeature(Aws::Client::UserAgentFeature::FLEXIBLE_CHECKSUMS_REQ_SHA1);
121-
if (request.IsStreaming()) {
122-
httpRequest->SetRequestHash(checksumAlgorithmName, Aws::MakeShared<Sha1>(AWS_SMITHY_CLIENT_CHECKSUM));
123-
} else {
124-
httpRequest->SetHeaderValue(checksumType, HashingUtils::Base64Encode(HashingUtils::CalculateSHA1(*(GetBodyStream(request)))));
125-
}
86+
if (checksumValueAndAlgorithmProvided) {
87+
handleProvidedChecksum(request, httpRequest, checksumAlgorithmName, checksumType, checksumHeader->second);
12688
} else {
127-
AWS_LOGSTREAM_WARN(AWS_SMITHY_CLIENT_CHECKSUM, "Checksum algorithm: " << checksumAlgorithmName << "is not supported by SDK.");
89+
calculateAndSetChecksum(request, httpRequest, checksumAlgorithmName, checksumType);
12890
}
12991
}
13092
}
@@ -133,30 +95,7 @@ class ChecksumInterceptor : public smithy::interceptor::Interceptor {
13395
if ((!request.GetResponseChecksumAlgorithmNames().empty() &&
13496
m_responseChecksumValidation == ResponseChecksumValidation::WHEN_SUPPORTED) ||
13597
request.ShouldValidateResponseChecksum()) {
136-
for (const Aws::String& responseChecksumAlgorithmName : request.GetResponseChecksumAlgorithmNames()) {
137-
const auto responseChecksum = Aws::Utils::StringUtils::ToLower(responseChecksumAlgorithmName.c_str());
138-
if (responseChecksum == "crc32c") {
139-
std::shared_ptr<CRC32C> crc32c = Aws::MakeShared<CRC32C>(AWS_SMITHY_CLIENT_CHECKSUM);
140-
httpRequest->AddResponseValidationHash("crc32c", crc32c);
141-
} else if (responseChecksum == "crc32") {
142-
std::shared_ptr<CRC32> crc32 = Aws::MakeShared<CRC32>(AWS_SMITHY_CLIENT_CHECKSUM);
143-
httpRequest->AddResponseValidationHash("crc32", crc32);
144-
} else if (responseChecksum == "sha1") {
145-
std::shared_ptr<Sha1> sha1 = Aws::MakeShared<Sha1>(AWS_SMITHY_CLIENT_CHECKSUM);
146-
httpRequest->AddResponseValidationHash("sha1", sha1);
147-
} else if (responseChecksum == "sha256") {
148-
std::shared_ptr<Sha256> sha256 = Aws::MakeShared<Sha256>(AWS_SMITHY_CLIENT_CHECKSUM);
149-
httpRequest->AddResponseValidationHash("sha256", sha256);
150-
} else if (responseChecksum == "crc64nvme") {
151-
std::shared_ptr<CRC64> crc64 = Aws::MakeShared<CRC64>(AWS_SMITHY_CLIENT_CHECKSUM);
152-
httpRequest->AddResponseValidationHash("crc64nvme", crc64);
153-
} else {
154-
AWS_LOGSTREAM_WARN(AWS_SMITHY_CLIENT_CHECKSUM,
155-
"Checksum algorithm: " << responseChecksum << " is not supported in validating response body yet.");
156-
}
157-
}
158-
// we have to set the checksum mode to enabled if it was not previously
159-
httpRequest->SetHeaderValue("x-amz-checksum-mode", "enabled");
98+
SetResponseChecksum(request, httpRequest);
16099
}
161100

162101
return httpRequest;
@@ -233,6 +172,94 @@ class ChecksumInterceptor : public smithy::interceptor::Interceptor {
233172
}
234173
}
235174

175+
void handleProvidedChecksum(const Aws::AmazonWebServiceRequest& request, std::shared_ptr<Aws::Http::HttpRequest> httpRequest,
176+
const Aws::String& algorithm, const Aws::String& checksumType, const Aws::String& checksumValue) {
177+
if (request.IsStreaming()) {
178+
addChecksumFeatureForChecksumName(algorithm, request);
179+
if (httpRequest->GetRequestHash().second == nullptr) {
180+
auto hash = Aws::MakeShared<PrecalculatedHash>(AWS_SMITHY_CLIENT_CHECKSUM, checksumValue);
181+
httpRequest->SetRequestHash(algorithm, hash);
182+
}
183+
} else {
184+
httpRequest->SetHeaderValue(checksumType, checksumValue);
185+
}
186+
}
187+
188+
void calculateAndSetChecksum(const Aws::AmazonWebServiceRequest& request, std::shared_ptr<Aws::Http::HttpRequest> httpRequest,
189+
const Aws::String& algorithm, const Aws::String& checksumType) {
190+
static const Aws::Array<std::pair<const char*, ChecksumHandler>, 5> algorithmMap = {{
191+
std::make_pair("crc64nvme", ChecksumHandler{
192+
[]() { return Aws::MakeShared<CRC64>(AWS_SMITHY_CLIENT_CHECKSUM); },
193+
[](Aws::IOStream& stream) { return HashingUtils::Base64Encode(HashingUtils::CalculateCRC64(stream)); },
194+
Aws::Client::UserAgentFeature::FLEXIBLE_CHECKSUMS_REQ_CRC64}),
195+
std::make_pair("crc32", ChecksumHandler{
196+
[]() { return Aws::MakeShared<CRC32>(AWS_SMITHY_CLIENT_CHECKSUM); },
197+
[](Aws::IOStream& stream) { return HashingUtils::Base64Encode(HashingUtils::CalculateCRC32(stream)); },
198+
Aws::Client::UserAgentFeature::FLEXIBLE_CHECKSUMS_REQ_CRC32}),
199+
std::make_pair("crc32c", ChecksumHandler{
200+
[]() { return Aws::MakeShared<CRC32C>(AWS_SMITHY_CLIENT_CHECKSUM); },
201+
[](Aws::IOStream& stream) { return HashingUtils::Base64Encode(HashingUtils::CalculateCRC32C(stream)); },
202+
Aws::Client::UserAgentFeature::FLEXIBLE_CHECKSUMS_REQ_CRC32C}),
203+
std::make_pair("sha256", ChecksumHandler{
204+
[]() { return Aws::MakeShared<Sha256>(AWS_SMITHY_CLIENT_CHECKSUM); },
205+
[](Aws::IOStream& stream) { return HashingUtils::Base64Encode(HashingUtils::CalculateSHA256(stream)); },
206+
Aws::Client::UserAgentFeature::FLEXIBLE_CHECKSUMS_REQ_SHA256}),
207+
std::make_pair("sha1", ChecksumHandler{
208+
[]() { return Aws::MakeShared<Sha1>(AWS_SMITHY_CLIENT_CHECKSUM); },
209+
[](Aws::IOStream& stream) { return HashingUtils::Base64Encode(HashingUtils::CalculateSHA1(stream)); },
210+
Aws::Client::UserAgentFeature::FLEXIBLE_CHECKSUMS_REQ_SHA1})
211+
}};
212+
213+
const auto it = find_if(algorithmMap.begin(), algorithmMap.end(), [&](const std::pair<const char*, ChecksumHandler> &pair) { return algorithm == pair.first; });
214+
if (it == algorithmMap.end()) {
215+
AWS_LOGSTREAM_WARN(AWS_SMITHY_CLIENT_CHECKSUM, "Checksum algorithm: " << algorithm << " is not supported by SDK.");
216+
return;
217+
}
218+
219+
request.AddUserAgentFeature(it->second.userAgentFeature);
220+
221+
if (request.IsStreaming()) {
222+
if (httpRequest->GetRequestHash().second == nullptr) {
223+
httpRequest->SetRequestHash(algorithm, it->second.createHash());
224+
}
225+
} else {
226+
httpRequest->SetHeaderValue(checksumType, it->second.calculateHash(*GetBodyStream(request)));
227+
}
228+
}
229+
230+
void SetResponseChecksum(const Aws::AmazonWebServiceRequest& request, std::shared_ptr<Aws::Http::HttpRequest> httpRequest) {
231+
for (const Aws::String& responseChecksumAlgorithmName : request.GetResponseChecksumAlgorithmNames()) {
232+
const auto responseChecksum = Aws::Utils::StringUtils::ToLower(responseChecksumAlgorithmName.c_str());
233+
if (responseChecksum == "crc32c") {
234+
std::shared_ptr<CRC32C> crc32c = Aws::MakeShared<CRC32C>(AWS_SMITHY_CLIENT_CHECKSUM);
235+
httpRequest->AddResponseValidationHash("crc32c", crc32c);
236+
} else if (responseChecksum == "crc32") {
237+
std::shared_ptr<CRC32> crc32 = Aws::MakeShared<CRC32>(AWS_SMITHY_CLIENT_CHECKSUM);
238+
httpRequest->AddResponseValidationHash("crc32", crc32);
239+
} else if (responseChecksum == "sha1") {
240+
std::shared_ptr<Sha1> sha1 = Aws::MakeShared<Sha1>(AWS_SMITHY_CLIENT_CHECKSUM);
241+
httpRequest->AddResponseValidationHash("sha1", sha1);
242+
} else if (responseChecksum == "sha256") {
243+
std::shared_ptr<Sha256> sha256 = Aws::MakeShared<Sha256>(AWS_SMITHY_CLIENT_CHECKSUM);
244+
httpRequest->AddResponseValidationHash("sha256", sha256);
245+
} else if (responseChecksum == "crc64nvme") {
246+
std::shared_ptr<CRC64> crc64 = Aws::MakeShared<CRC64>(AWS_SMITHY_CLIENT_CHECKSUM);
247+
httpRequest->AddResponseValidationHash("crc64nvme", crc64);
248+
} else {
249+
AWS_LOGSTREAM_WARN(AWS_SMITHY_CLIENT_CHECKSUM,
250+
"Checksum algorithm: " << responseChecksum << " is not supported in validating response body yet.");
251+
}
252+
}
253+
// we have to set the checksum mode to enabled if it was not previously
254+
httpRequest->SetHeaderValue("x-amz-checksum-mode", "enabled");
255+
}
256+
257+
struct ChecksumHandler {
258+
std::function<std::shared_ptr<Aws::Utils::Crypto::Hash>()> createHash;
259+
std::function<Aws::String(Aws::IOStream&)> calculateHash;
260+
Aws::Client::UserAgentFeature userAgentFeature;
261+
};
262+
236263
RequestChecksumCalculation m_requestChecksumCalculation{RequestChecksumCalculation::WHEN_SUPPORTED};
237264
ResponseChecksumValidation m_responseChecksumValidation{ResponseChecksumValidation::WHEN_SUPPORTED};
238265
};

src/aws-cpp-sdk-core/source/client/AWSClient.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,17 @@ HttpResponseOutcome AWSClient::AttemptExhaustively(const Aws::Http::URI& uri,
389389
{
390390
newUri.SetAuthority(newEndpoint);
391391
}
392+
393+
// Save checksum information from the original request if we haven't already - safe to assume that the checksum has been finalized, since we have sent and received a response
394+
RetryContext context = request.GetRetryContext();
395+
if (context.m_requestHash == nullptr) {
396+
auto originalRequestHash = httpRequest->GetRequestHash();
397+
if (originalRequestHash.second != nullptr) {
398+
context.m_requestHash = Aws::MakeShared<std::pair<Aws::String, std::shared_ptr<Aws::Utils::Crypto::Hash>>>(AWS_CLIENT_LOG_TAG, originalRequestHash);
399+
request.SetRetryContext(context);
400+
}
401+
}
402+
392403
httpRequest = CreateHttpRequest(newUri, method, request.GetResponseStreamFactory());
393404

394405
httpRequest->SetHeaderValue(Http::SDK_INVOCATION_ID_HEADER, invocationId);
@@ -920,6 +931,13 @@ void AWSClient::BuildHttpRequest(const Aws::AmazonWebServiceRequest& request, co
920931
httpRequest->SetContinueRequestHandle(request.GetContinueRequestHandler());
921932
httpRequest->SetServiceSpecificParameters(request.GetServiceSpecificParameters());
922933
request.AddQueryStringParameters(httpRequest->GetUri());
934+
935+
// check for retry context, if present use it
936+
RetryContext context = request.GetRetryContext();
937+
if (context.m_requestHash != nullptr) {
938+
const auto hash = Aws::MakeShared<Aws::Utils::Crypto::PrecalculatedHash>(smithy::client::AWS_SMITHY_CLIENT_CHECKSUM, HashingUtils::Base64Encode(context.m_requestHash->second->GetHash().GetResult()));
939+
httpRequest->SetRequestHash(context.m_requestHash->first, hash);
940+
}
923941
}
924942

925943
Aws::String AWSClient::GeneratePresignedUrl(const Aws::Http::URI& uri, Aws::Http::HttpMethod method, long long expirationInSeconds, const std::shared_ptr<Aws::Http::ServiceSpecificParameters> serviceSpecificParameter)

0 commit comments

Comments
 (0)