@@ -34,12 +34,17 @@ class MockSTSClient : public STSClient
3434 Model::AssumeRoleOutcome AssumeRole (const Model::AssumeRoleRequest& request) const override
3535 {
3636 m_capturedRequest = request;
37- return m_mockedOutcome;
37+ if (!m_mockedOutcomes.empty ()) {
38+ auto outcome = m_mockedOutcomes.front ();
39+ m_mockedOutcomes.pop ();
40+ return outcome;
41+ }
42+ return STSError{};
3843 }
3944
4045 void MockAssumeRole (const Model::AssumeRoleOutcome& outcome)
4146 {
42- m_mockedOutcome = outcome;
47+ m_mockedOutcomes. push ( outcome) ;
4348 }
4449
4550 const Model::AssumeRoleRequest& CapturedRequest () const
@@ -54,7 +59,7 @@ class MockSTSClient : public STSClient
5459
5560private:
5661 mutable Model::AssumeRoleRequest m_capturedRequest;
57- Model::AssumeRoleOutcome m_mockedOutcome ;
62+ mutable Aws::Queue< Model::AssumeRoleOutcome> m_mockedOutcomes ;
5863 AWSCredentials m_credentials;
5964};
6065
@@ -621,4 +626,72 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleRecursivelyCircularReference
621626
622627 ASSERT_TRUE (actualCredentials.IsExpiredOrEmpty ());
623628}
629+
630+ TEST_F (STSProfileCredentialsProviderTest, ShouldRefreshCredentialsNearExpiry)
631+ {
632+ Aws::OFStream configFile {m_configFilename.c_str (), Aws::OFStream::out | Aws::OFStream::trunc};
633+
634+ configFile << std::endl;
635+ configFile << " [default]" << std::endl;
636+ configFile << " source_profile = default" << std::endl;
637+ configFile << " role_arn = " << ROLE_ARN_1 << std::endl;
638+ configFile << " aws_access_key_id = " << ACCESS_KEY_ID_1 << std::endl;
639+ configFile << " aws_secret_access_key = " << SECRET_ACCESS_KEY_ID_1 << std::endl;
640+ configFile.close ();
641+ Aws::Config::ReloadCachedConfigFile ();
642+
643+ constexpr auto roleSessionDuration = std::chrono::seconds (5 );
644+ const DateTime expiryTime{DateTime::Now () + roleSessionDuration};
645+
646+ Model::Credentials stsCredentials;
647+ stsCredentials.WithAccessKeyId (ACCESS_KEY_ID_2)
648+ .WithSecretAccessKey (SECRET_ACCESS_KEY_ID_2)
649+ .WithSessionToken (SESSION_TOKEN)
650+ .WithExpiration (expiryTime);
651+
652+ Model::Credentials refreshedStsCredentials;
653+ refreshedStsCredentials.WithAccessKeyId (ACCESS_KEY_ID_3)
654+ .WithSecretAccessKey (SECRET_ACCESS_KEY_ID_3)
655+ .WithSessionToken (SESSION_TOKEN)
656+ .WithExpiration (expiryTime);
657+
658+ Model::AssumeRoleResult mockResult;
659+ mockResult.SetCredentials (stsCredentials);
660+ Model::AssumeRoleResult refreshedMockResult;
661+ refreshedMockResult.SetCredentials (refreshedStsCredentials);
662+ Aws::UniquePtr<MockSTSClient> stsClient;
663+ std::once_flag stsClientInitialized;
664+
665+ int stsCallCounter = 0 ;
666+ STSProfileCredentialsProvider credsProvider (" default" , std::chrono::minutes (60 ), [&](const AWSCredentials& creds)
667+ {
668+ ++stsCallCounter;
669+ std::call_once (stsClientInitialized, [&] {
670+ stsClient = Aws::MakeUnique<MockSTSClient>(CLASS_TAG, creds);
671+ stsClient->MockAssumeRole (mockResult);
672+ stsClient->MockAssumeRole (refreshedMockResult);
673+ });
674+ return stsClient.get ();
675+ });
676+
677+ auto actualCredentials = credsProvider.GetAWSCredentials ();
678+
679+ ASSERT_STREQ (ACCESS_KEY_ID_2, actualCredentials.GetAWSAccessKeyId ().c_str ());
680+ ASSERT_STREQ (SECRET_ACCESS_KEY_ID_2, actualCredentials.GetAWSSecretKey ().c_str ());
681+ ASSERT_STREQ (SESSION_TOKEN, actualCredentials.GetSessionToken ().c_str ());
682+ ASSERT_EQ (expiryTime, actualCredentials.GetExpiration ());
683+
684+ ASSERT_EQ (1 , stsCallCounter);
685+ ASSERT_TRUE (stsClient);
686+ ASSERT_STREQ (ACCESS_KEY_ID_1, stsClient->Credentials ().GetAWSAccessKeyId ().c_str ());
687+ ASSERT_STREQ (SECRET_ACCESS_KEY_ID_1, stsClient->Credentials ().GetAWSSecretKey ().c_str ());
688+
689+ actualCredentials = credsProvider.GetAWSCredentials ();
690+ ASSERT_STREQ (ACCESS_KEY_ID_3, actualCredentials.GetAWSAccessKeyId ().c_str ());
691+ ASSERT_STREQ (SECRET_ACCESS_KEY_ID_3, actualCredentials.GetAWSSecretKey ().c_str ());
692+ ASSERT_STREQ (SESSION_TOKEN, actualCredentials.GetSessionToken ().c_str ());
693+ ASSERT_EQ (expiryTime, actualCredentials.GetExpiration ());
694+ // should have called refresh
695+ ASSERT_EQ (2 , stsCallCounter);
696+ }
624697} // namespace
0 commit comments