Skip to content

Commit 2e36fcc

Browse files
committed
feat: Add callback-based credential provider tracking system
Implement lightweight callback mechanism for tracking credential provider usage in User-Agent strings without breaking backward compatibility. Changes: - Add SetCredentialTrackingCallback() and NotifyCredentialUsage() to AWSCredentialsProvider base class - Implement callback forwarding in AWSCredentialsProviderChain for provider chains - Add TrackCredentialProviderUsage() method to AWSClient for setting up tracking callbacks - Update EnvironmentAWSCredentialsProvider to call NotifyCredentialUsage() when credentials found - Add unit test for environment credential tracking (currently failing due to GetCredentialsProvider() limitation) The callback system allows individual credential providers to notify when they successfully retrieve credentials, enabling User-Agent tracking without requiring API changes or thread-local variables. Note: Current implementation has limitation where GetCredentialsProvider() returns null when provider is embedded in signer, preventing callback setup in some scenarios.
1 parent 522b399 commit 2e36fcc

File tree

6 files changed

+100
-70
lines changed

6 files changed

+100
-70
lines changed

src/aws-cpp-sdk-core/include/aws/core/auth/AWSCredentialsProvider.h

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,12 @@
1919
#include <aws/core/config/AWSProfileConfigLoader.h>
2020
#include <aws/core/client/RetryStrategy.h>
2121
#include <memory>
22+
#include <functional>
2223

2324
namespace Aws
2425
{
26+
class AmazonWebServiceRequest;
27+
2528
namespace Client
2629
{
2730
struct ClientConfiguration;
@@ -32,17 +35,6 @@ namespace Aws
3235

3336
constexpr int AWS_CREDENTIAL_PROVIDER_EXPIRATION_GRACE_PERIOD = 5 * 1000;
3437

35-
/**
36-
* Enum to identify credential provider types for tracking purposes
37-
*/
38-
enum class CredentialProviderType
39-
{
40-
DEFAULT,
41-
ENVIRONMENT,
42-
// ... add other types as needed
43-
44-
};
45-
4638
/**
4739
* Returns the full path of the config file.
4840
*/
@@ -73,22 +65,27 @@ namespace Aws
7365
* Initializes provider. Sets last Loaded time count to 0, forcing a refresh on the
7466
* first call to GetAWSCredentials.
7567
*/
76-
AWSCredentialsProvider(CredentialProviderType providerType = CredentialProviderType::DEFAULT)
77-
: m_lastLoadedMs(0), m_providerType(providerType)
68+
AWSCredentialsProvider() : m_lastLoadedMs(0)
7869
{
7970
}
8071

81-
/**
82-
* Get the provider type for tracking purposes
83-
*/
84-
CredentialProviderType GetProviderType() const { return m_providerType; }
85-
8672
virtual ~AWSCredentialsProvider() = default;
8773

8874
/**
8975
* The core of the credential provider interface. Override this method to control how credentials are retrieved.
9076
*/
9177
virtual AWSCredentials GetAWSCredentials() = 0;
78+
79+
/**
80+
* Set callback for credential usage tracking
81+
*/
82+
virtual void SetCredentialTrackingCallback(std::function<void()> callback) { m_trackingCallback = callback; }
83+
84+
protected:
85+
/**
86+
* Call this when credentials are successfully retrieved for tracking
87+
*/
88+
void NotifyCredentialUsage() { if (m_trackingCallback) m_trackingCallback(); }
9289

9390
protected:
9491
/**
@@ -100,7 +97,7 @@ namespace Aws
10097
mutable Aws::Utils::Threading::ReaderWriterLock m_reloadLock;
10198
private:
10299
long long m_lastLoadedMs;
103-
CredentialProviderType m_providerType;
100+
std::function<void()> m_trackingCallback;
104101
};
105102

106103
/**
@@ -114,7 +111,6 @@ namespace Aws
114111
* Returns empty credentials object.
115112
*/
116113
inline AWSCredentials GetAWSCredentials() override { return AWSCredentials(); }
117-
118114
};
119115

120116
/**
@@ -161,7 +157,7 @@ namespace Aws
161157
/**
162158
* Initializes environment credentials provider
163159
*/
164-
EnvironmentAWSCredentialsProvider() : AWSCredentialsProvider(CredentialProviderType::ENVIRONMENT) {}
160+
EnvironmentAWSCredentialsProvider() = default;
165161

166162
/**
167163
* Reads AWS credentials from the Environment variables AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY and AWS_SESSION_TOKEN if they exist. If they

src/aws-cpp-sdk-core/include/aws/core/auth/AWSCredentialsProviderChain.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@ namespace Aws
2727
* When a credentials provider in the chain returns empty credentials,
2828
* We go on to the next provider until we have either exhausted the installed providers in the chain or something returns non-empty credentials.
2929
*/
30-
virtual AWSCredentials GetAWSCredentials();
30+
AWSCredentials GetAWSCredentials() override;
31+
32+
/**
33+
* Override to store and forward callback to providers in chain
34+
*/
35+
void SetCredentialTrackingCallback(std::function<void()> callback) override { m_chainTrackingCallback = callback; }
3136

3237
/**
3338
* Gets all providers stored in this chain.
@@ -50,6 +55,7 @@ namespace Aws
5055
Aws::Vector<std::shared_ptr<AWSCredentialsProvider> > m_providerChain;
5156
std::shared_ptr<AWSCredentialsProvider> m_cachedProvider;
5257
mutable Aws::Utils::Threading::ReaderWriterLock m_cachedProviderLock;
58+
std::function<void()> m_chainTrackingCallback;
5359
};
5460

5561
/**

src/aws-cpp-sdk-core/source/auth/AWSCredentialsProvider.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include <aws/core/client/AWSError.h>
1919
#include <aws/core/utils/StringUtils.h>
2020
#include <aws/core/utils/xml/XmlSerializer.h>
21+
#include <aws/core/AmazonWebServiceRequest.h>
22+
#include <aws/core/client/UserAgent.h>
2123
#include <cstdlib>
2224
#include <fstream>
2325
#include <string.h>
@@ -102,11 +104,12 @@ AWSCredentials EnvironmentAWSCredentialsProvider::GetAWSCredentials() //pass in
102104
AWS_LOGSTREAM_DEBUG(ENVIRONMENT_LOG_TAG, "Found accountId");
103105
}
104106
}
105-
107+
106108
if (!credentials.IsEmpty()) {
107-
// TODO: this will work
108-
// TODO: how does request get here?????
109-
// request.AddFeature(ENV_VAR)
109+
// TODO: this will work
110+
// TODO: how does request get here?????
111+
// request.AddFeature(ENV_VAR)
112+
NotifyCredentialUsage();
110113
}
111114

112115
return credentials;

src/aws-cpp-sdk-core/source/auth/AWSCredentialsProviderChain.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ AWSCredentials AWSCredentialsProviderChain::GetAWSCredentials()
2222
{
2323
ReaderLockGuard lock(m_cachedProviderLock);
2424
if (m_cachedProvider) {
25+
// Forward callback to cached provider
26+
m_cachedProvider->SetCredentialTrackingCallback(m_chainTrackingCallback);
2527
AWSCredentials credentials = m_cachedProvider->GetAWSCredentials();
2628
if (!credentials.GetAWSAccessKeyId().empty() && !credentials.GetAWSSecretKey().empty())
2729
{
@@ -31,6 +33,8 @@ AWSCredentials AWSCredentialsProviderChain::GetAWSCredentials()
3133
lock.UpgradeToWriterLock();
3234
for (auto&& credentialsProvider : m_providerChain)
3335
{
36+
// Forward callback to each provider in chain
37+
credentialsProvider->SetCredentialTrackingCallback(m_chainTrackingCallback);
3438
AWSCredentials credentials = credentialsProvider->GetAWSCredentials();
3539
if (!credentials.GetAWSAccessKeyId().empty() && !credentials.GetAWSSecretKey().empty())
3640
{

src/aws-cpp-sdk-core/source/client/AWSClient.cpp

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,25 +1075,19 @@ void AWSClient::TrackCredentialProviderUsage(const Aws::AmazonWebServiceRequest&
10751075
return;
10761076
}
10771077

1078-
// Get the provider type
1079-
auto providerType = credentialsProvider->GetProviderType();
1080-
1081-
switch (providerType)
1082-
{
1083-
case Aws::Auth::CredentialProviderType::ENVIRONMENT:
1084-
// Environment credentials are being used
1078+
// Set up callback for credential tracking
1079+
if (credentialsProvider) {
1080+
credentialsProvider->SetCredentialTrackingCallback([&request]() {
10851081
request.AddUserAgentFeature(Aws::Client::UserAgentFeature::CREDENTIALS_ENV_VARS);
10861082
AWS_LOGSTREAM_DEBUG(AWS_CLIENT_LOG_TAG, "Added CREDENTIALS_ENV_VARS to User-Agent");
1087-
break;
1083+
});
1084+
}
1085+
1086+
// Trigger credential retrieval to enable tracking
1087+
if (credentialsProvider) {
1088+
credentialsProvider->GetAWSCredentials();
10881089

1089-
// Add more provider types as needed
1090-
default:
1091-
// For provider chains or unknown types, check environment variables as fallback
1092-
if (!Aws::Environment::GetEnv("AWS_ACCESS_KEY_ID").empty())
1093-
{
1094-
request.AddUserAgentFeature(Aws::Client::UserAgentFeature::CREDENTIALS_ENV_VARS);
1095-
AWS_LOGSTREAM_DEBUG(AWS_CLIENT_LOG_TAG, "Added CREDENTIALS_ENV_VARS to User-Agent (fallback detection)");
1096-
}
1097-
break;
1090+
// Clear callback
1091+
credentialsProvider->SetCredentialTrackingCallback(nullptr);
10981092
}
10991093
}

tests/aws-cpp-sdk-core-tests/aws/auth/CredentialTrackingTest.cpp

Lines changed: 54 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,50 @@
88
#include <aws/testing/mocks/aws/client/MockAWSClient.h>
99
#include <aws/testing/mocks/http/MockHttpClient.h>
1010
#include <aws/core/auth/AWSCredentialsProvider.h>
11+
#include <aws/core/auth/AWSCredentialsProviderChain.h>
1112
#include <aws/core/client/ClientConfiguration.h>
13+
#include <aws/core/client/AWSClient.h>
14+
#include <aws/core/auth/AWSAuthSigner.h>
1215
#include <aws/core/platform/Environment.h>
1316
#include <aws/core/utils/StringUtils.h>
17+
#include <iostream>
1418

1519
using namespace Aws::Client;
1620
using namespace Aws::Auth;
1721
using namespace Aws::Http;
1822

1923
static const char ALLOCATION_TAG[] = "CredentialTrackingTest";
2024

25+
// Custom client that uses default credential provider chain for testing
26+
class CredentialTestingClient : public Aws::Client::AWSClient
27+
{
28+
public:
29+
explicit CredentialTestingClient(const Aws::Client::ClientConfiguration& configuration)
30+
: AWSClient(configuration,
31+
Aws::MakeShared<Aws::Client::AWSAuthV4Signer>(ALLOCATION_TAG,
32+
Aws::MakeShared<DefaultAWSCredentialsProviderChain>(ALLOCATION_TAG),
33+
"service", configuration.region),
34+
Aws::MakeShared<MockAWSErrorMarshaller>(ALLOCATION_TAG))
35+
{
36+
// Client created with DefaultAWSCredentialsProviderChain
37+
}
38+
39+
Aws::Client::HttpResponseOutcome MakeRequest(const Aws::AmazonWebServiceRequest& request)
40+
{
41+
auto uri = Aws::Http::URI("https://test.com");
42+
return AWSClient::AttemptExhaustively(uri, request, Aws::Http::HttpMethod::HTTP_POST, Aws::Auth::SIGV4_SIGNER);
43+
}
44+
45+
const char* GetServiceClientName() const override { return "CredentialTestingClient"; }
46+
47+
protected:
48+
Aws::Client::AWSError<Aws::Client::CoreErrors> BuildAWSError(const std::shared_ptr<Aws::Http::HttpResponse>& response) const override
49+
{
50+
AWS_UNREFERENCED_PARAM(response);
51+
return Aws::Client::AWSError<Aws::Client::CoreErrors>(Aws::Client::CoreErrors::UNKNOWN, false);
52+
}
53+
};
54+
2155
class CredentialTrackingTest : public Aws::Testing::AwsCppSdkGTestSuite
2256
{
2357
protected:
@@ -49,33 +83,35 @@ TEST_F(CredentialTrackingTest, TestEnvironmentCredentialsTracking)
4983
Aws::Environment::SetEnv("AWS_SECRET_ACCESS_KEY", "test-secret-key", 1);
5084

5185
// Setup mock response
52-
auto request = CreateHttpRequest(Aws::Http::URI("http://test.com"),
53-
Aws::Http::HttpMethod::HTTP_POST,
54-
Aws::Utils::Stream::DefaultResponseStreamFactoryMethod);
55-
auto response = Aws::MakeShared<Standard::StandardHttpResponse>(ALLOCATION_TAG, request);
56-
response->SetResponseCode(HttpResponseCode::OK);
57-
response->GetResponseBody() << "{}";
58-
mockHttpClient->AddResponseToReturn(response);
86+
std::shared_ptr<HttpRequest> requestTmp =
87+
CreateHttpRequest(Aws::Http::URI("dummy"), Aws::Http::HttpMethod::HTTP_POST,
88+
Aws::Utils::Stream::DefaultResponseStreamFactoryMethod);
89+
auto successResponse = Aws::MakeShared<Standard::StandardHttpResponse>(ALLOCATION_TAG, requestTmp);
90+
successResponse->SetResponseCode(HttpResponseCode::OK);
91+
successResponse->GetResponseBody() << "{}";
92+
mockHttpClient->AddResponseToReturn(successResponse);
5993

6094
// Create client configuration
61-
ClientConfiguration config;
62-
config.region = Aws::Region::US_EAST_1;
95+
Aws::Client::ClientConfigurationInitValues cfgInit;
96+
cfgInit.shouldDisableIMDS = true;
97+
Aws::Client::ClientConfiguration clientConfig(cfgInit);
98+
clientConfig.region = Aws::Region::US_EAST_1;
6399

64-
// Create mock client
65-
MockAWSClient client(config);
100+
// Create credential testing client that uses default provider chain
101+
CredentialTestingClient client(clientConfig);
66102

67-
// Make a request
103+
// Create mock request
68104
AmazonWebServiceRequestMock mockRequest;
69-
auto outcome = client.MakeRequest(mockRequest);
70105

71-
// Verify request succeeded
72-
AWS_ASSERT_SUCCESS(outcome);
106+
// Make request
107+
auto outcome = client.MakeRequest(mockRequest);
108+
ASSERT_TRUE(outcome.IsSuccess());
73109

74110
// Verify User-Agent contains environment credentials tracking
75111
auto lastRequest = mockHttpClient->GetMostRecentHttpRequest();
76-
EXPECT_TRUE(lastRequest.HasUserAgent());
77-
const auto& userAgent = lastRequest.GetUserAgent();
78-
EXPECT_TRUE(!userAgent.empty());
112+
EXPECT_TRUE(lastRequest.HasHeader(Aws::Http::USER_AGENT_HEADER));
113+
const auto& userAgent = lastRequest.GetHeaderValue(Aws::Http::USER_AGENT_HEADER);
114+
EXPECT_FALSE(userAgent.empty());
79115

80116
const auto userAgentParsed = Aws::Utils::StringUtils::Split(userAgent, ' ');
81117

@@ -89,12 +125,3 @@ TEST_F(CredentialTrackingTest, TestEnvironmentCredentialsTracking)
89125
Aws::Environment::UnSetEnv("AWS_ACCESS_KEY_ID");
90126
Aws::Environment::UnSetEnv("AWS_SECRET_ACCESS_KEY");
91127
}
92-
93-
TEST_F(CredentialTrackingTest, TestEnvironmentProviderType)
94-
{
95-
// Test that EnvironmentAWSCredentialsProvider has correct provider type
96-
auto envProvider = Aws::MakeShared<EnvironmentAWSCredentialsProvider>(ALLOCATION_TAG);
97-
98-
// Verify the provider type is set correctly
99-
EXPECT_EQ(envProvider->GetProviderType(), Aws::Auth::CredentialProviderType::ENVIRONMENT);
100-
}

0 commit comments

Comments
 (0)