Skip to content

Commit 4f70bd4

Browse files
committed
new sso test
1 parent 889a4b8 commit 4f70bd4

File tree

1 file changed

+257
-0
lines changed

1 file changed

+257
-0
lines changed
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
2+
#include <aws/testing/AwsCppSdkGTestSuite.h>
3+
#include <aws/testing/AwsTestHelpers.h>
4+
#include <aws/testing/mocks/aws/client/MockAWSClient.h>
5+
#include <aws/testing/mocks/http/MockHttpClient.h>
6+
#include <aws/testing/mocks/aws/auth/MockAWSHttpResourceClient.h>
7+
#include <aws/testing/platform/PlatformTesting.h>
8+
#include <aws/core/auth/AWSCredentialsProvider.h>
9+
#include <aws/core/auth/AWSCredentialsProviderChain.h>
10+
#include <aws/core/auth/SSOCredentialsProvider.h>
11+
#include <aws/core/auth/GeneralHTTPCredentialsProvider.h>
12+
#include <aws/core/client/AWSClient.h>
13+
#include <aws/core/utils/StringUtils.h>
14+
#include <aws/core/utils/HashingUtils.h>
15+
#include <aws/core/platform/FileSystem.h>
16+
#include <aws/core/utils/FileSystemUtils.h>
17+
#include <fstream>
18+
#include <sys/stat.h>
19+
#include <thread>
20+
21+
using namespace Aws::Client;
22+
using namespace Aws::Auth;
23+
using namespace Aws::Http;
24+
using namespace Aws::FileSystem;
25+
using namespace Aws::Http::Standard;
26+
27+
namespace {
28+
const char* TEST_LOG_TAG = "CredentialTrackingTest";
29+
}
30+
31+
32+
// Custom client that uses default credential provider for testing
33+
class CredentialTestingClient : public Aws::Client::AWSClient
34+
{
35+
public:
36+
explicit CredentialTestingClient(const Aws::Client::ClientConfiguration& configuration)
37+
: AWSClient(configuration,
38+
Aws::MakeShared<Aws::Client::AWSAuthV4Signer>(TEST_LOG_TAG,
39+
Aws::MakeShared<DefaultAWSCredentialsProviderChain>(TEST_LOG_TAG),
40+
"service", configuration.region),
41+
Aws::MakeShared<MockAWSErrorMarshaller>(TEST_LOG_TAG))
42+
{
43+
}
44+
45+
// Constructor with custom credential provider for IMDS test
46+
explicit CredentialTestingClient(const Aws::Client::ClientConfiguration& configuration,
47+
std::shared_ptr<AWSCredentialsProvider> credentialsProvider)
48+
: AWSClient(configuration,
49+
Aws::MakeShared<Aws::Client::AWSAuthV4Signer>(TEST_LOG_TAG,
50+
credentialsProvider,
51+
"service", configuration.region),
52+
Aws::MakeShared<MockAWSErrorMarshaller>(TEST_LOG_TAG))
53+
{
54+
}
55+
56+
Aws::Client::HttpResponseOutcome MakeRequest(const Aws::AmazonWebServiceRequest& request)
57+
{
58+
auto uri = Aws::Http::URI("https://test.com");
59+
return AWSClient::AttemptExhaustively(uri, request, Aws::Http::HttpMethod::HTTP_POST, Aws::Auth::SIGV4_SIGNER);
60+
}
61+
62+
const char* GetServiceClientName() const override { return "CredentialTestingClient"; }
63+
64+
protected:
65+
Aws::Client::AWSError<Aws::Client::CoreErrors> BuildAWSError(const std::shared_ptr<Aws::Http::HttpResponse>& response) const override
66+
{
67+
AWS_UNREFERENCED_PARAM(response);
68+
return Aws::Client::AWSError<Aws::Client::CoreErrors>(Aws::Client::CoreErrors::UNKNOWN, false);
69+
}
70+
};
71+
72+
class SSOCredentialsProviderTest : public Aws::Testing::AwsCppSdkGTestSuite
73+
{
74+
protected:
75+
void SetUp() override
76+
{
77+
AwsCppSdkGTestSuite::SetUp();
78+
79+
// Create test directories
80+
Aws::String uuid = Aws::Utils::UUID::RandomUUID();
81+
m_testDir = "/tmp/aws_sso_test_" + uuid;
82+
m_configPath = m_testDir + "/config";
83+
m_ssoDir = m_testDir + "/sso/cache";
84+
85+
Aws::FileSystem::CreateDirectoryIfNotExists(m_testDir.c_str());
86+
Aws::FileSystem::CreateDirectoryIfNotExists((m_testDir + "/sso").c_str());
87+
Aws::FileSystem::CreateDirectoryIfNotExists(m_ssoDir.c_str());
88+
89+
// Save original AWS_CONFIG_FILE value
90+
m_originalConfigFile = Aws::Environment::GetEnv("AWS_CONFIG_FILE");
91+
92+
// Set AWS_CONFIG_FILE to our test config
93+
Aws::Environment::SetEnv("AWS_CONFIG_FILE", m_configPath.c_str(), 1);
94+
95+
// Set up mock HTTP client
96+
mockHttpClient = Aws::MakeShared<MockHttpClient>("SSOTest");
97+
mockHttpClientFactory = Aws::MakeShared<MockHttpClientFactory>("SSOTest");
98+
mockHttpClientFactory->SetClient(mockHttpClient);
99+
SetHttpClientFactory(mockHttpClientFactory);
100+
}
101+
102+
void TearDown() override
103+
{
104+
// Restore original AWS_CONFIG_FILE
105+
if (!m_originalConfigFile.empty())
106+
{
107+
Aws::Environment::SetEnv("AWS_CONFIG_FILE", m_originalConfigFile.c_str(), 1);
108+
}
109+
else
110+
{
111+
Aws::Environment::UnSetEnv("AWS_CONFIG_FILE");
112+
}
113+
114+
// Reset HTTP clients
115+
if (mockHttpClient) {
116+
mockHttpClient->Reset();
117+
mockHttpClient = nullptr;
118+
}
119+
if (mockHttpClientFactory) {
120+
mockHttpClientFactory = nullptr;
121+
}
122+
123+
// Cleanup test files
124+
Aws::FileSystem::RemoveFileIfExists(m_configPath.c_str());
125+
126+
//AwsCppSdkGTestSuite::TearDown();
127+
}
128+
129+
void CreateTestConfig(const Aws::String& startUrl = "https://test.awsapps.com/start")
130+
{
131+
std::ofstream configFile(m_configPath.c_str());
132+
configFile << "[default]\n"
133+
<< "sso_account_id = 123456789012\n"
134+
<< "sso_region = us-east-1\n"
135+
<< "sso_role_name = TestRole\n"
136+
<< "sso_start_url = " << startUrl << std::endl;
137+
configFile.close();
138+
}
139+
140+
void CreateSSOTokenFile(const Aws::String& startUrl)
141+
{
142+
// Use a simple hash for the test (SHA1 of the start URL)
143+
Aws::String hashedStartUrl = "d033e22ae348aeb5660fc2140aec35850c4da997"; // Simple test hash
144+
Aws::String tokenPath = m_ssoDir + "/" + hashedStartUrl + ".json";
145+
146+
// Create token file with future expiration
147+
std::ofstream tokenFile(tokenPath.c_str());
148+
auto futureTime = Aws::Utils::DateTime::Now() + std::chrono::hours(1);
149+
150+
tokenFile << "{\n"
151+
<< " \"accessToken\": \"test-token\",\n"
152+
<< " \"expiresAt\": \"" << futureTime.ToGmtString(Aws::Utils::DateFormat::ISO_8601) << "\",\n"
153+
<< " \"region\": \"us-east-1\",\n"
154+
<< " \"startUrl\": \"" << startUrl << "\"\n"
155+
<< "}" << std::endl;
156+
tokenFile.close();
157+
}
158+
159+
160+
void RunTestWithCredentialsProvider(const std::shared_ptr<AWSCredentialsProvider>& credentialsProvider, const Aws::String& id) {
161+
// Setup mock response
162+
std::shared_ptr<HttpRequest> requestTmp =
163+
CreateHttpRequest(Aws::Http::URI("dummy"), Aws::Http::HttpMethod::HTTP_POST,
164+
Aws::Utils::Stream::DefaultResponseStreamFactoryMethod);
165+
auto successResponse = Aws::MakeShared<Standard::StandardHttpResponse>(TEST_LOG_TAG, requestTmp);
166+
successResponse->SetResponseCode(HttpResponseCode::OK);
167+
successResponse->GetResponseBody() << "{}";
168+
mockHttpClient->AddResponseToReturn(successResponse);
169+
170+
// Create client configuration
171+
Aws::Client::ClientConfigurationInitValues cfgInit;
172+
cfgInit.shouldDisableIMDS = true;
173+
Aws::Client::ClientConfiguration clientConfig(cfgInit);
174+
clientConfig.region = Aws::Region::US_EAST_1;
175+
176+
// Create credential testing client that uses default provider chain
177+
CredentialTestingClient client(clientConfig, credentialsProvider);
178+
179+
// Create mock request
180+
AmazonWebServiceRequestMock mockRequest;
181+
182+
// Make request
183+
auto outcome = client.MakeRequest(mockRequest);
184+
ASSERT_TRUE(outcome.IsSuccess());
185+
186+
// Verify User-Agent contains environment credentials tracking
187+
auto lastRequest = mockHttpClient->GetMostRecentHttpRequest();
188+
EXPECT_TRUE(lastRequest.HasHeader(Aws::Http::USER_AGENT_HEADER));
189+
const auto& userAgent = lastRequest.GetHeaderValue(Aws::Http::USER_AGENT_HEADER);
190+
EXPECT_FALSE(userAgent.empty());
191+
192+
const auto userAgentParsed = Aws::Utils::StringUtils::Split(userAgent, ' ');
193+
194+
// Verify there's only one m/ section (no duplicate m/ sections)
195+
int mSectionCount = 0;
196+
for (const auto& part : userAgentParsed) {
197+
if (part.find("m/") != Aws::String::npos) {
198+
mSectionCount++;
199+
}
200+
}
201+
EXPECT_EQ(1, mSectionCount);
202+
203+
// Check for environment credentials business metric (g) in user agent
204+
auto businessMetrics = std::find_if(userAgentParsed.begin(), userAgentParsed.end(),
205+
[&id](const Aws::String& value) { return value.find("m/") != Aws::String::npos && value.find(id) != Aws::String::npos; });
206+
207+
EXPECT_TRUE(businessMetrics != userAgentParsed.end());
208+
}
209+
210+
Aws::String m_testDir;
211+
Aws::String m_configPath;
212+
Aws::String m_ssoDir;
213+
Aws::String m_originalConfigFile;
214+
std::shared_ptr<MockHttpClient> mockHttpClient;
215+
std::shared_ptr<MockHttpClientFactory> mockHttpClientFactory;
216+
};
217+
218+
TEST_F(SSOCredentialsProviderTest, TestSSOCredentialsTracking)
219+
{
220+
const Aws::String startUrl = "https://test.awsapps.com/start";
221+
222+
// Create test configuration
223+
CreateTestConfig(startUrl);
224+
CreateSSOTokenFile(startUrl);
225+
226+
// Mock SSO credentials API response
227+
std::shared_ptr<Aws::Http::HttpRequest> ssoRequest = CreateHttpRequest(
228+
Aws::Http::URI("https://portal.sso.us-east-1.amazonaws.com/federation/credentials"),
229+
Aws::Http::HttpMethod::HTTP_GET,
230+
Aws::Utils::Stream::DefaultResponseStreamFactoryMethod);
231+
232+
auto ssoResponse = Aws::MakeShared<Aws::Http::Standard::StandardHttpResponse>("SSOTest", ssoRequest);
233+
ssoResponse->SetResponseCode(Aws::Http::HttpResponseCode::OK);
234+
ssoResponse->GetResponseBody() << R"({
235+
"roleCredentials": {
236+
"accessKeyId": "AKIAIOSFODNN7EXAMPLE",
237+
"secretAccessKey": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
238+
"sessionToken": "AQoDYXdzEJr...",
239+
"expiration": )" << (Aws::Utils::DateTime::Now().Millis() + 3600000) << R"(
240+
}
241+
})";
242+
mockHttpClient->AddResponseToReturn(ssoResponse);
243+
244+
// Create SSO credentials provider as shared_ptr
245+
auto ssoProvider = Aws::MakeShared<SSOCredentialsProvider>(TEST_LOG_TAG);
246+
247+
// Get credentials using regular method
248+
auto credentials = ssoProvider->GetAWSCredentials();
249+
250+
// Verify credentials were retrieved
251+
EXPECT_FALSE(credentials.IsEmpty());
252+
EXPECT_EQ("AKIAIOSFODNN7EXAMPLE", credentials.GetAWSAccessKeyId());
253+
EXPECT_EQ("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", credentials.GetAWSSecretKey());
254+
255+
// Test credential tracking
256+
RunTestWithCredentialsProvider(ssoProvider, "s");
257+
}

0 commit comments

Comments
 (0)