Skip to content

Commit f7b37ad

Browse files
committed
fix cache and threading concerns with STS Webidentity provider
1 parent d71a569 commit f7b37ad

File tree

3 files changed

+113
-28
lines changed

3 files changed

+113
-28
lines changed

src/aws-cpp-sdk-core/include/aws/core/auth/STSCredentialsProvider.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,14 @@
88

99
#include <aws/core/Core_EXPORTS.h>
1010
#include <aws/core/auth/AWSCredentialsProvider.h>
11+
#include <atomic>
12+
#include <memory>
1113

1214
namespace Aws {
1315
namespace Crt {
1416
namespace Auth {
1517
class ICredentialsProvider;
18+
class Credentials;
1619
}
1720
}
1821
}
@@ -46,10 +49,19 @@ namespace Aws
4649
INITIALIZED,
4750
SHUT_DOWN,
4851
} m_state{STATE::SHUT_DOWN};
49-
std::mutex m_refreshMutex;
50-
std::condition_variable m_refreshSignal;
52+
mutable std::mutex m_refreshMutex;
53+
mutable std::condition_variable m_refreshSignal;
5154
std::shared_ptr<Aws::Crt::Auth::ICredentialsProvider> m_credentialsProvider;
5255
std::chrono::milliseconds m_providerFuturesTimeoutMs;
56+
57+
// Thread-safe credential fetch coordination
58+
mutable std::atomic<bool> m_refreshInProgress{false};
59+
mutable std::shared_ptr<AWSCredentials> m_pendingCredentials;
60+
61+
// Helper methods for credential retrieval
62+
AWSCredentials waitForSharedCredentials() const;
63+
AWSCredentials extractCredentialsFromCrt(const Aws::Crt::Auth::Credentials& crtCredentials) const;
64+
AWSCredentials fetchCredentialsAsync();
5365
};
5466
} // namespace Auth
5567
} // namespace Aws

src/aws-cpp-sdk-core/include/aws/core/client/ClientConfiguration.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,11 @@ namespace Aws
544544
* Time out for the credentials future call.
545545
*/
546546
std::chrono::milliseconds retrieveCredentialsFutureTimeout = std::chrono::seconds(10);
547+
548+
/**
549+
* How long a cached credential set will be used for
550+
*/
551+
std::chrono::milliseconds credentialCacheCacheTTL = std::chrono::minutes(50);
547552
} stsCredentialsProviderConfig;
548553
} credentialProviderConfig;
549554
};

src/aws-cpp-sdk-core/source/auth/STSCredentialsProvider.cpp

Lines changed: 94 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,27 @@ STSAssumeRoleWebIdentityCredentialsProvider::STSAssumeRoleWebIdentityCredentials
3737
}
3838
return UUID::RandomUUID();
3939
}().c_str();
40-
m_credentialsProvider = Aws::Crt::Auth::CredentialsProvider::CreateCredentialsProviderSTSWebIdentity(stsConfig);
40+
41+
// Create underlying STS provider
42+
auto stsProvider = Aws::Crt::Auth::CredentialsProvider::CreateCredentialsProviderSTSWebIdentity(stsConfig);
43+
if (!stsProvider || !stsProvider->IsValid()) {
44+
AWS_LOGSTREAM_WARN(STS_LOG_TAG, "Failed to create underlying STS credentials provider");
45+
return;
46+
}
47+
48+
// Wrap with caching provider
49+
Aws::Crt::Auth::CredentialsProviderCachedConfig cachedConfig;
50+
cachedConfig.Provider = stsProvider;
51+
cachedConfig.CachedCredentialTTL = credentialsConfig.stsCredentialsProviderConfig.credentialCacheCacheTTL;
52+
53+
m_credentialsProvider = Aws::Crt::Auth::CredentialsProvider::CreateCredentialsProviderCached(cachedConfig);
4154
if (m_credentialsProvider && m_credentialsProvider->IsValid()) {
4255
m_state = STATE::INITIALIZED;
56+
AWS_LOGSTREAM_INFO(STS_LOG_TAG, "STS credentials provider initialized with cache TTL "
57+
<< cachedConfig.CachedCredentialTTL.count()
58+
<< " ms");
4359
} else {
44-
AWS_LOGSTREAM_WARN(STS_LOG_TAG, "Failed to create STS credentials provider");
60+
AWS_LOGSTREAM_WARN(STS_LOG_TAG, "Failed to create cached STS credentials provider");
4561
}
4662
}
4763

@@ -80,31 +96,15 @@ AWSCredentials STSAssumeRoleWebIdentityCredentialsProvider::GetAWSCredentials()
8096
AWS_LOGSTREAM_DEBUG(STS_LOG_TAG, "STSCredentialsProvider is not initialized, returning empty credentials");
8197
return AWSCredentials{};
8298
}
83-
AWSCredentials credentials{};
84-
auto refreshDone = false;
85-
m_credentialsProvider->GetCredentials(
86-
[this, &credentials, &refreshDone](std::shared_ptr<Aws::Crt::Auth::Credentials> crtCredentials, int errorCode) -> void {
87-
{
88-
const std::unique_lock<std::mutex> lock{m_refreshMutex};
89-
if (errorCode != AWS_ERROR_SUCCESS) {
90-
AWS_LOGSTREAM_ERROR(STS_LOG_TAG, "Failed to get credentials from STS: " << errorCode);
91-
} else {
92-
const auto accountIdCursor = crtCredentials->GetAccessKeyId();
93-
credentials.SetAWSAccessKeyId({reinterpret_cast<char*>(accountIdCursor.ptr), accountIdCursor.len});
94-
const auto secretKeuCursor = crtCredentials->GetSecretAccessKey();
95-
credentials.SetAWSSecretKey({reinterpret_cast<char*>(secretKeuCursor.ptr), secretKeuCursor.len});
96-
const auto expiration = crtCredentials->GetExpirationTimepointInSeconds();
97-
credentials.SetExpiration(DateTime{static_cast<double>(expiration)});
98-
const auto sessionTokenCursor = crtCredentials->GetSessionToken();
99-
credentials.SetSessionToken({reinterpret_cast<char*>(sessionTokenCursor.ptr), sessionTokenCursor.len});
100-
}
101-
refreshDone = true;
102-
}
103-
m_refreshSignal.notify_one();
104-
});
10599

106-
std::unique_lock<std::mutex> lock{m_refreshMutex};
107-
m_refreshSignal.wait_for(lock, m_providerFuturesTimeoutMs, [&refreshDone]() -> bool { return refreshDone; });
100+
// Thread-safe check: If another thread is already fetching, wait for its result
101+
auto expected = false;
102+
if (!m_refreshInProgress.compare_exchange_strong(expected, true)) {
103+
return waitForSharedCredentials();
104+
}
105+
106+
// This thread will fetch the credentials
107+
auto credentials = fetchCredentialsAsync();
108108

109109
if (!credentials.IsEmpty()) {
110110
credentials.AddUserAgentFeature(Aws::Client::UserAgentFeature::CREDENTIALS_STS_WEB_IDENTITY_TOKEN);
@@ -116,3 +116,71 @@ AWSCredentials STSAssumeRoleWebIdentityCredentialsProvider::GetAWSCredentials()
116116
void STSAssumeRoleWebIdentityCredentialsProvider::Reload() {
117117
AWS_LOGSTREAM_DEBUG(STS_LOG_TAG, "Calling reload on STSCredentialsProvider is a no-op and no longer in the call path");
118118
}
119+
120+
AWSCredentials STSAssumeRoleWebIdentityCredentialsProvider::waitForSharedCredentials() const {
121+
AWS_LOGSTREAM_DEBUG(STS_LOG_TAG, "Another thread is fetching credentials, waiting for result");
122+
std::unique_lock<std::mutex> lock{m_refreshMutex};
123+
m_refreshSignal.wait_for(lock, m_providerFuturesTimeoutMs, [this]() -> bool {
124+
return !m_refreshInProgress.load();
125+
});
126+
127+
if (m_pendingCredentials) {
128+
return *m_pendingCredentials;
129+
}
130+
131+
AWS_LOGSTREAM_WARN(STS_LOG_TAG, "Failed to get shared credentials after timeout");
132+
return AWSCredentials{};
133+
}
134+
135+
AWSCredentials STSAssumeRoleWebIdentityCredentialsProvider::extractCredentialsFromCrt(
136+
const Aws::Crt::Auth::Credentials& crtCredentials) const {
137+
AWSCredentials credentials{};
138+
const auto accountIdCursor = crtCredentials.GetAccessKeyId();
139+
credentials.SetAWSAccessKeyId({reinterpret_cast<char*>(accountIdCursor.ptr), accountIdCursor.len});
140+
const auto secretKeyCursor = crtCredentials.GetSecretAccessKey();
141+
credentials.SetAWSSecretKey({reinterpret_cast<char*>(secretKeyCursor.ptr), secretKeyCursor.len});
142+
const auto expiration = crtCredentials.GetExpirationTimepointInSeconds();
143+
credentials.SetExpiration(DateTime{static_cast<double>(expiration)});
144+
const auto sessionTokenCursor = crtCredentials.GetSessionToken();
145+
credentials.SetSessionToken({reinterpret_cast<char*>(sessionTokenCursor.ptr), sessionTokenCursor.len});
146+
return credentials;
147+
}
148+
149+
AWSCredentials STSAssumeRoleWebIdentityCredentialsProvider::fetchCredentialsAsync() {
150+
AWS_LOGSTREAM_DEBUG(STS_LOG_TAG, "Initiating credential fetch from STS/cache");
151+
152+
AWSCredentials credentials{};
153+
std::atomic<bool> refreshDone{false};
154+
155+
m_credentialsProvider->GetCredentials(
156+
[this, &credentials, &refreshDone](
157+
std::shared_ptr<Aws::Crt::Auth::Credentials> crtCredentials, int errorCode) -> void {
158+
159+
std::unique_lock<std::mutex> lock{m_refreshMutex};
160+
if (errorCode != AWS_ERROR_SUCCESS) {
161+
m_pendingCredentials.reset();
162+
} else {
163+
credentials = extractCredentialsFromCrt(*crtCredentials);
164+
165+
// Store for other waiting threads
166+
m_pendingCredentials = Aws::MakeShared<AWSCredentials>(STS_LOG_TAG, credentials);
167+
}
168+
refreshDone.store(true);
169+
m_refreshInProgress.store(false);
170+
m_refreshSignal.notify_all();
171+
});
172+
173+
// Wait for completion
174+
std::unique_lock<std::mutex> lock{m_refreshMutex};
175+
auto completed = m_refreshSignal.wait_for(lock, m_providerFuturesTimeoutMs, [&refreshDone]() -> bool {
176+
return refreshDone.load();
177+
});
178+
179+
if (!completed) {
180+
AWS_LOGSTREAM_ERROR(STS_LOG_TAG, "Credential fetch timed out after " << m_providerFuturesTimeoutMs.count() << "ms");
181+
m_refreshInProgress.store(false);
182+
m_refreshSignal.notify_all();
183+
}
184+
185+
return credentials;
186+
}

0 commit comments

Comments
 (0)