Skip to content

Commit 12c046e

Browse files
committed
fix cache and threading concerns with STS Webidentity provider
1 parent 9000aa2 commit 12c046e

File tree

6 files changed

+111
-30
lines changed

6 files changed

+111
-30
lines changed

.cspell.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@
131131
"DENABLED",
132132
"DENFORCE",
133133
"APPSTORE",
134+
"UCRT",
134135
// Compiler and linker
135136
"Wpedantic",
136137
"Wextra",

.github/workflows/cspell.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,5 @@ jobs:
1717
- name: cspell
1818
run: |
1919
cd aws-sdk-cpp
20-
sudo cspell --fail-fast "src/**/*.txt" "*.txt" "src/aws-cpp-sdk-core/**/*.h" "src/aws-cpp-sdk-core/**/*.cpp"
20+
sudo cspell "src/**/*.txt" "*.txt" "src/aws-cpp-sdk-core/**/*.h" "src/aws-cpp-sdk-core/**/*.cpp"
2121
if [ $? -ne 0 ]; then sudo cspell "src/**/*.txt" "*.txt" "src/aws-cpp-sdk-core/**/*.h" "src/aws-cpp-sdk-core/**/*.cpp"; exit 1; fi;

src/aws-cpp-sdk-core/include/aws/core/auth/STSCredentialsProvider.h

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,14 @@
99
#include <aws/core/Core_EXPORTS.h>
1010
#include <aws/core/auth/AWSCredentialsProvider.h>
1111

12+
#include <atomic>
13+
#include <memory>
14+
1215
namespace Aws {
1316
namespace Crt {
1417
namespace Auth {
1518
class ICredentialsProvider;
19+
class Credentials;
1620
}
1721
}
1822
}
@@ -46,10 +50,19 @@ namespace Aws
4650
INITIALIZED,
4751
SHUT_DOWN,
4852
} m_state{STATE::SHUT_DOWN};
49-
std::mutex m_refreshMutex;
50-
std::condition_variable m_refreshSignal;
53+
mutable std::mutex m_refreshMutex;
54+
mutable std::condition_variable m_refreshSignal;
5155
std::shared_ptr<Aws::Crt::Auth::ICredentialsProvider> m_credentialsProvider;
5256
std::chrono::milliseconds m_providerFuturesTimeoutMs;
57+
58+
// Thread-safe credential fetch coordination
59+
mutable std::atomic<bool> m_refreshInProgress{false};
60+
mutable std::shared_ptr<AWSCredentials> m_pendingCredentials;
61+
62+
// Helper methods for credential retrieval
63+
AWSCredentials waitForSharedCredentials() const;
64+
AWSCredentials extractCredentialsFromCrt(const Aws::Crt::Auth::Credentials& crtCredentials) const;
65+
AWSCredentials fetchCredentialsAsync();
5366
};
5467
} // namespace Auth
5568
} // namespace Aws

src/aws-cpp-sdk-core/include/aws/core/client/ClientConfiguration.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,11 @@ namespace Aws
544544
* Time out for the credentials future call.
545545
*/
546546
std::chrono::milliseconds retrieveCredentialsFutureTimeout = std::chrono::seconds(10);
547+
548+
/**
549+
* How long a cached credential set will be used for
550+
*/
551+
std::chrono::milliseconds credentialCacheCacheTTL = std::chrono::minutes(50);
547552
} stsCredentialsProviderConfig;
548553
} credentialProviderConfig;
549554
};

src/aws-cpp-sdk-core/source/auth/STSCredentialsProvider.cpp

Lines changed: 87 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,26 @@ STSAssumeRoleWebIdentityCredentialsProvider::STSAssumeRoleWebIdentityCredentials
3737
}
3838
return UUID::RandomUUID();
3939
}().c_str();
40-
m_credentialsProvider = Aws::Crt::Auth::CredentialsProvider::CreateCredentialsProviderSTSWebIdentity(stsConfig);
40+
41+
// Create underlying STS provider
42+
auto stsProvider = Aws::Crt::Auth::CredentialsProvider::CreateCredentialsProviderSTSWebIdentity(stsConfig);
43+
if (!stsProvider || !stsProvider->IsValid()) {
44+
AWS_LOGSTREAM_WARN(STS_LOG_TAG, "Failed to create underlying STS credentials provider");
45+
return;
46+
}
47+
48+
// Wrap with caching provider
49+
Aws::Crt::Auth::CredentialsProviderCachedConfig cachedConfig;
50+
cachedConfig.Provider = stsProvider;
51+
cachedConfig.CachedCredentialTTL = credentialsConfig.stsCredentialsProviderConfig.credentialCacheCacheTTL;
52+
53+
m_credentialsProvider = Aws::Crt::Auth::CredentialsProvider::CreateCredentialsProviderCached(cachedConfig);
4154
if (m_credentialsProvider && m_credentialsProvider->IsValid()) {
4255
m_state = STATE::INITIALIZED;
56+
AWS_LOGSTREAM_INFO(STS_LOG_TAG,
57+
"STS credentials provider initialized with cache TTL " << cachedConfig.CachedCredentialTTL.count() << " ms");
4358
} else {
44-
AWS_LOGSTREAM_WARN(STS_LOG_TAG, "Failed to create STS credentials provider");
59+
AWS_LOGSTREAM_WARN(STS_LOG_TAG, "Failed to create cached STS credentials provider");
4560
}
4661
}
4762

@@ -80,31 +95,15 @@ AWSCredentials STSAssumeRoleWebIdentityCredentialsProvider::GetAWSCredentials()
8095
AWS_LOGSTREAM_DEBUG(STS_LOG_TAG, "STSCredentialsProvider is not initialized, returning empty credentials");
8196
return AWSCredentials{};
8297
}
83-
AWSCredentials credentials{};
84-
auto refreshDone = false;
85-
m_credentialsProvider->GetCredentials(
86-
[this, &credentials, &refreshDone](std::shared_ptr<Aws::Crt::Auth::Credentials> crtCredentials, int errorCode) -> void {
87-
{
88-
const std::unique_lock<std::mutex> lock{m_refreshMutex};
89-
if (errorCode != AWS_ERROR_SUCCESS) {
90-
AWS_LOGSTREAM_ERROR(STS_LOG_TAG, "Failed to get credentials from STS: " << errorCode);
91-
} else {
92-
const auto accountIdCursor = crtCredentials->GetAccessKeyId();
93-
credentials.SetAWSAccessKeyId({reinterpret_cast<char*>(accountIdCursor.ptr), accountIdCursor.len});
94-
const auto secretKeuCursor = crtCredentials->GetSecretAccessKey();
95-
credentials.SetAWSSecretKey({reinterpret_cast<char*>(secretKeuCursor.ptr), secretKeuCursor.len});
96-
const auto expiration = crtCredentials->GetExpirationTimepointInSeconds();
97-
credentials.SetExpiration(DateTime{static_cast<double>(expiration)});
98-
const auto sessionTokenCursor = crtCredentials->GetSessionToken();
99-
credentials.SetSessionToken({reinterpret_cast<char*>(sessionTokenCursor.ptr), sessionTokenCursor.len});
100-
}
101-
refreshDone = true;
102-
}
103-
m_refreshSignal.notify_one();
104-
});
10598

106-
std::unique_lock<std::mutex> lock{m_refreshMutex};
107-
m_refreshSignal.wait_for(lock, m_providerFuturesTimeoutMs, [&refreshDone]() -> bool { return refreshDone; });
99+
// Thread-safe check: If another thread is already fetching, wait for its result
100+
auto expected = false;
101+
if (!m_refreshInProgress.compare_exchange_strong(expected, true)) {
102+
return waitForSharedCredentials();
103+
}
104+
105+
// This thread will fetch the credentials
106+
auto credentials = fetchCredentialsAsync();
108107

109108
if (!credentials.IsEmpty()) {
110109
credentials.AddUserAgentFeature(Aws::Client::UserAgentFeature::CREDENTIALS_STS_WEB_IDENTITY_TOKEN);
@@ -116,3 +115,65 @@ AWSCredentials STSAssumeRoleWebIdentityCredentialsProvider::GetAWSCredentials()
116115
void STSAssumeRoleWebIdentityCredentialsProvider::Reload() {
117116
AWS_LOGSTREAM_DEBUG(STS_LOG_TAG, "Calling reload on STSCredentialsProvider is a no-op and no longer in the call path");
118117
}
118+
119+
AWSCredentials STSAssumeRoleWebIdentityCredentialsProvider::waitForSharedCredentials() const {
120+
AWS_LOGSTREAM_DEBUG(STS_LOG_TAG, "Another thread is fetching credentials, waiting for result");
121+
std::unique_lock<std::mutex> lock{m_refreshMutex};
122+
m_refreshSignal.wait_for(lock, m_providerFuturesTimeoutMs, [this]() -> bool { return !m_refreshInProgress.load(); });
123+
124+
if (m_pendingCredentials) {
125+
return *m_pendingCredentials;
126+
}
127+
128+
AWS_LOGSTREAM_WARN(STS_LOG_TAG, "Failed to get shared credentials after timeout");
129+
return AWSCredentials{};
130+
}
131+
132+
AWSCredentials STSAssumeRoleWebIdentityCredentialsProvider::extractCredentialsFromCrt(
133+
const Aws::Crt::Auth::Credentials& crtCredentials) const {
134+
AWSCredentials credentials{};
135+
const auto accountIdCursor = crtCredentials.GetAccessKeyId();
136+
credentials.SetAWSAccessKeyId({reinterpret_cast<char*>(accountIdCursor.ptr), accountIdCursor.len});
137+
const auto secretKeyCursor = crtCredentials.GetSecretAccessKey();
138+
credentials.SetAWSSecretKey({reinterpret_cast<char*>(secretKeyCursor.ptr), secretKeyCursor.len});
139+
const auto expiration = crtCredentials.GetExpirationTimepointInSeconds();
140+
credentials.SetExpiration(DateTime{static_cast<double>(expiration)});
141+
const auto sessionTokenCursor = crtCredentials.GetSessionToken();
142+
credentials.SetSessionToken({reinterpret_cast<char*>(sessionTokenCursor.ptr), sessionTokenCursor.len});
143+
return credentials;
144+
}
145+
146+
AWSCredentials STSAssumeRoleWebIdentityCredentialsProvider::fetchCredentialsAsync() {
147+
AWS_LOGSTREAM_DEBUG(STS_LOG_TAG, "Initiating credential fetch from STS/cache");
148+
149+
AWSCredentials credentials{};
150+
std::atomic<bool> refreshDone{false};
151+
152+
m_credentialsProvider->GetCredentials(
153+
[this, &credentials, &refreshDone](std::shared_ptr<Aws::Crt::Auth::Credentials> crtCredentials, int errorCode) -> void {
154+
std::unique_lock<std::mutex> lock{m_refreshMutex};
155+
if (errorCode != AWS_ERROR_SUCCESS) {
156+
m_pendingCredentials.reset();
157+
} else {
158+
credentials = extractCredentialsFromCrt(*crtCredentials);
159+
160+
// Store for other waiting threads
161+
m_pendingCredentials = Aws::MakeShared<AWSCredentials>(STS_LOG_TAG, credentials);
162+
}
163+
refreshDone.store(true);
164+
m_refreshInProgress.store(false);
165+
m_refreshSignal.notify_all();
166+
});
167+
168+
// Wait for completion
169+
std::unique_lock<std::mutex> lock{m_refreshMutex};
170+
auto completed = m_refreshSignal.wait_for(lock, m_providerFuturesTimeoutMs, [&refreshDone]() -> bool { return refreshDone.load(); });
171+
172+
if (!completed) {
173+
AWS_LOGSTREAM_ERROR(STS_LOG_TAG, "Credential fetch timed out after " << m_providerFuturesTimeoutMs.count() << "ms");
174+
m_refreshInProgress.store(false);
175+
m_refreshSignal.notify_all();
176+
}
177+
178+
return credentials;
179+
}

src/aws-cpp-sdk-core/source/client/AWSErrorMarshaller.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,8 @@ void JsonErrorMarshallerQueryCompatible::MarshallError(AWSError<CoreErrors>& err
301301
}
302302

303303
AWSError<CoreErrors> RpcV2ErrorMarshaller::Marshall(const Aws::Http::HttpResponse& httpResponse) const {
304-
return AWSError<CoreErrors>(CoreErrors::UNKNOWN, "Not implemented yet", "RpcV2ErrorMarshaller::Marshall not implemeneted yet: " + httpResponse.GetClientErrorMessage(), false);
304+
return AWSError<CoreErrors>(CoreErrors::UNKNOWN, "Not implemented yet",
305+
"RpcV2ErrorMarshaller::Marshall not implemented yet: " + httpResponse.GetClientErrorMessage(), false);
305306
}
306307

307308
AWSError<CoreErrors> RpcV2ErrorMarshaller::BuildAWSError(const std::shared_ptr<Http::HttpResponse>& httpResponse) const {

0 commit comments

Comments
 (0)