@@ -37,11 +37,27 @@ STSAssumeRoleWebIdentityCredentialsProvider::STSAssumeRoleWebIdentityCredentials
3737 }
3838 return UUID::RandomUUID ();
3939 }().c_str ();
40- m_credentialsProvider = Aws::Crt::Auth::CredentialsProvider::CreateCredentialsProviderSTSWebIdentity (stsConfig);
40+
41+ // Create underlying STS provider
42+ auto stsProvider = Aws::Crt::Auth::CredentialsProvider::CreateCredentialsProviderSTSWebIdentity (stsConfig);
43+ if (!stsProvider || !stsProvider->IsValid ()) {
44+ AWS_LOGSTREAM_WARN (STS_LOG_TAG, " Failed to create underlying STS credentials provider" );
45+ return ;
46+ }
47+
48+ // Wrap with caching provider
49+ Aws::Crt::Auth::CredentialsProviderCachedConfig cachedConfig;
50+ cachedConfig.Provider = stsProvider;
51+ cachedConfig.CachedCredentialTTL = credentialsConfig.stsCredentialsProviderConfig .credentialCacheCacheTTL ;
52+
53+ m_credentialsProvider = Aws::Crt::Auth::CredentialsProvider::CreateCredentialsProviderCached (cachedConfig);
4154 if (m_credentialsProvider && m_credentialsProvider->IsValid ()) {
4255 m_state = STATE::INITIALIZED;
56+ AWS_LOGSTREAM_INFO (STS_LOG_TAG, " STS credentials provider initialized with cache TTL "
57+ << cachedConfig.CachedCredentialTTL .count ()
58+ << " ms" );
4359 } else {
44- AWS_LOGSTREAM_WARN (STS_LOG_TAG, " Failed to create STS credentials provider" );
60+ AWS_LOGSTREAM_WARN (STS_LOG_TAG, " Failed to create cached STS credentials provider" );
4561 }
4662}
4763
@@ -77,34 +93,18 @@ STSAssumeRoleWebIdentityCredentialsProvider::~STSAssumeRoleWebIdentityCredential
7793
7894AWSCredentials STSAssumeRoleWebIdentityCredentialsProvider::GetAWSCredentials () {
7995 if (m_state != STATE::INITIALIZED) {
80- AWS_LOGSTREAM_DEBUG (STS_LOG_TAG, " STSCredentialsProvider is not initialized, returning empty credentials" );
96+ AWS_LOGSTREAM_WARN (STS_LOG_TAG, " STSCredentialsProvider is not initialized, returning empty credentials" );
8197 return AWSCredentials{};
8298 }
83- AWSCredentials credentials{};
84- auto refreshDone = false ;
85- m_credentialsProvider->GetCredentials (
86- [this , &credentials, &refreshDone](std::shared_ptr<Aws::Crt::Auth::Credentials> crtCredentials, int errorCode) -> void {
87- {
88- const std::unique_lock<std::mutex> lock{m_refreshMutex};
89- if (errorCode != AWS_ERROR_SUCCESS) {
90- AWS_LOGSTREAM_ERROR (STS_LOG_TAG, " Failed to get credentials from STS: " << errorCode);
91- } else {
92- const auto accountIdCursor = crtCredentials->GetAccessKeyId ();
93- credentials.SetAWSAccessKeyId ({reinterpret_cast <char *>(accountIdCursor.ptr ), accountIdCursor.len });
94- const auto secretKeuCursor = crtCredentials->GetSecretAccessKey ();
95- credentials.SetAWSSecretKey ({reinterpret_cast <char *>(secretKeuCursor.ptr ), secretKeuCursor.len });
96- const auto expiration = crtCredentials->GetExpirationTimepointInSeconds ();
97- credentials.SetExpiration (DateTime{static_cast <double >(expiration)});
98- const auto sessionTokenCursor = crtCredentials->GetSessionToken ();
99- credentials.SetSessionToken ({reinterpret_cast <char *>(sessionTokenCursor.ptr ), sessionTokenCursor.len });
100- }
101- refreshDone = true ;
102- }
103- m_refreshSignal.notify_one ();
104- });
10599
106- std::unique_lock<std::mutex> lock{m_refreshMutex};
107- m_refreshSignal.wait_for (lock, m_providerFuturesTimeoutMs, [&refreshDone]() -> bool { return refreshDone; });
100+ // Thread-safe check: If another thread is already fetching, wait for its result
101+ auto expected = false ;
102+ if (!m_refreshInProgress.compare_exchange_strong (expected, true )) {
103+ return waitForSharedCredentials ();
104+ }
105+
106+ // This thread will fetch the credentials
107+ auto credentials = fetchCredentialsAsync ();
108108
109109 if (!credentials.IsEmpty ()) {
110110 credentials.AddUserAgentFeature (Aws::Client::UserAgentFeature::CREDENTIALS_STS_WEB_IDENTITY_TOKEN);
@@ -116,3 +116,71 @@ AWSCredentials STSAssumeRoleWebIdentityCredentialsProvider::GetAWSCredentials()
116116void STSAssumeRoleWebIdentityCredentialsProvider::Reload () {
117117 AWS_LOGSTREAM_DEBUG (STS_LOG_TAG, " Calling reload on STSCredentialsProvider is a no-op and no longer in the call path" );
118118}
119+
120+ AWSCredentials STSAssumeRoleWebIdentityCredentialsProvider::waitForSharedCredentials () const {
121+ AWS_LOGSTREAM_DEBUG (STS_LOG_TAG, " Another thread is fetching credentials, waiting for result" );
122+ std::unique_lock<std::mutex> lock{m_refreshMutex};
123+ m_refreshSignal.wait_for (lock, m_providerFuturesTimeoutMs, [this ]() -> bool {
124+ return !m_refreshInProgress.load ();
125+ });
126+
127+ if (m_pendingCredentials) {
128+ return *m_pendingCredentials;
129+ }
130+
131+ AWS_LOGSTREAM_WARN (STS_LOG_TAG, " Failed to get shared credentials after timeout" );
132+ return AWSCredentials{};
133+ }
134+
135+ AWSCredentials STSAssumeRoleWebIdentityCredentialsProvider::extractCredentialsFromCrt (
136+ const Aws::Crt::Auth::Credentials& crtCredentials) const {
137+ AWSCredentials credentials{};
138+ const auto accountIdCursor = crtCredentials.GetAccessKeyId ();
139+ credentials.SetAWSAccessKeyId ({reinterpret_cast <char *>(accountIdCursor.ptr ), accountIdCursor.len });
140+ const auto secretKeyCursor = crtCredentials.GetSecretAccessKey ();
141+ credentials.SetAWSSecretKey ({reinterpret_cast <char *>(secretKeyCursor.ptr ), secretKeyCursor.len });
142+ const auto expiration = crtCredentials.GetExpirationTimepointInSeconds ();
143+ credentials.SetExpiration (DateTime{static_cast <double >(expiration)});
144+ const auto sessionTokenCursor = crtCredentials.GetSessionToken ();
145+ credentials.SetSessionToken ({reinterpret_cast <char *>(sessionTokenCursor.ptr ), sessionTokenCursor.len });
146+ return credentials;
147+ }
148+
149+ AWSCredentials STSAssumeRoleWebIdentityCredentialsProvider::fetchCredentialsAsync () {
150+ AWS_LOGSTREAM_DEBUG (STS_LOG_TAG, " Initiating credential fetch from STS/cache" );
151+
152+ AWSCredentials credentials{};
153+ std::atomic<bool > refreshDone{false };
154+
155+ m_credentialsProvider->GetCredentials (
156+ [this , &credentials, &refreshDone](
157+ std::shared_ptr<Aws::Crt::Auth::Credentials> crtCredentials, int errorCode) -> void {
158+
159+ std::unique_lock<std::mutex> lock{m_refreshMutex};
160+ if (errorCode != AWS_ERROR_SUCCESS) {
161+ m_pendingCredentials.reset ();
162+ } else {
163+ credentials = extractCredentialsFromCrt (*crtCredentials);
164+
165+ // Store for other waiting threads
166+ m_pendingCredentials = std::make_shared<AWSCredentials>(credentials);
167+ }
168+ refreshDone.store (true );
169+ m_refreshInProgress.store (false );
170+ m_refreshSignal.notify_all ();
171+ });
172+
173+ // Wait for completion
174+ std::unique_lock<std::mutex> lock{m_refreshMutex};
175+ auto completed = m_refreshSignal.wait_for (lock, m_providerFuturesTimeoutMs, [&refreshDone]() -> bool {
176+ return refreshDone.load ();
177+ });
178+
179+ if (!completed) {
180+ AWS_LOGSTREAM_ERROR (STS_LOG_TAG, " Credential fetch timed out after " << m_providerFuturesTimeoutMs.count () << " ms" );
181+ m_refreshInProgress.store (false );
182+ m_refreshSignal.notify_all ();
183+ }
184+
185+ return credentials;
186+ }
0 commit comments