|
2 | 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
3 | 3 | * SPDX-License-Identifier: Apache-2.0. |
4 | 4 | */ |
5 | | -#include <aws/core/Globals.h> |
| 5 | + |
| 6 | + |
6 | 7 | #include <aws/core/auth/STSCredentialsProvider.h> |
7 | | -#include <aws/core/client/ClientConfiguration.h> |
| 8 | +#include <aws/core/config/AWSProfileConfigLoader.h> |
8 | 9 | #include <aws/core/platform/Environment.h> |
9 | | -#include <aws/crt/auth/Credentials.h> |
| 10 | +#include <aws/core/platform/FileSystem.h> |
| 11 | +#include <aws/core/utils/logging/LogMacros.h> |
| 12 | +#include <aws/core/utils/StringUtils.h> |
| 13 | +#include <aws/core/utils/FileSystemUtils.h> |
| 14 | +#include <aws/core/client/SpecifiedRetryableErrorsRetryStrategy.h> |
| 15 | +#include <aws/core/utils/StringUtils.h> |
| 16 | +#include <aws/core/utils/UUID.h> |
| 17 | +#include <cstdlib> |
| 18 | +#include <fstream> |
| 19 | +#include <string.h> |
| 20 | +#include <climits> |
| 21 | + |
10 | 22 |
|
11 | | -using namespace Aws::Auth; |
12 | 23 | using namespace Aws::Utils; |
| 24 | +using namespace Aws::Utils::Logging; |
| 25 | +using namespace Aws::Auth; |
| 26 | +using namespace Aws::Internal; |
| 27 | +using namespace Aws::FileSystem; |
| 28 | +using namespace Aws::Client; |
| 29 | +using Aws::Utils::Threading::ReaderLockGuard; |
| 30 | +using Aws::Utils::Threading::WriterLockGuard; |
13 | 31 |
|
14 | | -namespace { |
15 | | -const char* STS_LOG_TAG = "STSAssumeRoleWebIdentityCredentialsProvider"; |
16 | | -} |
| 32 | +static const char STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG[] = "STSAssumeRoleWithWebIdentityCredentialsProvider"; |
| 33 | +static const int STS_CREDENTIAL_PROVIDER_EXPIRATION_GRACE_PERIOD = 5 * 60 * 1000; // 5 Minutes. |
17 | 34 |
|
18 | | -STSAssumeRoleWebIdentityCredentialsProvider::STSAssumeRoleWebIdentityCredentialsProvider( |
19 | | - Aws::Client::ClientConfiguration::CredentialProviderConfiguration credentialsConfig) |
20 | | - : m_credentialsProvider(nullptr), m_providerFuturesTimeoutMs(credentialsConfig.stsCredentialsProviderConfig.retrieveCredentialsFutureTimeout) |
| 35 | +STSAssumeRoleWebIdentityCredentialsProvider::STSAssumeRoleWebIdentityCredentialsProvider(Aws::Client::ClientConfiguration::CredentialProviderConfiguration credentialsConfig): |
| 36 | + m_initialized(false) |
21 | 37 | { |
22 | | - Aws::Crt::Auth::CredentialsProviderSTSWebIdentityConfig stsConfig{}; |
23 | | - stsConfig.Bootstrap = GetDefaultClientBootstrap(); |
24 | | - Aws::Crt::Io::TlsContextOptions tlsCtxOptions = Aws::Crt::Io::TlsContextOptions::InitDefaultClient(); |
25 | | - const Aws::Crt::Io::TlsContext tlsContext(tlsCtxOptions, Aws::Crt::Io::TlsMode::CLIENT); |
26 | | - const auto tlsOptions = Aws::GetDefaultTlsConnectionOptions(); |
27 | | - if (tlsOptions) { |
28 | | - stsConfig.TlsConnectionOptions = *tlsOptions; |
29 | | - } |
30 | | - stsConfig.Region = credentialsConfig.region.c_str(); |
31 | | - stsConfig.TokenFilePath = credentialsConfig.stsCredentialsProviderConfig.tokenFilePath.c_str(); |
32 | | - stsConfig.RoleArn = credentialsConfig.stsCredentialsProviderConfig.roleArn.c_str(); |
33 | | - stsConfig.SessionName = [&credentialsConfig]() -> Aws::String { |
34 | | - if (!credentialsConfig.stsCredentialsProviderConfig.sessionName.empty()) { |
35 | | - return credentialsConfig.stsCredentialsProviderConfig.sessionName; |
| 38 | + m_roleArn = Aws::Environment::GetEnv("AWS_ROLE_ARN"); |
| 39 | + m_tokenFile = Aws::Environment::GetEnv("AWS_WEB_IDENTITY_TOKEN_FILE"); |
| 40 | + m_sessionName = Aws::Environment::GetEnv("AWS_ROLE_SESSION_NAME"); |
| 41 | + |
| 42 | + // check profile_config if either m_roleArn or m_tokenFile is not loaded from environment variable |
| 43 | + // region source is not enforced, but we need it to construct sts endpoint, if we can't find from environment, we should check if it's set in config file. |
| 44 | + if (m_roleArn.empty() || m_tokenFile.empty()) |
| 45 | + { |
| 46 | + auto profile = Aws::Config::GetCachedConfigProfile(credentialsConfig.profile); |
| 47 | + // If either of these two were not found from environment, use whatever found for all three in config file |
| 48 | + if (m_roleArn.empty() || m_tokenFile.empty()) |
| 49 | + { |
| 50 | + m_roleArn = profile.GetRoleArn(); |
| 51 | + m_tokenFile = profile.GetValue("web_identity_token_file"); |
| 52 | + m_sessionName = profile.GetValue("role_session_name"); |
| 53 | + } |
36 | 54 | } |
37 | | - return UUID::RandomUUID(); |
38 | | - }().c_str(); |
39 | | - m_credentialsProvider = Aws::Crt::Auth::CredentialsProvider::CreateCredentialsProviderSTSWebIdentity(stsConfig); |
40 | | - if (m_credentialsProvider && m_credentialsProvider->IsValid()) { |
41 | | - m_state = STATE::INITIALIZED; |
42 | | - } else { |
43 | | - AWS_LOGSTREAM_WARN(STS_LOG_TAG, "Failed to create STS credentials provider"); |
44 | | - } |
| 55 | + |
| 56 | + if (m_tokenFile.empty()) |
| 57 | + { |
| 58 | + AWS_LOGSTREAM_WARN(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Token file must be specified to use STS AssumeRole web identity creds provider."); |
| 59 | + return; // No need to do further constructing |
| 60 | + } |
| 61 | + else |
| 62 | + { |
| 63 | + AWS_LOGSTREAM_DEBUG(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Resolved token_file from profile_config or environment variable to be " << m_tokenFile); |
| 64 | + } |
| 65 | + |
| 66 | + if (m_roleArn.empty()) |
| 67 | + { |
| 68 | + AWS_LOGSTREAM_WARN(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "RoleArn must be specified to use STS AssumeRole web identity creds provider."); |
| 69 | + return; // No need to do further constructing |
| 70 | + } |
| 71 | + else |
| 72 | + { |
| 73 | + AWS_LOGSTREAM_DEBUG(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Resolved role_arn from profile_config or environment variable to be " << m_roleArn); |
| 74 | + } |
| 75 | + |
| 76 | + if (m_sessionName.empty()) |
| 77 | + { |
| 78 | + m_sessionName = Aws::Utils::UUID::PseudoRandomUUID(); |
| 79 | + } |
| 80 | + else |
| 81 | + { |
| 82 | + AWS_LOGSTREAM_DEBUG(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Resolved session_name from profile_config or environment variable to be " << m_sessionName); |
| 83 | + } |
| 84 | + |
| 85 | + Aws::Client::ClientConfiguration config; |
| 86 | + config.scheme = Aws::Http::Scheme::HTTPS; |
| 87 | + config.region = credentialsConfig.region; |
| 88 | + Aws::Vector<Aws::String> retryableErrors; |
| 89 | + retryableErrors.push_back("IDPCommunicationError"); |
| 90 | + retryableErrors.push_back("InvalidIdentityToken"); |
| 91 | + |
| 92 | + config.retryStrategy = Aws::MakeShared<SpecifiedRetryableErrorsRetryStrategy>(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, retryableErrors, 3/*maxRetries*/); |
| 93 | + |
| 94 | + m_client = Aws::MakeUnique<Aws::Internal::STSCredentialsClient>(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, config); |
| 95 | + m_initialized = true; |
| 96 | + AWS_LOGSTREAM_INFO(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Creating STS AssumeRole with web identity creds provider."); |
45 | 97 | } |
46 | 98 |
|
47 | | -Aws::String GetLegacySettingFromEnvOrProfile(const Aws::String& envVar, |
48 | | - std::function<Aws::String (Aws::Config::Profile)> profileFetchFunction) |
49 | | -{ |
50 | | - auto value = Aws::Environment::GetEnv(envVar.c_str()); |
51 | | - if (value.empty()) { |
| 99 | +Aws::String LegacyGetRegion() { |
| 100 | + auto region = Aws::Environment::GetEnv("AWS_DEFAULT_REGION"); |
| 101 | + if (region.empty()) { |
52 | 102 | auto profile = Aws::Config::GetCachedConfigProfile(Aws::Auth::GetConfigProfileName()); |
53 | | - value = profileFetchFunction(profile); |
| 103 | + region = profile.GetRegion(); |
54 | 104 | } |
55 | | - return value; |
| 105 | + return region; |
56 | 106 | } |
57 | 107 |
|
58 | 108 | STSAssumeRoleWebIdentityCredentialsProvider::STSAssumeRoleWebIdentityCredentialsProvider() |
59 | 109 | : STSAssumeRoleWebIdentityCredentialsProvider( |
60 | | - Aws::Client::ClientConfiguration::CredentialProviderConfiguration{ |
61 | | - Aws::Auth::GetConfigProfileName(), |
62 | | - GetLegacySettingFromEnvOrProfile("AWS_DEFAULT_REGION", |
63 | | - [](const Aws::Config::Profile& profile) -> Aws::String { return profile.GetRegion(); }), |
64 | | - {}, |
65 | | - { |
66 | | - GetLegacySettingFromEnvOrProfile("AWS_ROLE_ARN", |
67 | | - [](const Aws::Config::Profile& profile) -> Aws::String { return profile.GetRoleArn(); }), |
68 | | - GetLegacySettingFromEnvOrProfile("AWS_ROLE_SESSION_NAME", |
69 | | - [](const Aws::Config::Profile& profile) -> Aws::String { return profile.GetValue("role_session_name"); }), |
70 | | - GetLegacySettingFromEnvOrProfile("AWS_WEB_IDENTITY_TOKEN_FILE", |
71 | | - [](const Aws::Config::Profile& profile) -> Aws::String { return profile.GetValue("web_identity_token_file"); }) |
72 | | - }}) |
73 | | -{} |
74 | | - |
75 | | -STSAssumeRoleWebIdentityCredentialsProvider::~STSAssumeRoleWebIdentityCredentialsProvider() = default; |
76 | | - |
77 | | -AWSCredentials STSAssumeRoleWebIdentityCredentialsProvider::GetAWSCredentials() { |
78 | | - if (m_state != STATE::INITIALIZED) { |
79 | | - AWS_LOGSTREAM_DEBUG(STS_LOG_TAG, "STSCredentialsProvider is not initialized, returning empty credentials"); |
80 | | - return AWSCredentials{}; |
81 | | - } |
82 | | - AWSCredentials credentials{}; |
83 | | - auto refreshDone = false; |
84 | | - m_credentialsProvider->GetCredentials( |
85 | | - [this, &credentials, &refreshDone](std::shared_ptr<Aws::Crt::Auth::Credentials> crtCredentials, int errorCode) -> void { |
86 | | - { |
87 | | - const std::unique_lock<std::mutex> lock{m_refreshMutex}; |
88 | | - if (errorCode != AWS_ERROR_SUCCESS) { |
89 | | - AWS_LOGSTREAM_ERROR(STS_LOG_TAG, "Failed to get credentials from STS: " << errorCode); |
90 | | - } else { |
91 | | - const auto accountIdCursor = crtCredentials->GetAccessKeyId(); |
92 | | - credentials.SetAWSAccessKeyId({reinterpret_cast<char*>(accountIdCursor.ptr), accountIdCursor.len}); |
93 | | - const auto secretKeuCursor = crtCredentials->GetSecretAccessKey(); |
94 | | - credentials.SetAWSSecretKey({reinterpret_cast<char*>(secretKeuCursor.ptr), secretKeuCursor.len}); |
95 | | - const auto expiration = crtCredentials->GetExpirationTimepointInSeconds(); |
96 | | - credentials.SetExpiration(DateTime{static_cast<double>(expiration)}); |
97 | | - const auto sessionTokenCursor = crtCredentials->GetSessionToken(); |
98 | | - credentials.SetSessionToken({reinterpret_cast<char*>(sessionTokenCursor.ptr), sessionTokenCursor.len}); |
99 | | - } |
100 | | - refreshDone = true; |
101 | | - } |
102 | | - m_refreshSignal.notify_one(); |
103 | | - }); |
| 110 | + Aws::Client::ClientConfiguration::CredentialProviderConfiguration{Aws::Auth::GetConfigProfileName(), LegacyGetRegion(), {}}) {} |
| 111 | + |
| 112 | +AWSCredentials STSAssumeRoleWebIdentityCredentialsProvider::GetAWSCredentials() |
| 113 | +{ |
| 114 | + // A valid client means required information like role arn and token file were constructed correctly. |
| 115 | + // We can use this provider to load creds, otherwise, we can just return empty creds. |
| 116 | + if (!m_initialized) |
| 117 | + { |
| 118 | + return Aws::Auth::AWSCredentials(); |
| 119 | + } |
| 120 | + RefreshIfExpired(); |
| 121 | + ReaderLockGuard guard(m_reloadLock); |
| 122 | + return m_credentials; |
| 123 | +} |
| 124 | + |
| 125 | +void STSAssumeRoleWebIdentityCredentialsProvider::Reload() |
| 126 | +{ |
| 127 | + AWS_LOGSTREAM_INFO(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Credentials have expired, attempting to renew from STS."); |
104 | 128 |
|
105 | | - std::unique_lock<std::mutex> lock{m_refreshMutex}; |
106 | | - m_refreshSignal.wait_for(lock, m_providerFuturesTimeoutMs, [&refreshDone]() -> bool { return refreshDone; }); |
107 | | - return credentials; |
| 129 | + Aws::IFStream tokenFile(m_tokenFile.c_str()); |
| 130 | + if(tokenFile) |
| 131 | + { |
| 132 | + Aws::String token((std::istreambuf_iterator<char>(tokenFile)), std::istreambuf_iterator<char>()); |
| 133 | + m_token = token; |
| 134 | + } |
| 135 | + else |
| 136 | + { |
| 137 | + AWS_LOGSTREAM_ERROR(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Can't open token file: " << m_tokenFile); |
| 138 | + return; |
| 139 | + } |
| 140 | + STSCredentialsClient::STSAssumeRoleWithWebIdentityRequest request {m_sessionName, m_roleArn, m_token}; |
| 141 | + |
| 142 | + auto result = m_client->GetAssumeRoleWithWebIdentityCredentials(request); |
| 143 | + AWS_LOGSTREAM_TRACE(STS_ASSUME_ROLE_WEB_IDENTITY_LOG_TAG, "Successfully retrieved credentials with AWS_ACCESS_KEY: " << result.creds.GetAWSAccessKeyId()); |
| 144 | + m_credentials = result.creds; |
| 145 | +} |
| 146 | + |
| 147 | +bool STSAssumeRoleWebIdentityCredentialsProvider::ExpiresSoon() const |
| 148 | +{ |
| 149 | + return ((m_credentials.GetExpiration() - Aws::Utils::DateTime::Now()).count() < STS_CREDENTIAL_PROVIDER_EXPIRATION_GRACE_PERIOD); |
108 | 150 | } |
109 | 151 |
|
110 | | -void STSAssumeRoleWebIdentityCredentialsProvider::Reload() { |
111 | | - AWS_LOGSTREAM_DEBUG(STS_LOG_TAG, "Calling reload on STSCredentialsProvider is a no-op and no longer in the call path"); |
| 152 | +void STSAssumeRoleWebIdentityCredentialsProvider::RefreshIfExpired() |
| 153 | +{ |
| 154 | + ReaderLockGuard guard(m_reloadLock); |
| 155 | + if (!m_credentials.IsEmpty() && !ExpiresSoon()) |
| 156 | + { |
| 157 | + return; |
| 158 | + } |
| 159 | + |
| 160 | + guard.UpgradeToWriterLock(); |
| 161 | + if (!m_credentials.IsExpiredOrEmpty() && !ExpiresSoon()) // double-checked lock to avoid refreshing twice |
| 162 | + { |
| 163 | + return; |
| 164 | + } |
| 165 | + |
| 166 | + Reload(); |
112 | 167 | } |
0 commit comments