Skip to content

Commit 47820e4

Browse files
authored
Merge pull request #5601 from alexander-demicev/stsv2
🌱 Migrate sts to sdk v2
2 parents 49f7c86 + df25a52 commit 47820e4

22 files changed

+346
-605
lines changed

controlplane/eks/controllers/awsmanagedcontrolplane_controller_test.go

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,19 @@ import (
2020
"context"
2121
"encoding/base64"
2222
"fmt"
23-
"net/http"
2423
"strconv"
2524
"testing"
2625
"time"
2726

2827
"github.com/aws/aws-sdk-go-v2/aws"
28+
signerv4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
2929
"github.com/aws/aws-sdk-go-v2/service/ec2"
3030
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
3131
"github.com/aws/aws-sdk-go-v2/service/eks"
3232
ekstypes "github.com/aws/aws-sdk-go-v2/service/eks/types"
3333
"github.com/aws/aws-sdk-go-v2/service/iam"
3434
iamtypes "github.com/aws/aws-sdk-go-v2/service/iam/types"
35-
stsrequest "github.com/aws/aws-sdk-go/aws/request"
36-
"github.com/aws/aws-sdk-go/service/sts"
35+
stsv2 "github.com/aws/aws-sdk-go-v2/service/sts"
3736
"github.com/aws/smithy-go"
3837
"github.com/golang/mock/gomock"
3938
. "github.com/onsi/gomega"
@@ -54,8 +53,8 @@ import (
5453
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/iamauth/mock_iamauth"
5554
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/mock_services"
5655
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/network"
57-
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/s3/mock_stsiface"
5856
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/securitygroup"
57+
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts/mock_stsiface"
5958
"sigs.k8s.io/cluster-api-provider-aws/v2/test/mocks"
6059
clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1"
6160
"sigs.k8s.io/cluster-api/util"
@@ -76,7 +75,7 @@ func TestAWSManagedControlPlaneReconcilerIntegrationTests(t *testing.T) {
7675
ec2Mock *mocks.MockEC2API
7776
eksMock *mock_eksiface.MockEKSAPI
7877
iamMock *mock_iamauth.MockIAMAPI
79-
stsMock *mock_stsiface.MockSTSAPI
78+
stsMock *mock_stsiface.MockSTSClient
8079
awsNodeMock *mock_services.MockAWSNodeInterface
8180
iamAuthenticatorMock *mock_services.MockIAMAuthenticatorInterface
8281
kubeProxyMock *mock_services.MockKubeProxyInterface
@@ -96,7 +95,7 @@ func TestAWSManagedControlPlaneReconcilerIntegrationTests(t *testing.T) {
9695
ec2Mock = mocks.NewMockEC2API(mockCtrl)
9796
eksMock = mock_eksiface.NewMockEKSAPI(mockCtrl)
9897
iamMock = mock_iamauth.NewMockIAMAPI(mockCtrl)
99-
stsMock = mock_stsiface.NewMockSTSAPI(mockCtrl)
98+
stsMock = mock_stsiface.NewMockSTSClient(mockCtrl)
10099

101100
// Mocking these as well, since the actual implementation requires a remote client to an actual cluster
102101
awsNodeMock = mock_services.NewMockAWSNodeInterface(mockCtrl)
@@ -854,7 +853,7 @@ func mockedEKSControlPlaneIAMRole(g *WithT, iamRec *mock_iamauth.MockIAMAPIMockR
854853
}).After(getPolicyCall).Return(&iam.AttachRolePolicyOutput{}, nil)
855854
}
856855

857-
func mockedEKSCluster(ctx context.Context, g *WithT, eksRec *mock_eksiface.MockEKSAPIMockRecorder, iamRec *mock_iamauth.MockIAMAPIMockRecorder, ec2Rec *mocks.MockEC2APIMockRecorder, stsRec *mock_stsiface.MockSTSAPIMockRecorder, awsNodeRec *mock_services.MockAWSNodeInterfaceMockRecorder, kubeProxyRec *mock_services.MockKubeProxyInterfaceMockRecorder, iamAuthenticatorRec *mock_services.MockIAMAuthenticatorInterfaceMockRecorder) {
856+
func mockedEKSCluster(ctx context.Context, g *WithT, eksRec *mock_eksiface.MockEKSAPIMockRecorder, iamRec *mock_iamauth.MockIAMAPIMockRecorder, ec2Rec *mocks.MockEC2APIMockRecorder, stsRec *mock_stsiface.MockSTSClientMockRecorder, awsNodeRec *mock_services.MockAWSNodeInterfaceMockRecorder, kubeProxyRec *mock_services.MockKubeProxyInterfaceMockRecorder, iamAuthenticatorRec *mock_services.MockIAMAuthenticatorInterfaceMockRecorder) {
858857
describeClusterCall := eksRec.DescribeCluster(ctx, &eks.DescribeClusterInput{
859858
Name: aws.String("test-cluster"),
860859
}).Return(nil, &ekstypes.ResourceNotFoundException{
@@ -948,12 +947,14 @@ func mockedEKSCluster(ctx context.Context, g *WithT, eksRec *mock_eksiface.MockE
948947
})).Return(
949948
clusterSgDesc, nil)
950949

951-
req, err := http.NewRequest(http.MethodGet, "foobar", http.NoBody)
952-
g.Expect(err).To(BeNil())
953-
stsRec.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{}).Return(&stsrequest.Request{
954-
HTTPRequest: req,
955-
Operation: &stsrequest.Operation{},
956-
}, &sts.GetCallerIdentityOutput{})
950+
stsRec.PresignGetCallerIdentity(gomock.Any(), gomock.Any(), gomock.Any()).Return(&signerv4.PresignedHTTPRequest{
951+
URL: "https://example.com",
952+
}, nil)
953+
stsRec.GetCallerIdentity(gomock.Any(), gomock.Any()).Return(&stsv2.GetCallerIdentityOutput{
954+
Account: aws.String("123456789012"),
955+
Arn: aws.String("arn:aws:iam::123456789012:user/test-user"),
956+
UserId: aws.String("AIDACKCEVSQ6C2EXAMPLE"),
957+
}, nil)
957958

958959
eksRec.TagResource(ctx, &eks.TagResourceInput{
959960
ResourceArn: clusterARN,

controlplane/rosa/controllers/rosacontrolplane_controller.go

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ import (
3030
"time"
3131

3232
stsv2 "github.com/aws/aws-sdk-go-v2/service/sts"
33-
sts "github.com/aws/aws-sdk-go/service/sts"
34-
"github.com/aws/aws-sdk-go/service/sts/stsiface"
3533
"github.com/google/go-cmp/cmp"
3634
idputils "github.com/openshift-online/ocm-common/pkg/idp/utils"
3735
cmv1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1"
@@ -62,6 +60,7 @@ import (
6260
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/annotations"
6361
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud"
6462
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/scope"
63+
stsiface "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts"
6564
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger"
6665
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/rosa"
6766
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/utils"
@@ -92,7 +91,7 @@ type ROSAControlPlaneReconciler struct {
9291
WatchFilterValue string
9392
WaitInfraPeriod time.Duration
9493
Endpoints []scope.ServiceEndpoint
95-
NewStsClient func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSAPI
94+
NewStsClient func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSClient
9695
NewOCMClient func(ctx context.Context, rosaScope *scope.ROSAControlPlaneScope) (rosa.OCMClient, error)
9796
// Exposing the restClientConfig for integration test. No need to initialize.
9897
restClientConfig *restclient.Config
@@ -221,7 +220,11 @@ func (r *ROSAControlPlaneReconciler) reconcileNormal(ctx context.Context, rosaSc
221220
return ctrl.Result{}, fmt.Errorf("failed to create OCM client: %w", err)
222221
}
223222

224-
creator, err := rosaaws.CreatorForCallerIdentity(convertStsV2(rosaScope.Identity))
223+
creator, err := rosaaws.CreatorForCallerIdentity(&stsv2.GetCallerIdentityOutput{
224+
Account: rosaScope.Identity.Account,
225+
Arn: rosaScope.Identity.Arn,
226+
UserId: rosaScope.Identity.UserId,
227+
})
225228
if err != nil {
226229
return ctrl.Result{}, fmt.Errorf("failed to transform caller identity to creator: %w", err)
227230
}
@@ -354,7 +357,11 @@ func (r *ROSAControlPlaneReconciler) reconcileDelete(ctx context.Context, rosaSc
354357
return ctrl.Result{}, fmt.Errorf("failed to create OCM client: %w", err)
355358
}
356359

357-
creator, err := rosaaws.CreatorForCallerIdentity(convertStsV2(rosaScope.Identity))
360+
creator, err := rosaaws.CreatorForCallerIdentity(&stsv2.GetCallerIdentityOutput{
361+
Account: rosaScope.Identity.Account,
362+
Arn: rosaScope.Identity.Arn,
363+
UserId: rosaScope.Identity.UserId,
364+
})
358365
if err != nil {
359366
return ctrl.Result{}, fmt.Errorf("failed to transform caller identity to creator: %w", err)
360367
}
@@ -1130,12 +1137,3 @@ func buildAPIEndpoint(cluster *cmv1.Cluster) (*clusterv1.APIEndpoint, error) {
11301137
Port: int32(port), //#nosec G109 G115
11311138
}, nil
11321139
}
1133-
1134-
// TODO: Remove this and update the aws-sdk lib to v2.
1135-
func convertStsV2(identity *sts.GetCallerIdentityOutput) *stsv2.GetCallerIdentityOutput {
1136-
return &stsv2.GetCallerIdentityOutput{
1137-
Account: identity.Account,
1138-
Arn: identity.Arn,
1139-
UserId: identity.UserId,
1140-
}
1141-
}

controlplane/rosa/controllers/rosacontrolplane_controller_test.go

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,8 @@ import (
2828
"testing"
2929
"time"
3030

31+
stsv2 "github.com/aws/aws-sdk-go-v2/service/sts"
3132
"github.com/aws/aws-sdk-go/aws"
32-
sts "github.com/aws/aws-sdk-go/service/sts"
33-
"github.com/aws/aws-sdk-go/service/sts/stsiface"
3433
"github.com/golang/mock/gomock"
3534
. "github.com/onsi/gomega"
3635
v1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1"
@@ -48,7 +47,8 @@ import (
4847
rosacontrolplanev1 "sigs.k8s.io/cluster-api-provider-aws/v2/controlplane/rosa/api/v1beta2"
4948
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud"
5049
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/scope"
51-
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/s3/mock_stsiface"
50+
stsiface "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts"
51+
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts/mock_stsiface"
5252
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger"
5353
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/rosa"
5454
"sigs.k8s.io/cluster-api-provider-aws/v2/test/mocks"
@@ -292,10 +292,10 @@ func TestRosaControlPlaneReconcileStatusVersion(t *testing.T) {
292292
mockCtrl := gomock.NewController(t)
293293
ctx := context.TODO()
294294
ocmMock := mocks.NewMockOCMClient(mockCtrl)
295-
stsMock := mock_stsiface.NewMockSTSAPI(mockCtrl)
295+
stsMock := mock_stsiface.NewMockSTSClient(mockCtrl)
296296

297-
getCallerIdentityResult := &sts.GetCallerIdentityOutput{Account: aws.String("foo"), Arn: aws.String("arn:aws:iam::123456789012:rosa/foo")}
298-
stsMock.EXPECT().GetCallerIdentity(gomock.Any()).Return(getCallerIdentityResult, nil).Times(1)
297+
getCallerIdentityResult := &stsv2.GetCallerIdentityOutput{Account: aws.String("foo"), Arn: aws.String("arn:aws:iam::123456789012:rosa/foo")}
298+
stsMock.EXPECT().GetCallerIdentity(gomock.Any(), gomock.Any()).Return(getCallerIdentityResult, nil).Times(1)
299299

300300
expect := func(m *mocks.MockOCMClientMockRecorder) {
301301
m.ValidateHypershiftVersion(gomock.Any(), gomock.Any()).DoAndReturn(func(clusterId string, nodePoolID string) (bool, error) {
@@ -396,7 +396,9 @@ func TestRosaControlPlaneReconcileStatusVersion(t *testing.T) {
396396
Endpoints: []scope.ServiceEndpoint{},
397397
Client: testEnv,
398398
restClientConfig: cfg,
399-
NewStsClient: func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSAPI { return stsMock },
399+
NewStsClient: func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSClient {
400+
return stsMock
401+
},
400402
NewOCMClient: func(ctx context.Context, rosaScope *scope.ROSAControlPlaneScope) (rosa.OCMClient, error) {
401403
return ocmMock, nil
402404
},

exp/controllers/awsmachinepool_controller_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ import (
5151
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/mock_services"
5252
s3svc "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/s3"
5353
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/s3/mock_s3iface"
54-
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/s3/mock_stsiface"
54+
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts/mock_stsiface"
5555
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/userdata"
5656
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger"
5757
clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1"
@@ -71,7 +71,7 @@ func TestAWSMachinePoolReconciler(t *testing.T) {
7171
asgSvc *mock_services.MockASGInterface
7272
reconSvc *mock_services.MockMachinePoolReconcileInterface
7373
s3Mock *mock_s3iface.MockS3API
74-
stsMock *mock_stsiface.MockSTSAPI
74+
stsMock *mock_stsiface.MockSTSClient
7575
recorder *record.FakeRecorder
7676
awsMachinePool *expinfrav1.AWSMachinePool
7777
secret *corev1.Secret
@@ -182,7 +182,7 @@ func TestAWSMachinePoolReconciler(t *testing.T) {
182182
asgSvc = mock_services.NewMockASGInterface(mockCtrl)
183183
reconSvc = mock_services.NewMockMachinePoolReconcileInterface(mockCtrl)
184184
s3Mock = mock_s3iface.NewMockS3API(mockCtrl)
185-
stsMock = mock_stsiface.NewMockSTSAPI(mockCtrl)
185+
stsMock = mock_stsiface.NewMockSTSClient(mockCtrl)
186186

187187
// If the test hangs for 9 minutes, increase the value here to the number of events during a reconciliation loop
188188
recorder = record.NewFakeRecorder(2)

exp/controllers/rosamachinepool_controller.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77

88
"github.com/aws/aws-sdk-go-v2/service/ec2"
99
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
10-
"github.com/aws/aws-sdk-go/service/sts/stsiface"
1110
"github.com/blang/semver"
1211
"github.com/google/go-cmp/cmp"
1312
"github.com/google/go-cmp/cmp/cmpopts"
@@ -35,6 +34,7 @@ import (
3534
expinfrav1 "sigs.k8s.io/cluster-api-provider-aws/v2/exp/api/v1beta2"
3635
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud"
3736
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/scope"
37+
stsservice "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts"
3838
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger"
3939
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/rosa"
4040
"sigs.k8s.io/cluster-api-provider-aws/v2/util/paused"
@@ -52,7 +52,7 @@ type ROSAMachinePoolReconciler struct {
5252
Recorder record.EventRecorder
5353
WatchFilterValue string
5454
Endpoints []scope.ServiceEndpoint
55-
NewStsClient func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSAPI
55+
NewStsClient func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsservice.STSClient
5656
NewOCMClient func(ctx context.Context, rosaScope *scope.ROSAControlPlaneScope) (rosa.OCMClient, error)
5757
}
5858

exp/controllers/rosamachinepool_controller_test.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"testing"
77
"time"
88

9-
"github.com/aws/aws-sdk-go/service/sts/stsiface"
109
"github.com/golang/mock/gomock"
1110
. "github.com/onsi/gomega"
1211
cmv1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1"
@@ -26,7 +25,8 @@ import (
2625
expinfrav1 "sigs.k8s.io/cluster-api-provider-aws/v2/exp/api/v1beta2"
2726
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud"
2827
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/scope"
29-
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/s3/mock_stsiface"
28+
stsiface "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts"
29+
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts/mock_stsiface"
3030
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger"
3131
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/rosa"
3232
"sigs.k8s.io/cluster-api-provider-aws/v2/test/mocks"
@@ -546,15 +546,17 @@ func TestRosaMachinePoolReconcile(t *testing.T) {
546546
ocmMock := mocks.NewMockOCMClient(mockCtrl)
547547
test.expect(ocmMock.EXPECT())
548548

549-
stsMock := mock_stsiface.NewMockSTSAPI(mockCtrl)
550-
stsMock.EXPECT().GetCallerIdentity(gomock.Any()).Times(1)
549+
stsMock := mock_stsiface.NewMockSTSClient(mockCtrl)
550+
stsMock.EXPECT().GetCallerIdentity(gomock.Any(), gomock.Any()).Times(1)
551551

552552
r := ROSAMachinePoolReconciler{
553553
Recorder: recorder,
554554
WatchFilterValue: "",
555555
Endpoints: []scope.ServiceEndpoint{},
556556
Client: testEnv,
557-
NewStsClient: func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSAPI { return stsMock },
557+
NewStsClient: func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSClient {
558+
return stsMock
559+
},
558560
NewOCMClient: func(ctx context.Context, rosaScope *scope.ROSAControlPlaneScope) (rosa.OCMClient, error) {
559561
return ocmMock, nil
560562
},
@@ -641,15 +643,17 @@ func TestRosaMachinePoolReconcile(t *testing.T) {
641643
}
642644
expect(ocmMock.EXPECT())
643645

644-
stsMock := mock_stsiface.NewMockSTSAPI(mockCtrl)
645-
stsMock.EXPECT().GetCallerIdentity(gomock.Any()).Times(1)
646+
stsMock := mock_stsiface.NewMockSTSClient(mockCtrl)
647+
stsMock.EXPECT().GetCallerIdentity(gomock.Any(), gomock.Any()).Times(1)
646648

647649
r := ROSAMachinePoolReconciler{
648650
Recorder: recorder,
649651
WatchFilterValue: "",
650652
Endpoints: []scope.ServiceEndpoint{},
651653
Client: testEnv,
652-
NewStsClient: func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSAPI { return stsMock },
654+
NewStsClient: func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSClient {
655+
return stsMock
656+
},
653657
NewOCMClient: func(ctx context.Context, rosaScope *scope.ROSAControlPlaneScope) (rosa.OCMClient, error) {
654658
return ocmMock, nil
655659
},

pkg/cloud/endpointsv2/endpoints.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import (
3232
"github.com/aws/aws-sdk-go-v2/service/s3"
3333
"github.com/aws/aws-sdk-go-v2/service/sqs"
3434
"github.com/aws/aws-sdk-go-v2/service/ssm"
35+
"github.com/aws/aws-sdk-go-v2/service/sts"
3536
smithyendpoints "github.com/aws/smithy-go/endpoints"
3637

3738
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger"
@@ -303,3 +304,25 @@ func (s *SSMEndpointResolver) ResolveEndpoint(ctx context.Context, params ssm.En
303304
params.Region = &endpoint.SigningRegion
304305
return ssm.NewDefaultEndpointResolverV2().ResolveEndpoint(ctx, params)
305306
}
307+
308+
// STSEndpointResolver implements EndpointResolverV2 interface for STS.
309+
type STSEndpointResolver struct {
310+
*MultiServiceEndpointResolver
311+
}
312+
313+
// ResolveEndpoint for STS.
314+
func (s *STSEndpointResolver) ResolveEndpoint(ctx context.Context, params sts.EndpointParameters) (smithyendpoints.Endpoint, error) {
315+
// If custom endpoint not found, return default endpoint for the service
316+
log := logger.FromContext(ctx)
317+
endpoint, ok := s.endpoints[sts.ServiceID]
318+
319+
if !ok {
320+
log.Debug("Custom endpoint not found, using default endpoint")
321+
return sts.NewDefaultEndpointResolverV2().ResolveEndpoint(ctx, params)
322+
}
323+
324+
log.Debug("Custom endpoint found, using custom endpoint", "endpoint", endpoint.URL)
325+
params.Endpoint = &endpoint.URL
326+
params.Region = &endpoint.SigningRegion
327+
return sts.NewDefaultEndpointResolverV2().ResolveEndpoint(ctx, params)
328+
}

pkg/cloud/scope/clients.go

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,20 @@ import (
2828
"github.com/aws/aws-sdk-go-v2/service/s3"
2929
"github.com/aws/aws-sdk-go-v2/service/sqs"
3030
"github.com/aws/aws-sdk-go-v2/service/ssm"
31+
stsv2 "github.com/aws/aws-sdk-go-v2/service/sts"
3132
"github.com/aws/aws-sdk-go/aws"
3233
"github.com/aws/aws-sdk-go/aws/awserr"
3334
"github.com/aws/aws-sdk-go/aws/request"
3435
"github.com/aws/aws-sdk-go/service/secretsmanager"
3536
"github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface"
36-
"github.com/aws/aws-sdk-go/service/sts"
37-
"github.com/aws/aws-sdk-go/service/sts/stsiface"
3837
"k8s.io/apimachinery/pkg/runtime"
3938

4039
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud"
4140
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/endpointsv2"
4241
awslogs "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/logs"
4342
awsmetrics "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/metrics"
4443
awsmetricsv2 "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/metricsv2"
44+
stsservice "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts"
4545
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/throttle"
4646
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger"
4747
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/record"
@@ -270,13 +270,26 @@ func NewIAMClient(scopeUser cloud.ScopeUsage, session cloud.Session, logger logg
270270
}
271271

272272
// NewSTSClient creates a new STS API client for a given session.
273-
func NewSTSClient(scopeUser cloud.ScopeUsage, session cloud.Session, logger logger.Wrapper, target runtime.Object) stsiface.STSAPI {
274-
stsClient := sts.New(session.Session(), aws.NewConfig().WithLogLevel(awslogs.GetAWSLogLevel(logger.GetLogger())).WithLogger(awslogs.NewWrapLogr(logger.GetLogger())))
275-
stsClient.Handlers.Build.PushFrontNamed(getUserAgentHandler())
276-
stsClient.Handlers.CompleteAttempt.PushFront(awsmetrics.CaptureRequestMetrics(scopeUser.ControllerName()))
277-
stsClient.Handlers.Complete.PushBack(recordAWSPermissionsIssue(target))
273+
func NewSTSClient(scopeUser cloud.ScopeUsage, session cloud.Session, logger logger.Wrapper, target runtime.Object) stsservice.STSClient {
274+
cfg := session.SessionV2()
275+
multiSvcEndpointResolver := endpointsv2.NewMultiServiceEndpointResolver()
276+
stsEndpointResolver := &endpointsv2.STSEndpointResolver{
277+
MultiServiceEndpointResolver: multiSvcEndpointResolver,
278+
}
279+
280+
stsOpts := []func(*stsv2.Options){
281+
func(o *stsv2.Options) {
282+
o.Logger = logger.GetAWSLogger()
283+
o.ClientLogMode = awslogs.GetAWSLogLevelV2(logger.GetLogger())
284+
o.EndpointResolverV2 = stsEndpointResolver
285+
},
286+
stsv2.WithAPIOptions(
287+
awsmetricsv2.WithMiddlewares(scopeUser.ControllerName(), target),
288+
awsmetricsv2.WithCAPAUserAgentMiddleware(),
289+
),
290+
}
278291

279-
return stsClient
292+
return stsservice.NewClientWrapper(stsv2.NewFromConfig(cfg, stsOpts...))
280293
}
281294

282295
// NewSSMClient creates a new Secrets API client for a given session.

0 commit comments

Comments
 (0)