Skip to content

Commit 4498063

Browse files
Pass AWS client configuration to STSProfileWithWebIdentityCredentialsProvider. (#4641)
This PR applies a similar fix to #4616 but for `STSProfileWithWebIdentityCredentialsProvider`. That class was also adjusted to prevent a memory leak, to use the user-provided `STSClient` if available, and to use public APIs. The issue has been validated to be fixed. --- TYPE: BUG DESC: Fix HTTP requests for AWS assume role with web identity not honoring config options.
1 parent 8594341 commit 4498063

File tree

3 files changed

+55
-38
lines changed

3 files changed

+55
-38
lines changed

tiledb/sm/filesystem/s3.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1421,7 +1421,14 @@ Status S3::init_client() const {
14211421
}
14221422
case 16: {
14231423
credentials_provider_ = make_shared<
1424-
Aws::Auth::STSProfileWithWebIdentityCredentialsProvider>(HERE());
1424+
Aws::Auth::STSProfileWithWebIdentityCredentialsProvider>(
1425+
HERE(),
1426+
Aws::Auth::GetConfigProfileName(),
1427+
std::chrono::minutes(60),
1428+
[client_config](const auto& credentials) {
1429+
return make_shared<Aws::STS::STSClient>(
1430+
HERE(), credentials, client_config);
1431+
});
14251432
break;
14261433
}
14271434
default: {

tiledb/sm/filesystem/s3/STSProfileWithWebIdentityCredentialsProvider.cc

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
#include <aws/core/utils/logging/LogMacros.h>
4949
#include <aws/sts/STSClient.h>
5050
#include <aws/sts/model/AssumeRoleRequest.h>
51+
#include <aws/sts/model/AssumeRoleWithWebIdentityRequest.h>
5152

5253
#include <utility>
5354

@@ -77,8 +78,8 @@ STSProfileWithWebIdentityCredentialsProvider::
7778
STSProfileWithWebIdentityCredentialsProvider(
7879
const Aws::String& profileName,
7980
std::chrono::minutes duration,
80-
const std::function<Aws::STS::STSClient*(const AWSCredentials&)>&
81-
stsClientFactory)
81+
const std::function<std::shared_ptr<Aws::STS::STSClient>(
82+
const AWSCredentials&)>& stsClientFactory)
8283
: m_profileName(profileName)
8384
, m_duration(duration)
8485
, m_reloadFrequency(
@@ -430,27 +431,22 @@ STSProfileWithWebIdentityCredentialsProvider::GetCredentialsFromSTS(
430431
const Aws::String& externalID) {
431432
using namespace Aws::STS::Model;
432433
if (m_stsClientFactory) {
433-
return GetCredentialsFromSTSInternal(
434-
roleArn, externalID, m_stsClientFactory(credentials));
434+
auto client = m_stsClientFactory(credentials);
435+
return GetCredentialsFromSTSInternal(roleArn, externalID, client.get());
435436
}
436437

437438
Aws::STS::STSClient stsClient{credentials};
438439
return GetCredentialsFromSTSInternal(roleArn, externalID, &stsClient);
439440
}
440441

441-
AWSCredentials
442-
STSProfileWithWebIdentityCredentialsProvider::GetCredentialsFromWebIdentity(
443-
const Config::Profile& profile) {
442+
AWSCredentials STSProfileWithWebIdentityCredentialsProvider::
443+
GetCredentialsFromWebIdentityInternal(
444+
const Config::Profile& profile, Aws::STS::STSClient* client) {
445+
using namespace Aws::STS::Model;
444446
const Aws::String& m_roleArn = profile.GetRoleArn();
445447
Aws::String m_tokenFile = profile.GetValue("web_identity_token_file");
446448
Aws::String m_sessionName = profile.GetValue("role_session_name");
447449

448-
auto tmpRegion = profile.GetRegion();
449-
if (tmpRegion.empty()) {
450-
// Set same default as STSAssumeRoleWebIdentityCredentialsProvider
451-
tmpRegion = Aws::Region::US_EAST_1;
452-
}
453-
454450
if (m_sessionName.empty()) {
455451
m_sessionName = Aws::Utils::UUID::RandomUUID();
456452
}
@@ -467,30 +463,40 @@ STSProfileWithWebIdentityCredentialsProvider::GetCredentialsFromWebIdentity(
467463
return {};
468464
}
469465

470-
Internal::STSCredentialsClient::STSAssumeRoleWithWebIdentityRequest request{
471-
m_sessionName, m_roleArn, m_token};
472-
473-
Aws::Client::ClientConfiguration config;
474-
config.scheme = Aws::Http::Scheme::HTTPS;
475-
config.region = tmpRegion;
466+
AssumeRoleWithWebIdentityRequest request;
467+
request.SetRoleArn(m_roleArn);
468+
request.SetRoleSessionName(m_sessionName);
469+
request.SetWebIdentityToken(m_token);
476470

477-
Aws::Vector<Aws::String> retryableErrors;
478-
retryableErrors.push_back("IDPCommunicationError");
479-
retryableErrors.push_back("InvalidIdentityToken");
480-
481-
config.retryStrategy =
482-
Aws::MakeShared<Aws::Client::SpecifiedRetryableErrorsRetryStrategy>(
483-
CLASS_TAG, retryableErrors, 3 /*maxRetries*/);
471+
auto outcome = client->AssumeRoleWithWebIdentity(request);
472+
if (outcome.IsSuccess()) {
473+
const auto& modelCredentials = outcome.GetResult().GetCredentials();
474+
AWS_LOGSTREAM_TRACE(
475+
CLASS_TAG,
476+
"Successfully retrieved credentials with AWS_ACCESS_KEY: "
477+
<< modelCredentials.GetAccessKeyId());
478+
return {
479+
modelCredentials.GetAccessKeyId(),
480+
modelCredentials.GetSecretAccessKey(),
481+
modelCredentials.GetSessionToken(),
482+
modelCredentials.GetExpiration()};
483+
} else {
484+
AWS_LOGSTREAM_ERROR(CLASS_TAG, "failed to assume role" << m_roleArn);
485+
}
486+
return {};
487+
}
484488

485-
auto m_client =
486-
Aws::MakeUnique<Aws::Internal::STSCredentialsClient>(CLASS_TAG, config);
487-
auto result = m_client->GetAssumeRoleWithWebIdentityCredentials(request);
488-
AWS_LOGSTREAM_TRACE(
489-
CLASS_TAG,
490-
"Successfully retrieved credentials with AWS_ACCESS_KEY: "
491-
<< result.creds.GetAWSAccessKeyId());
489+
AWSCredentials
490+
STSProfileWithWebIdentityCredentialsProvider::GetCredentialsFromWebIdentity(
491+
const Config::Profile& profile) {
492+
using namespace Aws::STS::Model;
493+
if (m_stsClientFactory) {
494+
auto client = m_stsClientFactory({});
495+
return GetCredentialsFromWebIdentityInternal(profile, client.get());
496+
}
492497

493-
return result.creds;
498+
Aws::STS::STSClient stsClient{AWSCredentials{}};
499+
return GetCredentialsFromWebIdentityInternal(profile, &stsClient);
494500
}
495501

496502
#endif // HAVE_S3s

tiledb/sm/filesystem/s3/STSProfileWithWebIdentityCredentialsProvider.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ class /* AWS_IDENTITY_MANAGEMENT_API */
9797
STSProfileWithWebIdentityCredentialsProvider(
9898
const Aws::String& profileName,
9999
std::chrono::minutes duration,
100-
const std::function<Aws::STS::STSClient*(const AWSCredentials&)>&
101-
stsClientFactory);
100+
const std::function<std::shared_ptr<Aws::STS::STSClient>(
101+
const AWSCredentials&)>& stsClientFactory);
102102

103103
/**
104104
* Fetches the credentials set from STS following the rules defined in the
@@ -132,11 +132,15 @@ class /* AWS_IDENTITY_MANAGEMENT_API */
132132
const Aws::String& externalID,
133133
Aws::STS::STSClient* client);
134134

135+
AWSCredentials GetCredentialsFromWebIdentityInternal(
136+
const Config::Profile& profile, Aws::STS::STSClient* client);
137+
135138
Aws::String m_profileName;
136139
AWSCredentials m_credentials;
137140
const std::chrono::minutes m_duration;
138141
const std::chrono::milliseconds m_reloadFrequency;
139-
std::function<Aws::STS::STSClient*(const AWSCredentials&)> m_stsClientFactory;
142+
std::function<std::shared_ptr<Aws::STS::STSClient>(const AWSCredentials&)>
143+
m_stsClientFactory;
140144
};
141145
} // namespace Auth
142146
} // namespace Aws

0 commit comments

Comments
 (0)