Skip to content

Commit 22354bc

Browse files
committed
fix for custom UA
# Conflicts: # src/aws-cpp-sdk-core/include/aws/core/auth/AWSCredentials.h # src/aws-cpp-sdk-core/source/auth/signer/AWSAuthV4Signer.cpp
1 parent c054379 commit 22354bc

File tree

3 files changed

+132
-0
lines changed

3 files changed

+132
-0
lines changed

src/aws-cpp-sdk-core/include/aws/core/auth/AWSCredentials.h

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,67 @@
66
#pragma once
77

88
#include <aws/core/Core_EXPORTS.h>
9+
#include <aws/core/client/UserAgent.h>
910
#include <aws/core/utils/memory/stl/AWSString.h>
1011
#include <aws/core/utils/DateTime.h>
1112
namespace Aws
1213
{
1314
namespace Auth
1415
{
16+
/**
17+
* Context class for credential resolution that tracks features used during credential retrieval.
18+
*/
19+
class AWS_CORE_API CredentialsResolutionContext
20+
{
21+
public:
22+
// Default constructor - no features tracked
23+
CredentialsResolutionContext() = default;
24+
25+
/**
26+
* Add a user agent feature to track credential usage.
27+
*/
28+
void AddUserAgentFeature(Aws::Client::UserAgentFeature feature)
29+
{
30+
m_features.insert(feature);
31+
}
32+
33+
/**
34+
* Get all tracked credential features.
35+
*/
36+
const Aws::Set<Aws::Client::UserAgentFeature> GetUserAgentFeatures() const
37+
{
38+
return m_features;
39+
}
40+
41+
/**
42+
* Set the user agent for this context
43+
*/
44+
void SetUserAgent(const std::shared_ptr<Aws::Client::UserAgent>& userAgent)
45+
{
46+
m_userAgent = userAgent;
47+
}
48+
49+
/**
50+
* Get the user agent associated with this context
51+
*/
52+
const std::shared_ptr<Aws::Client::UserAgent>& GetUserAgent() const
53+
{
54+
return m_userAgent;
55+
}
56+
57+
/**
58+
* Check if this context has a custom user agent
59+
*/
60+
bool HasCustomUserAgent() const
61+
{
62+
return m_userAgent && m_userAgent->HasOverrideUserAgent();
63+
}
64+
65+
private:
66+
Aws::Set<Aws::Client::UserAgentFeature> m_features;
67+
std::shared_ptr<Aws::Client::UserAgent> m_userAgent;
68+
};
69+
1570
/**
1671
* Simple data object around aws credentials
1772
*/
@@ -214,12 +269,24 @@ namespace Aws
214269
m_expiration = expiration;
215270
}
216271

272+
/**
273+
* Gets credential resolution context. this is information about the call
274+
* such as what credentials provider was used to to resolve the credentials
275+
*/
276+
inline CredentialsResolutionContext GetContext() { return m_context; }
277+
278+
/**
279+
* Adds a user agent feature used during credentials resolution to the credentials
280+
* context. This is useful to track which credentials provider was used.
281+
*/
282+
inline void AddUserAgentFeature(Aws::Client::UserAgentFeature feature) { m_context.AddUserAgentFeature(feature); }
217283
private:
218284
Aws::String m_accessKeyId;
219285
Aws::String m_secretKey;
220286
Aws::String m_sessionToken;
221287
Aws::Utils::DateTime m_expiration;
222288
Aws::String m_accountId;
289+
CredentialsResolutionContext m_context;
223290
};
224291
}
225292
}

src/aws-cpp-sdk-core/include/aws/core/client/UserAgent.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class AWS_CORE_API UserAgent {
3838
public:
3939
explicit UserAgent(const ClientConfiguration& clientConfiguration, const Aws::String& retryStrategyName, const Aws::String& apiName);
4040
Aws::String SerializeWithFeatures(const Aws::Set<UserAgentFeature>& features) const;
41+
42+
bool HasOverrideUserAgent() const { return !m_overrideUserAgent.empty(); }
4143
void SetApiName(const Aws::String& apiName) { m_api = apiName; }
4244
void AddLegacyFeature(const Aws::String& legacyFeature);
4345

src/aws-cpp-sdk-core/source/auth/signer/AWSAuthV4Signer.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <aws/core/auth/signer/AWSAuthSignerHelper.h>
99

1010
#include <aws/core/auth/AWSCredentialsProvider.h>
11+
#include <aws/core/client/UserAgent.h>
1112
#include <aws/core/http/HttpRequest.h>
1213
#include <aws/core/http/URI.h>
1314
#include <aws/core/utils/DateTime.h>
@@ -25,6 +26,7 @@
2526

2627
#include <iomanip>
2728
#include <cstring>
29+
#include <numeric>
2830

2931
using namespace Aws;
3032
using namespace Aws::Client;
@@ -81,6 +83,8 @@ bool AWSAuthV4Signer::SignRequestWithSigV4a(Aws::Http::HttpRequest& request, con
8183
bool signBody, long long expirationTimeInSeconds, Aws::Crt::Auth::SignatureType signatureType) const
8284
{
8385
AWSCredentials credentials = GetCredentials(request.GetServiceSpecificParameters());
86+
87+
UpdateUserAgentWithCredentialFeatures(request, credentials.GetContext());
8488
auto crtCredentials = Aws::MakeShared<Aws::Crt::Auth::Credentials>(v4AsymmetricLogTag,
8589
Aws::Crt::ByteCursorFromCString(credentials.GetAWSAccessKeyId().c_str()),
8690
Aws::Crt::ByteCursorFromCString(credentials.GetAWSSecretKey().c_str()),
@@ -336,6 +340,9 @@ bool AWSAuthV4Signer::SignRequestWithCreds(Aws::Http::HttpRequest& request, cons
336340
bool AWSAuthV4Signer::SignRequest(Aws::Http::HttpRequest& request, const char* region, const char* serviceName, bool signBody) const
337341
{
338342
AWSCredentials credentials = GetCredentials(request.GetServiceSpecificParameters());
343+
344+
UpdateUserAgentWithCredentialFeatures(request, credentials.GetContext());
345+
339346
return SignRequestWithCreds(request, credentials, region, serviceName, signBody);
340347
}
341348

@@ -464,6 +471,9 @@ bool AWSAuthV4Signer::PresignRequest(Aws::Http::HttpRequest& request, const Aws:
464471
bool AWSAuthV4Signer::PresignRequest(Aws::Http::HttpRequest& request, const char* region, const char* serviceName, long long expirationTimeInSeconds) const
465472
{
466473
AWSCredentials credentials = GetCredentials(request.GetServiceSpecificParameters());
474+
475+
UpdateUserAgentWithCredentialFeatures(request, credentials.GetContext());
476+
467477
return PresignRequest(request, credentials, region,serviceName, expirationTimeInSeconds );
468478
}
469479

@@ -595,3 +605,56 @@ Aws::Auth::AWSCredentials AWSAuthV4Signer::GetCredentials(const std::shared_ptr<
595605
AWS_UNREFERENCED_PARAM(serviceSpecificParameters);
596606
return m_credentialsProvider->GetAWSCredentials();
597607
}
608+
609+
void AWSAuthV4Signer::UpdateUserAgentWithCredentialFeatures(Aws::Http::HttpRequest& request, const Aws::Auth::CredentialsResolutionContext& context) const {
610+
if (!request.HasHeader(USER_AGENT)) {
611+
AWS_LOGSTREAM_DEBUG(v4LogTag, "Request does not have User-Agent header, skipping credential feature update");
612+
return;
613+
}
614+
615+
if (context.HasCustomUserAgent()) {
616+
AWS_LOGSTREAM_DEBUG(v4LogTag, "Custom User-Agent detected, skipping credential feature update");
617+
return;
618+
}
619+
620+
const auto features = context.GetUserAgentFeatures();
621+
if (features.empty()) {
622+
AWS_LOGSTREAM_DEBUG(v4LogTag, "No credential features to add to User-Agent");
623+
return;
624+
}
625+
626+
std::vector<Aws::String> businessMetrics(features.size());
627+
std::transform(features.begin(),
628+
features.end(),
629+
businessMetrics.begin(),
630+
[](UserAgentFeature feature) -> Aws::String { return UserAgent::BusinessMetricForFeature(feature); });
631+
632+
const auto credentialFeatures = std::accumulate(std::next(businessMetrics.begin()),
633+
businessMetrics.end(),
634+
businessMetrics.front(),
635+
[](const Aws::String& a, const Aws::String& b) {
636+
return a + "," + b;
637+
});
638+
639+
const auto userAgent = request.GetHeaderValue(USER_AGENT);
640+
auto userAgentParsed = Aws::Utils::StringUtils::Split(userAgent, ' ');
641+
auto metricsSegment = std::find_if(userAgentParsed.begin(), userAgentParsed.end(),
642+
[](const Aws::String& value) { return value.find("m/") != Aws::String::npos; });
643+
644+
if (metricsSegment != userAgentParsed.end()) {
645+
// Add new metrics to existing metrics section
646+
*metricsSegment = Aws::String{*metricsSegment + "," + credentialFeatures};
647+
} else {
648+
// No metrics section exists, add new one
649+
userAgentParsed.push_back("m/" + credentialFeatures);
650+
}
651+
652+
// Reassemble all parts with spaces
653+
const auto newUserAgent = std::accumulate(std::next(userAgentParsed.begin()),
654+
userAgentParsed.end(),
655+
userAgentParsed.front(),
656+
[](const Aws::String& a, const Aws::String& b) {
657+
return a + " " + b;
658+
});
659+
request.SetUserAgent(newUserAgent);
660+
}

0 commit comments

Comments
 (0)