@@ -37,11 +37,26 @@ STSAssumeRoleWebIdentityCredentialsProvider::STSAssumeRoleWebIdentityCredentials
3737 }
3838 return UUID::RandomUUID ();
3939 }().c_str ();
40- m_credentialsProvider = Aws::Crt::Auth::CredentialsProvider::CreateCredentialsProviderSTSWebIdentity (stsConfig);
40+
41+ // Create underlying STS provider
42+ auto stsProvider = Aws::Crt::Auth::CredentialsProvider::CreateCredentialsProviderSTSWebIdentity (stsConfig);
43+ if (!stsProvider || !stsProvider->IsValid ()) {
44+ AWS_LOGSTREAM_WARN (STS_LOG_TAG, " Failed to create underlying STS credentials provider" );
45+ return ;
46+ }
47+
48+ // Wrap with caching provider
49+ Aws::Crt::Auth::CredentialsProviderCachedConfig cachedConfig;
50+ cachedConfig.Provider = stsProvider;
51+ cachedConfig.CachedCredentialTTL = credentialsConfig.stsCredentialsProviderConfig .credentialCacheCacheTTL ;
52+
53+ m_credentialsProvider = Aws::Crt::Auth::CredentialsProvider::CreateCredentialsProviderCached (cachedConfig);
4154 if (m_credentialsProvider && m_credentialsProvider->IsValid ()) {
4255 m_state = STATE::INITIALIZED;
56+ AWS_LOGSTREAM_INFO (STS_LOG_TAG,
57+ " STS credentials provider initialized with cache TTL " << cachedConfig.CachedCredentialTTL .count () << " ms" );
4358 } else {
44- AWS_LOGSTREAM_WARN (STS_LOG_TAG, " Failed to create STS credentials provider" );
59+ AWS_LOGSTREAM_WARN (STS_LOG_TAG, " Failed to create cached STS credentials provider" );
4560 }
4661}
4762
@@ -80,31 +95,15 @@ AWSCredentials STSAssumeRoleWebIdentityCredentialsProvider::GetAWSCredentials()
8095 AWS_LOGSTREAM_DEBUG (STS_LOG_TAG, " STSCredentialsProvider is not initialized, returning empty credentials" );
8196 return AWSCredentials{};
8297 }
83- AWSCredentials credentials{};
84- auto refreshDone = false ;
85- m_credentialsProvider->GetCredentials (
86- [this , &credentials, &refreshDone](std::shared_ptr<Aws::Crt::Auth::Credentials> crtCredentials, int errorCode) -> void {
87- {
88- const std::unique_lock<std::mutex> lock{m_refreshMutex};
89- if (errorCode != AWS_ERROR_SUCCESS) {
90- AWS_LOGSTREAM_ERROR (STS_LOG_TAG, " Failed to get credentials from STS: " << errorCode);
91- } else {
92- const auto accountIdCursor = crtCredentials->GetAccessKeyId ();
93- credentials.SetAWSAccessKeyId ({reinterpret_cast <char *>(accountIdCursor.ptr ), accountIdCursor.len });
94- const auto secretKeuCursor = crtCredentials->GetSecretAccessKey ();
95- credentials.SetAWSSecretKey ({reinterpret_cast <char *>(secretKeuCursor.ptr ), secretKeuCursor.len });
96- const auto expiration = crtCredentials->GetExpirationTimepointInSeconds ();
97- credentials.SetExpiration (DateTime{static_cast <double >(expiration)});
98- const auto sessionTokenCursor = crtCredentials->GetSessionToken ();
99- credentials.SetSessionToken ({reinterpret_cast <char *>(sessionTokenCursor.ptr ), sessionTokenCursor.len });
100- }
101- refreshDone = true ;
102- }
103- m_refreshSignal.notify_one ();
104- });
10598
106- std::unique_lock<std::mutex> lock{m_refreshMutex};
107- m_refreshSignal.wait_for (lock, m_providerFuturesTimeoutMs, [&refreshDone]() -> bool { return refreshDone; });
99+ // Thread-safe check: If another thread is already fetching, wait for its result
100+ auto expected = false ;
101+ if (!m_refreshInProgress.compare_exchange_strong (expected, true )) {
102+ return waitForSharedCredentials ();
103+ }
104+
105+ // This thread will fetch the credentials
106+ auto credentials = fetchCredentialsAsync ();
108107
109108 if (!credentials.IsEmpty ()) {
110109 credentials.AddUserAgentFeature (Aws::Client::UserAgentFeature::CREDENTIALS_STS_WEB_IDENTITY_TOKEN);
@@ -116,3 +115,65 @@ AWSCredentials STSAssumeRoleWebIdentityCredentialsProvider::GetAWSCredentials()
116115void STSAssumeRoleWebIdentityCredentialsProvider::Reload () {
117116 AWS_LOGSTREAM_DEBUG (STS_LOG_TAG, " Calling reload on STSCredentialsProvider is a no-op and no longer in the call path" );
118117}
118+
119+ AWSCredentials STSAssumeRoleWebIdentityCredentialsProvider::waitForSharedCredentials () const {
120+ AWS_LOGSTREAM_DEBUG (STS_LOG_TAG, " Another thread is fetching credentials, waiting for result" );
121+ std::unique_lock<std::mutex> lock{m_refreshMutex};
122+ m_refreshSignal.wait_for (lock, m_providerFuturesTimeoutMs, [this ]() -> bool { return !m_refreshInProgress.load (); });
123+
124+ if (m_pendingCredentials) {
125+ return *m_pendingCredentials;
126+ }
127+
128+ AWS_LOGSTREAM_WARN (STS_LOG_TAG, " Failed to get shared credentials after timeout" );
129+ return AWSCredentials{};
130+ }
131+
132+ AWSCredentials STSAssumeRoleWebIdentityCredentialsProvider::extractCredentialsFromCrt (
133+ const Aws::Crt::Auth::Credentials& crtCredentials) const {
134+ AWSCredentials credentials{};
135+ const auto accountIdCursor = crtCredentials.GetAccessKeyId ();
136+ credentials.SetAWSAccessKeyId ({reinterpret_cast <char *>(accountIdCursor.ptr ), accountIdCursor.len });
137+ const auto secretKeyCursor = crtCredentials.GetSecretAccessKey ();
138+ credentials.SetAWSSecretKey ({reinterpret_cast <char *>(secretKeyCursor.ptr ), secretKeyCursor.len });
139+ const auto expiration = crtCredentials.GetExpirationTimepointInSeconds ();
140+ credentials.SetExpiration (DateTime{static_cast <double >(expiration)});
141+ const auto sessionTokenCursor = crtCredentials.GetSessionToken ();
142+ credentials.SetSessionToken ({reinterpret_cast <char *>(sessionTokenCursor.ptr ), sessionTokenCursor.len });
143+ return credentials;
144+ }
145+
146+ AWSCredentials STSAssumeRoleWebIdentityCredentialsProvider::fetchCredentialsAsync () {
147+ AWS_LOGSTREAM_DEBUG (STS_LOG_TAG, " Initiating credential fetch from STS/cache" );
148+
149+ AWSCredentials credentials{};
150+ std::atomic<bool > refreshDone{false };
151+
152+ m_credentialsProvider->GetCredentials (
153+ [this , &credentials, &refreshDone](std::shared_ptr<Aws::Crt::Auth::Credentials> crtCredentials, int errorCode) -> void {
154+ std::unique_lock<std::mutex> lock{m_refreshMutex};
155+ if (errorCode != AWS_ERROR_SUCCESS) {
156+ m_pendingCredentials.reset ();
157+ } else {
158+ credentials = extractCredentialsFromCrt (*crtCredentials);
159+
160+ // Store for other waiting threads
161+ m_pendingCredentials = Aws::MakeShared<AWSCredentials>(STS_LOG_TAG, credentials);
162+ }
163+ refreshDone.store (true );
164+ m_refreshInProgress.store (false );
165+ m_refreshSignal.notify_all ();
166+ });
167+
168+ // Wait for completion
169+ std::unique_lock<std::mutex> lock{m_refreshMutex};
170+ auto completed = m_refreshSignal.wait_for (lock, m_providerFuturesTimeoutMs, [&refreshDone]() -> bool { return refreshDone.load (); });
171+
172+ if (!completed) {
173+ AWS_LOGSTREAM_ERROR (STS_LOG_TAG, " Credential fetch timed out after " << m_providerFuturesTimeoutMs.count () << " ms" );
174+ m_refreshInProgress.store (false );
175+ m_refreshSignal.notify_all ();
176+ }
177+
178+ return credentials;
179+ }
0 commit comments