@@ -390,6 +390,17 @@ HttpResponseOutcome AWSClient::AttemptExhaustively(const Aws::Http::URI& uri,
390390 {
391391 newUri.SetAuthority (newEndpoint);
392392 }
393+
394+ // Save checksum information from the original request if we haven't already - safe to assume that the checksum has been finalized, since we have sent and received a response
395+ RetryContext context = request.GetRetryContext ();
396+ if (context.m_requestHash == nullptr ) {
397+ auto originalRequestHash = httpRequest->GetRequestHash ();
398+ if (originalRequestHash.second != nullptr ) {
399+ context.m_requestHash = Aws::MakeShared<std::pair<Aws::String, std::shared_ptr<Aws::Utils::Crypto::Hash>>>(AWS_CLIENT_LOG_TAG, originalRequestHash);
400+ request.SetRetryContext (context);
401+ }
402+ }
403+
393404 httpRequest = CreateHttpRequest (newUri, method, request.GetResponseStreamFactory ());
394405
395406 httpRequest->SetHeaderValue (Http::SDK_INVOCATION_ID_HEADER, invocationId);
@@ -930,6 +941,13 @@ void AWSClient::BuildHttpRequest(const Aws::AmazonWebServiceRequest& request, co
930941 httpRequest->SetContinueRequestHandle (request.GetContinueRequestHandler ());
931942 httpRequest->SetServiceSpecificParameters (request.GetServiceSpecificParameters ());
932943 request.AddQueryStringParameters (httpRequest->GetUri ());
944+
945+ // check for retry context, if present use it
946+ RetryContext context = request.GetRetryContext ();
947+ if (context.m_requestHash != nullptr ) {
948+ const auto hash = Aws::MakeShared<Aws::Utils::Crypto::PrecalculatedHash>(smithy::client::AWS_SMITHY_CLIENT_CHECKSUM, HashingUtils::Base64Encode (context.m_requestHash ->second ->GetHash ().GetResult ()));
949+ httpRequest->SetRequestHash (context.m_requestHash ->first , hash);
950+ }
933951}
934952
935953Aws::String AWSClient::GeneratePresignedUrl (const Aws::Http::URI& uri, Aws::Http::HttpMethod method, long long expirationInSeconds, const std::shared_ptr<Aws::Http::ServiceSpecificParameters> serviceSpecificParameter)
@@ -1020,38 +1038,37 @@ std::shared_ptr<Aws::Http::HttpResponse> AWSClient::MakeHttpRequest(std::shared_
10201038 return m_httpClient->MakeRequest (request, m_readRateLimiter.get (), m_writeRateLimiter.get ());
10211039}
10221040
1023- void AWSClient::AppendRecursionDetectionHeader (std::shared_ptr<Aws::Http::HttpRequest> ioRequest)
1024- {
1025- if (!ioRequest || ioRequest->HasHeader (Aws::Http::X_AMZN_TRACE_ID_HEADER)) {
1026- return ;
1027- }
1028- Aws::String awsLambdaFunctionName = Aws::Environment::GetEnv (AWS_LAMBDA_FUNCTION_NAME);
1029- if (awsLambdaFunctionName.empty ()) {
1030- return ;
1031- }
1032- Aws::String xAmznTraceIdVal = Aws::Environment::GetEnv (X_AMZN_TRACE_ID);
1033- if (xAmznTraceIdVal.empty ()) {
1034- return ;
1035- }
1041+ void AWSClient::AppendRecursionDetectionHeader (std::shared_ptr<Aws::Http::HttpRequest> ioRequest) {
1042+ if (!ioRequest || ioRequest->HasHeader (Aws::Http::X_AMZN_TRACE_ID_HEADER)) {
1043+ return ;
1044+ }
1045+ Aws::String awsLambdaFunctionName = Aws::Environment::GetEnv (AWS_LAMBDA_FUNCTION_NAME);
1046+ if (awsLambdaFunctionName.empty ()) {
1047+ return ;
1048+ }
1049+ Aws::String xAmznTraceIdVal = Aws::Environment::GetEnv (X_AMZN_TRACE_ID);
1050+ if (xAmznTraceIdVal.empty ()) {
1051+ return ;
1052+ }
10361053
1037- // Escape all non-printable ASCII characters by percent encoding
1038- Aws::OStringStream xAmznTraceIdValEncodedStr;
1039- for (const char ch : xAmznTraceIdVal)
1054+ // Escape all non-printable ASCII characters by percent encoding
1055+ Aws::OStringStream xAmznTraceIdValEncodedStr;
1056+ for (const char ch : xAmznTraceIdVal)
1057+ {
1058+ if (ch >= 0x20 && ch <= 0x7e ) // ascii chars [32-126] or [' ' to '~'] are not escaped
10401059 {
1041- if (ch >= 0x20 && ch <= 0x7e ) // ascii chars [32-126] or [' ' to '~'] are not escaped
1042- {
1043- xAmznTraceIdValEncodedStr << ch;
1044- }
1045- else
1046- {
1047- // A percent-encoded octet is encoded as a character triplet
1048- xAmznTraceIdValEncodedStr << ' %' // consisting of the percent character "%"
1049- << std::hex << std::setfill (' 0' ) << std::setw (2 ) << std::uppercase
1050- << (size_t ) ch // followed by the two hexadecimal digits representing that octet's numeric value
1051- << std::dec << std::setfill (' ' ) << std::setw (0 ) << std::nouppercase;
1052- }
1060+ xAmznTraceIdValEncodedStr << ch;
10531061 }
1054- xAmznTraceIdVal = xAmznTraceIdValEncodedStr.str ();
1062+ else
1063+ {
1064+ // A percent-encoded octet is encoded as a character triplet
1065+ xAmznTraceIdValEncodedStr << ' %' // consisting of the percent character "%"
1066+ << std::hex << std::setfill (' 0' ) << std::setw (2 ) << std::uppercase
1067+ << (size_t ) ch // followed by the two hexadecimal digits representing that octet's numeric value
1068+ << std::dec << std::setfill (' ' ) << std::setw (0 ) << std::nouppercase;
1069+ }
1070+ }
1071+ xAmznTraceIdVal = xAmznTraceIdValEncodedStr.str ();
10551072
1056- ioRequest->SetHeaderValue (Aws::Http::X_AMZN_TRACE_ID_HEADER, xAmznTraceIdVal);
1073+ ioRequest->SetHeaderValue (Aws::Http::X_AMZN_TRACE_ID_HEADER, xAmznTraceIdVal);
10571074}
0 commit comments