Skip to content

Commit 8a53a20

Browse files
Migrate sts to sdk v2
Signed-off-by: Alexandr Demicev <[email protected]>
1 parent 483f3a9 commit 8a53a20

File tree

22 files changed

+328
-611
lines changed

22 files changed

+328
-611
lines changed

controlplane/eks/controllers/awsmanagedcontrolplane_controller_test.go

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import (
2020
"context"
2121
"encoding/base64"
2222
"fmt"
23-
"net/http"
2423
"strconv"
2524
"testing"
2625
"time"
@@ -32,8 +31,7 @@ import (
3231
ekstypes "github.com/aws/aws-sdk-go-v2/service/eks/types"
3332
"github.com/aws/aws-sdk-go-v2/service/iam"
3433
iamtypes "github.com/aws/aws-sdk-go-v2/service/iam/types"
35-
stsrequest "github.com/aws/aws-sdk-go/aws/request"
36-
"github.com/aws/aws-sdk-go/service/sts"
34+
stsv2 "github.com/aws/aws-sdk-go-v2/service/sts"
3735
"github.com/aws/smithy-go"
3836
"github.com/golang/mock/gomock"
3937
. "github.com/onsi/gomega"
@@ -54,8 +52,8 @@ import (
5452
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/iamauth/mock_iamauth"
5553
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/mock_services"
5654
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/network"
57-
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/s3/mock_stsiface"
5855
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/securitygroup"
56+
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts/mock_stsiface"
5957
"sigs.k8s.io/cluster-api-provider-aws/v2/test/mocks"
6058
clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1"
6159
"sigs.k8s.io/cluster-api/util"
@@ -76,7 +74,7 @@ func TestAWSManagedControlPlaneReconcilerIntegrationTests(t *testing.T) {
7674
ec2Mock *mocks.MockEC2API
7775
eksMock *mock_eksiface.MockEKSAPI
7876
iamMock *mock_iamauth.MockIAMAPI
79-
stsMock *mock_stsiface.MockSTSAPI
77+
stsMock *mock_stsiface.MockSTSClient
8078
awsNodeMock *mock_services.MockAWSNodeInterface
8179
iamAuthenticatorMock *mock_services.MockIAMAuthenticatorInterface
8280
kubeProxyMock *mock_services.MockKubeProxyInterface
@@ -96,7 +94,7 @@ func TestAWSManagedControlPlaneReconcilerIntegrationTests(t *testing.T) {
9694
ec2Mock = mocks.NewMockEC2API(mockCtrl)
9795
eksMock = mock_eksiface.NewMockEKSAPI(mockCtrl)
9896
iamMock = mock_iamauth.NewMockIAMAPI(mockCtrl)
99-
stsMock = mock_stsiface.NewMockSTSAPI(mockCtrl)
97+
stsMock = mock_stsiface.NewMockSTSClient(mockCtrl)
10098

10199
// Mocking these as well, since the actual implementation requires a remote client to an actual cluster
102100
awsNodeMock = mock_services.NewMockAWSNodeInterface(mockCtrl)
@@ -854,7 +852,7 @@ func mockedEKSControlPlaneIAMRole(g *WithT, iamRec *mock_iamauth.MockIAMAPIMockR
854852
}).After(getPolicyCall).Return(&iam.AttachRolePolicyOutput{}, nil)
855853
}
856854

857-
func mockedEKSCluster(ctx context.Context, g *WithT, eksRec *mock_eksiface.MockEKSAPIMockRecorder, iamRec *mock_iamauth.MockIAMAPIMockRecorder, ec2Rec *mocks.MockEC2APIMockRecorder, stsRec *mock_stsiface.MockSTSAPIMockRecorder, awsNodeRec *mock_services.MockAWSNodeInterfaceMockRecorder, kubeProxyRec *mock_services.MockKubeProxyInterfaceMockRecorder, iamAuthenticatorRec *mock_services.MockIAMAuthenticatorInterfaceMockRecorder) {
855+
func mockedEKSCluster(ctx context.Context, g *WithT, eksRec *mock_eksiface.MockEKSAPIMockRecorder, iamRec *mock_iamauth.MockIAMAPIMockRecorder, ec2Rec *mocks.MockEC2APIMockRecorder, stsRec *mock_stsiface.MockSTSClientMockRecorder, awsNodeRec *mock_services.MockAWSNodeInterfaceMockRecorder, kubeProxyRec *mock_services.MockKubeProxyInterfaceMockRecorder, iamAuthenticatorRec *mock_services.MockIAMAuthenticatorInterfaceMockRecorder) {
858856
describeClusterCall := eksRec.DescribeCluster(ctx, &eks.DescribeClusterInput{
859857
Name: aws.String("test-cluster"),
860858
}).Return(nil, &ekstypes.ResourceNotFoundException{
@@ -948,12 +946,11 @@ func mockedEKSCluster(ctx context.Context, g *WithT, eksRec *mock_eksiface.MockE
948946
})).Return(
949947
clusterSgDesc, nil)
950948

951-
req, err := http.NewRequest(http.MethodGet, "foobar", http.NoBody)
952-
g.Expect(err).To(BeNil())
953-
stsRec.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{}).Return(&stsrequest.Request{
954-
HTTPRequest: req,
955-
Operation: &stsrequest.Operation{},
956-
}, &sts.GetCallerIdentityOutput{})
949+
stsRec.GetCallerIdentity(gomock.Any(), gomock.Any()).Return(&stsv2.GetCallerIdentityOutput{
950+
Account: aws.String("123456789012"),
951+
Arn: aws.String("arn:aws:iam::123456789012:user/test-user"),
952+
UserId: aws.String("AIDACKCEVSQ6C2EXAMPLE"),
953+
}, nil)
957954

958955
eksRec.TagResource(ctx, &eks.TagResourceInput{
959956
ResourceArn: clusterARN,

controlplane/rosa/controllers/rosacontrolplane_controller.go

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ import (
3030
"time"
3131

3232
stsv2 "github.com/aws/aws-sdk-go-v2/service/sts"
33-
sts "github.com/aws/aws-sdk-go/service/sts"
34-
"github.com/aws/aws-sdk-go/service/sts/stsiface"
3533
"github.com/google/go-cmp/cmp"
3634
idputils "github.com/openshift-online/ocm-common/pkg/idp/utils"
3735
cmv1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1"
@@ -50,6 +48,7 @@ import (
5048
"k8s.io/client-go/tools/clientcmd/api"
5149
"k8s.io/klog/v2"
5250
"k8s.io/utils/ptr"
51+
stsiface "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts"
5352
ctrl "sigs.k8s.io/controller-runtime"
5453
"sigs.k8s.io/controller-runtime/pkg/client"
5554
"sigs.k8s.io/controller-runtime/pkg/controller"
@@ -92,7 +91,7 @@ type ROSAControlPlaneReconciler struct {
9291
WatchFilterValue string
9392
WaitInfraPeriod time.Duration
9493
Endpoints []scope.ServiceEndpoint
95-
NewStsClient func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSAPI
94+
NewStsClient func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSClient
9695
NewOCMClient func(ctx context.Context, rosaScope *scope.ROSAControlPlaneScope) (rosa.OCMClient, error)
9796
// Exposing the restClientConfig for integration test. No need to initialize.
9897
restClientConfig *restclient.Config
@@ -221,7 +220,11 @@ func (r *ROSAControlPlaneReconciler) reconcileNormal(ctx context.Context, rosaSc
221220
return ctrl.Result{}, fmt.Errorf("failed to create OCM client: %w", err)
222221
}
223222

224-
creator, err := rosaaws.CreatorForCallerIdentity(convertStsV2(rosaScope.Identity))
223+
creator, err := rosaaws.CreatorForCallerIdentity(&stsv2.GetCallerIdentityOutput{
224+
Account: rosaScope.Identity.Account,
225+
Arn: rosaScope.Identity.Arn,
226+
UserId: rosaScope.Identity.UserId,
227+
})
225228
if err != nil {
226229
return ctrl.Result{}, fmt.Errorf("failed to transform caller identity to creator: %w", err)
227230
}
@@ -354,7 +357,11 @@ func (r *ROSAControlPlaneReconciler) reconcileDelete(ctx context.Context, rosaSc
354357
return ctrl.Result{}, fmt.Errorf("failed to create OCM client: %w", err)
355358
}
356359

357-
creator, err := rosaaws.CreatorForCallerIdentity(convertStsV2(rosaScope.Identity))
360+
creator, err := rosaaws.CreatorForCallerIdentity(&stsv2.GetCallerIdentityOutput{
361+
Account: rosaScope.Identity.Account,
362+
Arn: rosaScope.Identity.Arn,
363+
UserId: rosaScope.Identity.UserId,
364+
})
358365
if err != nil {
359366
return ctrl.Result{}, fmt.Errorf("failed to transform caller identity to creator: %w", err)
360367
}
@@ -1130,12 +1137,3 @@ func buildAPIEndpoint(cluster *cmv1.Cluster) (*clusterv1.APIEndpoint, error) {
11301137
Port: int32(port), //#nosec G109 G115
11311138
}, nil
11321139
}
1133-
1134-
// TODO: Remove this and update the aws-sdk lib to v2.
1135-
func convertStsV2(identity *sts.GetCallerIdentityOutput) *stsv2.GetCallerIdentityOutput {
1136-
return &stsv2.GetCallerIdentityOutput{
1137-
Account: identity.Account,
1138-
Arn: identity.Arn,
1139-
UserId: identity.UserId,
1140-
}
1141-
}

controlplane/rosa/controllers/rosacontrolplane_controller_test.go

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,8 @@ import (
2828
"testing"
2929
"time"
3030

31+
stsv2 "github.com/aws/aws-sdk-go-v2/service/sts"
3132
"github.com/aws/aws-sdk-go/aws"
32-
sts "github.com/aws/aws-sdk-go/service/sts"
33-
"github.com/aws/aws-sdk-go/service/sts/stsiface"
3433
"github.com/golang/mock/gomock"
3534
. "github.com/onsi/gomega"
3635
v1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1"
@@ -41,14 +40,15 @@ import (
4140
"k8s.io/apimachinery/pkg/runtime"
4241
"k8s.io/apimachinery/pkg/types"
4342
restclient "k8s.io/client-go/rest"
43+
stsiface "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts"
4444
ctrl "sigs.k8s.io/controller-runtime"
4545
"sigs.k8s.io/controller-runtime/pkg/client"
4646

4747
infrav1 "sigs.k8s.io/cluster-api-provider-aws/v2/api/v1beta2"
4848
rosacontrolplanev1 "sigs.k8s.io/cluster-api-provider-aws/v2/controlplane/rosa/api/v1beta2"
4949
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud"
5050
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/scope"
51-
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/s3/mock_stsiface"
51+
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts/mock_stsiface"
5252
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger"
5353
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/rosa"
5454
"sigs.k8s.io/cluster-api-provider-aws/v2/test/mocks"
@@ -292,10 +292,10 @@ func TestRosaControlPlaneReconcileStatusVersion(t *testing.T) {
292292
mockCtrl := gomock.NewController(t)
293293
ctx := context.TODO()
294294
ocmMock := mocks.NewMockOCMClient(mockCtrl)
295-
stsMock := mock_stsiface.NewMockSTSAPI(mockCtrl)
295+
stsMock := mock_stsiface.NewMockSTSClient(mockCtrl)
296296

297-
getCallerIdentityResult := &sts.GetCallerIdentityOutput{Account: aws.String("foo"), Arn: aws.String("arn:aws:iam::123456789012:rosa/foo")}
298-
stsMock.EXPECT().GetCallerIdentity(gomock.Any()).Return(getCallerIdentityResult, nil).Times(1)
297+
getCallerIdentityResult := &stsv2.GetCallerIdentityOutput{Account: aws.String("foo"), Arn: aws.String("arn:aws:iam::123456789012:rosa/foo")}
298+
stsMock.EXPECT().GetCallerIdentity(gomock.Any(), gomock.Any()).Return(getCallerIdentityResult, nil).Times(1)
299299

300300
expect := func(m *mocks.MockOCMClientMockRecorder) {
301301
m.ValidateHypershiftVersion(gomock.Any(), gomock.Any()).DoAndReturn(func(clusterId string, nodePoolID string) (bool, error) {
@@ -396,7 +396,9 @@ func TestRosaControlPlaneReconcileStatusVersion(t *testing.T) {
396396
Endpoints: []scope.ServiceEndpoint{},
397397
Client: testEnv,
398398
restClientConfig: cfg,
399-
NewStsClient: func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSAPI { return stsMock },
399+
NewStsClient: func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSClient {
400+
return stsMock
401+
},
400402
NewOCMClient: func(ctx context.Context, rosaScope *scope.ROSAControlPlaneScope) (rosa.OCMClient, error) {
401403
return ocmMock, nil
402404
},

exp/controllers/awsmachinepool_controller_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ import (
5151
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/mock_services"
5252
s3svc "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/s3"
5353
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/s3/mock_s3iface"
54-
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/s3/mock_stsiface"
54+
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts/mock_stsiface"
5555
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/userdata"
5656
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger"
5757
clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1"
@@ -71,7 +71,7 @@ func TestAWSMachinePoolReconciler(t *testing.T) {
7171
asgSvc *mock_services.MockASGInterface
7272
reconSvc *mock_services.MockMachinePoolReconcileInterface
7373
s3Mock *mock_s3iface.MockS3API
74-
stsMock *mock_stsiface.MockSTSAPI
74+
stsMock *mock_stsiface.MockSTSClient
7575
recorder *record.FakeRecorder
7676
awsMachinePool *expinfrav1.AWSMachinePool
7777
secret *corev1.Secret
@@ -182,7 +182,7 @@ func TestAWSMachinePoolReconciler(t *testing.T) {
182182
asgSvc = mock_services.NewMockASGInterface(mockCtrl)
183183
reconSvc = mock_services.NewMockMachinePoolReconcileInterface(mockCtrl)
184184
s3Mock = mock_s3iface.NewMockS3API(mockCtrl)
185-
stsMock = mock_stsiface.NewMockSTSAPI(mockCtrl)
185+
stsMock = mock_stsiface.NewMockSTSClient(mockCtrl)
186186

187187
// If the test hangs for 9 minutes, increase the value here to the number of events during a reconciliation loop
188188
recorder = record.NewFakeRecorder(2)

exp/controllers/rosamachinepool_controller.go

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77

88
"github.com/aws/aws-sdk-go-v2/service/ec2"
99
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
10-
"github.com/aws/aws-sdk-go/service/sts/stsiface"
1110
"github.com/blang/semver"
1211
"github.com/google/go-cmp/cmp"
1312
"github.com/google/go-cmp/cmp/cmpopts"
@@ -23,18 +22,11 @@ import (
2322
"k8s.io/client-go/tools/record"
2423
"k8s.io/klog/v2"
2524
"k8s.io/utils/ptr"
26-
ctrl "sigs.k8s.io/controller-runtime"
27-
"sigs.k8s.io/controller-runtime/pkg/client"
28-
"sigs.k8s.io/controller-runtime/pkg/client/apiutil"
29-
"sigs.k8s.io/controller-runtime/pkg/controller"
30-
"sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"
31-
"sigs.k8s.io/controller-runtime/pkg/handler"
32-
"sigs.k8s.io/controller-runtime/pkg/reconcile"
33-
3425
rosacontrolplanev1 "sigs.k8s.io/cluster-api-provider-aws/v2/controlplane/rosa/api/v1beta2"
3526
expinfrav1 "sigs.k8s.io/cluster-api-provider-aws/v2/exp/api/v1beta2"
3627
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud"
3728
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/scope"
29+
stsservice "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts"
3830
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger"
3931
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/rosa"
4032
"sigs.k8s.io/cluster-api-provider-aws/v2/util/paused"
@@ -44,6 +36,13 @@ import (
4436
"sigs.k8s.io/cluster-api/util/annotations"
4537
"sigs.k8s.io/cluster-api/util/conditions"
4638
"sigs.k8s.io/cluster-api/util/predicates"
39+
ctrl "sigs.k8s.io/controller-runtime"
40+
"sigs.k8s.io/controller-runtime/pkg/client"
41+
"sigs.k8s.io/controller-runtime/pkg/client/apiutil"
42+
"sigs.k8s.io/controller-runtime/pkg/controller"
43+
"sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"
44+
"sigs.k8s.io/controller-runtime/pkg/handler"
45+
"sigs.k8s.io/controller-runtime/pkg/reconcile"
4746
)
4847

4948
// ROSAMachinePoolReconciler reconciles a ROSAMachinePool object.
@@ -52,7 +51,7 @@ type ROSAMachinePoolReconciler struct {
5251
Recorder record.EventRecorder
5352
WatchFilterValue string
5453
Endpoints []scope.ServiceEndpoint
55-
NewStsClient func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSAPI
54+
NewStsClient func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsservice.STSClient
5655
NewOCMClient func(ctx context.Context, rosaScope *scope.ROSAControlPlaneScope) (rosa.OCMClient, error)
5756
}
5857

exp/controllers/rosamachinepool_controller_test.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"testing"
77
"time"
88

9-
"github.com/aws/aws-sdk-go/service/sts/stsiface"
109
"github.com/golang/mock/gomock"
1110
. "github.com/onsi/gomega"
1211
cmv1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1"
@@ -17,6 +16,7 @@ import (
1716
"k8s.io/apimachinery/pkg/util/intstr"
1817
"k8s.io/client-go/tools/record"
1918
"k8s.io/utils/ptr"
19+
stsiface "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts"
2020
ctrl "sigs.k8s.io/controller-runtime"
2121
"sigs.k8s.io/controller-runtime/pkg/client"
2222
"sigs.k8s.io/controller-runtime/pkg/reconcile"
@@ -26,7 +26,7 @@ import (
2626
expinfrav1 "sigs.k8s.io/cluster-api-provider-aws/v2/exp/api/v1beta2"
2727
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud"
2828
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/scope"
29-
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/s3/mock_stsiface"
29+
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/services/sts/mock_stsiface"
3030
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger"
3131
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/rosa"
3232
"sigs.k8s.io/cluster-api-provider-aws/v2/test/mocks"
@@ -546,15 +546,17 @@ func TestRosaMachinePoolReconcile(t *testing.T) {
546546
ocmMock := mocks.NewMockOCMClient(mockCtrl)
547547
test.expect(ocmMock.EXPECT())
548548

549-
stsMock := mock_stsiface.NewMockSTSAPI(mockCtrl)
550-
stsMock.EXPECT().GetCallerIdentity(gomock.Any()).Times(1)
549+
stsMock := mock_stsiface.NewMockSTSClient(mockCtrl)
550+
stsMock.EXPECT().GetCallerIdentity(gomock.Any(), gomock.Any()).Times(1)
551551

552552
r := ROSAMachinePoolReconciler{
553553
Recorder: recorder,
554554
WatchFilterValue: "",
555555
Endpoints: []scope.ServiceEndpoint{},
556556
Client: testEnv,
557-
NewStsClient: func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSAPI { return stsMock },
557+
NewStsClient: func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSClient {
558+
return stsMock
559+
},
558560
NewOCMClient: func(ctx context.Context, rosaScope *scope.ROSAControlPlaneScope) (rosa.OCMClient, error) {
559561
return ocmMock, nil
560562
},
@@ -641,15 +643,17 @@ func TestRosaMachinePoolReconcile(t *testing.T) {
641643
}
642644
expect(ocmMock.EXPECT())
643645

644-
stsMock := mock_stsiface.NewMockSTSAPI(mockCtrl)
645-
stsMock.EXPECT().GetCallerIdentity(gomock.Any()).Times(1)
646+
stsMock := mock_stsiface.NewMockSTSClient(mockCtrl)
647+
stsMock.EXPECT().GetCallerIdentity(gomock.Any(), gomock.Any()).Times(1)
646648

647649
r := ROSAMachinePoolReconciler{
648650
Recorder: recorder,
649651
WatchFilterValue: "",
650652
Endpoints: []scope.ServiceEndpoint{},
651653
Client: testEnv,
652-
NewStsClient: func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSAPI { return stsMock },
654+
NewStsClient: func(cloud.ScopeUsage, cloud.Session, logger.Wrapper, runtime.Object) stsiface.STSClient {
655+
return stsMock
656+
},
653657
NewOCMClient: func(ctx context.Context, rosaScope *scope.ROSAControlPlaneScope) (rosa.OCMClient, error) {
654658
return ocmMock, nil
655659
},

pkg/cloud/endpointsv2/endpoints.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import (
3232
"github.com/aws/aws-sdk-go-v2/service/s3"
3333
"github.com/aws/aws-sdk-go-v2/service/sqs"
3434
"github.com/aws/aws-sdk-go-v2/service/ssm"
35+
"github.com/aws/aws-sdk-go-v2/service/sts"
3536
smithyendpoints "github.com/aws/smithy-go/endpoints"
3637

3738
"sigs.k8s.io/cluster-api-provider-aws/v2/pkg/logger"
@@ -303,3 +304,25 @@ func (s *SSMEndpointResolver) ResolveEndpoint(ctx context.Context, params ssm.En
303304
params.Region = &endpoint.SigningRegion
304305
return ssm.NewDefaultEndpointResolverV2().ResolveEndpoint(ctx, params)
305306
}
307+
308+
// STSEndpointResolver implements EndpointResolverV2 interface for STS.
309+
type STSEndpointResolver struct {
310+
*MultiServiceEndpointResolver
311+
}
312+
313+
// ResolveEndpoint for STS.
314+
func (s *STSEndpointResolver) ResolveEndpoint(ctx context.Context, params sts.EndpointParameters) (smithyendpoints.Endpoint, error) {
315+
// If custom endpoint not found, return default endpoint for the service
316+
log := logger.FromContext(ctx)
317+
endpoint, ok := s.endpoints[sts.ServiceID]
318+
319+
if !ok {
320+
log.Debug("Custom endpoint not found, using default endpoint")
321+
return sts.NewDefaultEndpointResolverV2().ResolveEndpoint(ctx, params)
322+
}
323+
324+
log.Debug("Custom endpoint found, using custom endpoint", "endpoint", endpoint.URL)
325+
params.Endpoint = &endpoint.URL
326+
params.Region = &endpoint.SigningRegion
327+
return sts.NewDefaultEndpointResolverV2().ResolveEndpoint(ctx, params)
328+
}

0 commit comments

Comments
 (0)