Skip to content

Commit a617ca1

Browse files
committed
create-aws-client-with-region
1 parent e7c9629 commit a617ca1

File tree

4 files changed

+14
-9
lines changed

4 files changed

+14
-9
lines changed

pkg/cloud/identity/identity.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,11 @@ func GetAssumeRoleCredentials(roleIdentityProvider *AWSRolePrincipalTypeProvider
8080
}
8181

8282
// NewAWSRolePrincipalTypeProvider will create a new AWSRolePrincipalTypeProvider from an AWSClusterRoleIdentity.
83-
func NewAWSRolePrincipalTypeProvider(identity *infrav1.AWSClusterRoleIdentity, sourceProvider AWSPrincipalTypeProvider, log logger.Wrapper) *AWSRolePrincipalTypeProvider {
83+
func NewAWSRolePrincipalTypeProvider(identity *infrav1.AWSClusterRoleIdentity, sourceProvider AWSPrincipalTypeProvider, region string, log logger.Wrapper) *AWSRolePrincipalTypeProvider {
8484
return &AWSRolePrincipalTypeProvider{
8585
credentials: nil,
8686
stsClient: nil,
87+
region: region,
8788
Principal: identity,
8889
sourceProvider: sourceProvider,
8990
log: log.WithName("AWSRolePrincipalTypeProvider"),
@@ -130,6 +131,7 @@ func (p *AWSStaticPrincipalTypeProvider) IsExpired() bool {
130131
type AWSRolePrincipalTypeProvider struct {
131132
Principal *infrav1.AWSClusterRoleIdentity
132133
credentials *credentials.Credentials
134+
region string
133135
sourceProvider AWSPrincipalTypeProvider
134136
log logger.Wrapper
135137
stsClient stsiface.STSAPI
@@ -154,7 +156,7 @@ func (p *AWSRolePrincipalTypeProvider) Name() string {
154156
// Retrieve returns the credential values for the AWSRolePrincipalTypeProvider.
155157
func (p *AWSRolePrincipalTypeProvider) Retrieve() (credentials.Value, error) {
156158
if p.credentials == nil || p.IsExpired() {
157-
awsConfig := aws.NewConfig()
159+
awsConfig := aws.NewConfig().WithRegion(p.region)
158160
if p.sourceProvider != nil {
159161
sourceCreds, err := p.sourceProvider.Retrieve()
160162
if err != nil {

pkg/cloud/identity/identity_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ func TestAWSStaticPrincipalTypeProvider(t *testing.T) {
6161
roleProvider := &AWSRolePrincipalTypeProvider{
6262
credentials: nil,
6363
Principal: roleIdentity,
64+
region: "us-west-2",
6465
sourceProvider: staticProvider,
6566
stsClient: stsMock,
6667
}
@@ -78,6 +79,7 @@ func TestAWSStaticPrincipalTypeProvider(t *testing.T) {
7879
roleProvider2 := &AWSRolePrincipalTypeProvider{
7980
credentials: nil,
8081
Principal: roleIdentity2,
82+
region: "us-west-2",
8183
sourceProvider: roleProvider,
8284
stsClient: stsMock,
8385
}

pkg/cloud/scope/session.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ func sessionForClusterWithRegion(k8sClient client.Client, clusterScoper cloud.Se
120120
return endpoints.DefaultResolver().EndpointFor(service, region, optFns...)
121121
}
122122

123-
providers, err := getProvidersForCluster(context.Background(), k8sClient, clusterScoper, log)
123+
providers, err := getProvidersForCluster(context.Background(), k8sClient, clusterScoper, region, log)
124124
if err != nil {
125125
// could not get providers and retrieve the credentials
126126
conditions.MarkFalse(clusterScoper.InfraCluster(), infrav1.PrincipalCredentialRetrievedCondition, infrav1.PrincipalCredentialRetrievalFailedReason, clusterv1.ConditionSeverityError, err.Error())
@@ -256,6 +256,7 @@ func buildProvidersForRef(
256256
k8sClient client.Client,
257257
clusterScoper cloud.SessionMetadata,
258258
ref *infrav1.AWSIdentityReference,
259+
region string,
259260
log logger.Wrapper) ([]identity.AWSPrincipalTypeProvider, error) {
260261
if ref == nil {
261262
log.Trace("AWSCluster does not have a IdentityRef specified")
@@ -299,7 +300,7 @@ func buildProvidersForRef(
299300
setPrincipalUsageAllowedCondition(clusterScoper)
300301

301302
if roleIdentity.Spec.SourceIdentityRef != nil {
302-
providers, err = buildProvidersForRef(ctx, providers, k8sClient, clusterScoper, roleIdentity.Spec.SourceIdentityRef, log)
303+
providers, err = buildProvidersForRef(ctx, providers, k8sClient, clusterScoper, roleIdentity.Spec.SourceIdentityRef, region, log)
303304
if err != nil {
304305
return providers, err
305306
}
@@ -313,7 +314,7 @@ func buildProvidersForRef(
313314
}
314315
}
315316

316-
provider = identity.NewAWSRolePrincipalTypeProvider(roleIdentity, sourceProvider, log)
317+
provider = identity.NewAWSRolePrincipalTypeProvider(roleIdentity, sourceProvider, region, log)
317318
providers = append(providers, provider)
318319
default:
319320
return providers, errors.Errorf("No such provider known: '%s'", ref.Kind)
@@ -404,9 +405,9 @@ func buildAWSClusterControllerIdentity(ctx context.Context, identityObjectKey cl
404405
return nil
405406
}
406407

407-
func getProvidersForCluster(ctx context.Context, k8sClient client.Client, clusterScoper cloud.SessionMetadata, log logger.Wrapper) ([]identity.AWSPrincipalTypeProvider, error) {
408+
func getProvidersForCluster(ctx context.Context, k8sClient client.Client, clusterScoper cloud.SessionMetadata, region string, log logger.Wrapper) ([]identity.AWSPrincipalTypeProvider, error) {
408409
providers := make([]identity.AWSPrincipalTypeProvider, 0)
409-
providers, err := buildProvidersForRef(ctx, providers, k8sClient, clusterScoper, clusterScoper.IdentityRef(), log)
410+
providers, err := buildProvidersForRef(ctx, providers, k8sClient, clusterScoper, clusterScoper.IdentityRef(), region, log)
410411
if err != nil {
411412
return nil, err
412413
}

pkg/cloud/scope/session_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ func TestPrincipalParsing(t *testing.T) {
228228
Namespace: "default",
229229
},
230230
},
231-
AWSCluster: &infrav1.AWSCluster{},
231+
AWSCluster: &infrav1.AWSCluster{Spec: infrav1.AWSClusterSpec{Region: "us-west-2"}},
232232
},
233233
)
234234

@@ -489,7 +489,7 @@ func TestPrincipalParsing(t *testing.T) {
489489
k8sClient := fake.NewClientBuilder().WithScheme(scheme).Build()
490490
tc.setup(t, k8sClient)
491491
clusterScope.AWSCluster = &tc.awsCluster
492-
providers, err := getProvidersForCluster(context.Background(), k8sClient, clusterScope, logger.NewLogger(klog.Background()))
492+
providers, err := getProvidersForCluster(context.Background(), k8sClient, clusterScope, clusterScope.Region(), logger.NewLogger(klog.Background()))
493493
if tc.expectError {
494494
if err == nil {
495495
t.Fatal("Expected an error but didn't get one")

0 commit comments

Comments
 (0)