Skip to content

Commit a5691d0

Browse files
oarbusimaastha
andauthored
fix: Backport of "Fixes STS region resolution when using cross-region authentication (#3718)" (#3731)
* fix: Fixes STS region resolution when using cross-region authentication (#3718) * update provider to infer region from sts_endpoint * changelog * fmt * nit * nit * fmt * unit tests * test * nit * nit * nit * improve clarity of comment * use asserts in unit test * use const instead of magic number --------- Co-authored-by: Oriol Arbusi Abadal <[email protected]> * rename changelog file --------- Co-authored-by: maastha <[email protected]>
1 parent fd49f90 commit a5691d0

File tree

4 files changed

+174
-27
lines changed

4 files changed

+174
-27
lines changed

.changelog/3731.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
```release-note:bug
2+
provider: Fixes STS region resolution when using cross-region authentication
3+
```

.github/workflows/acceptance-tests-runner.yml

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,23 @@ jobs:
457457
needs: [ change-detection, get-provider-version ]
458458
if: ${{ needs.change-detection.outputs.assume_role == 'true' || inputs.test_group == 'assume_role' }}
459459
runs-on: ubuntu-latest
460+
strategy:
461+
fail-fast: false
462+
matrix:
463+
include:
464+
# Secret and STS Endpoint in same region
465+
- name: same-region-us-east-1
466+
aws_region: US_EAST_1
467+
sts_endpoint: https://sts.us-east-1.amazonaws.com/
468+
# Secret and STS Endpoint in different regions(Cross-region)
469+
- name: cross-sts-us-east-1-secret-eu-north-1
470+
aws_region: EU_NORTH_1
471+
sts_endpoint: https://sts.us-east-1.amazonaws.com/
472+
# Global STS endpoint (signs as us-east-1), secrets in eu-west-1
473+
- name: global-sts-secret-eu-west-1
474+
aws_region: EU_WEST_1
475+
sts_endpoint: https://sts.amazonaws.com
476+
name: assume_role – ${{ matrix.name }}
460477
permissions: {}
461478
steps:
462479
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8
@@ -478,19 +495,20 @@ jobs:
478495
AWS_ACCESS_KEY_ID: ${{ secrets.aws_access_key_id }}
479496
ASSUME_ROLE_ARN: ${{ vars.ASSUME_ROLE_ARN }}
480497
run: bash ./scripts/generate-credentials-with-sts-assume-role.sh
481-
- name: Acceptance Tests
498+
- name: Acceptance Tests (matrix)
482499
env:
483500
MONGODB_ATLAS_PUBLIC_KEY: ""
484501
MONGODB_ATLAS_PRIVATE_KEY: ""
485502
ASSUME_ROLE_ARN: ${{ vars.ASSUME_ROLE_ARN }}
486-
AWS_REGION: ${{ vars.AWS_REGION }}
487-
STS_ENDPOINT: ${{ vars.STS_ENDPOINT }}
503+
AWS_REGION: ${{ matrix.aws_region }}
504+
STS_ENDPOINT: ${{ matrix.sts_endpoint }}
488505
SECRET_NAME: ${{ inputs.aws_secret_name }}
489506
AWS_ACCESS_KEY_ID: ${{ steps.sts-assume-role.outputs.aws_access_key_id }}
490507
AWS_SECRET_ACCESS_KEY: ${{ steps.sts-assume-role.outputs.aws_secret_access_key }}
491508
AWS_SESSION_TOKEN: ${{ steps.sts-assume-role.outputs.AWS_SESSION_TOKEN }}
492509
MONGODB_ATLAS_LAST_VERSION: ${{ needs.get-provider-version.outputs.provider_version }}
493510
ACCTEST_PACKAGES: ./internal/provider
511+
ACCTEST_REGEX_RUN: ^TestAccSTSAssumeRole_basic$
494512
run: make testacc
495513

496514
autogen:

internal/provider/credentials.go

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"encoding/json"
55
"fmt"
66
"log"
7+
"net/url"
8+
"strings"
79

810
"github.com/aws/aws-sdk-go/aws"
911
"github.com/aws/aws-sdk-go/aws/awserr"
@@ -12,48 +14,39 @@ import (
1214
"github.com/aws/aws-sdk-go/aws/endpoints"
1315
"github.com/aws/aws-sdk-go/aws/session"
1416
"github.com/aws/aws-sdk-go/service/secretsmanager"
17+
"github.com/aws/aws-sdk-go/service/sts"
18+
1519
"github.com/mongodb/terraform-provider-mongodbatlas/internal/config"
1620
)
1721

1822
const (
19-
endPointSTSDefault = "https://sts.amazonaws.com"
23+
endPointSTSHostnameDefault = "sts.amazonaws.com"
24+
DefaultRegionSTS = "us-east-1"
25+
minSegmentsForSTSRegionalHost = 4
2026
)
2127

2228
func configureCredentialsSTS(cfg *config.Config, secret, region, awsAccessKeyID, awsSecretAccessKey, awsSessionToken, endpoint string) (config.Config, error) {
23-
ep, err := endpoints.GetSTSRegionalEndpoint("regional")
24-
if err != nil {
25-
log.Printf("GetSTSRegionalEndpoint error: %s", err)
26-
return *cfg, err
27-
}
28-
2929
defaultResolver := endpoints.DefaultResolver()
30-
stsCustResolverFn := func(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
31-
if service == endpoints.StsServiceID {
32-
if endpoint == "" {
33-
return endpoints.ResolvedEndpoint{
34-
URL: endPointSTSDefault,
35-
SigningRegion: region,
36-
}, nil
30+
stsCustResolverFn := func(service, _ string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
31+
if service == sts.EndpointsID {
32+
resolved, err := ResolveSTSEndpoint(endpoint, region)
33+
if err != nil {
34+
return endpoints.ResolvedEndpoint{}, err
3735
}
38-
return endpoints.ResolvedEndpoint{
39-
URL: endpoint,
40-
SigningRegion: region,
41-
}, nil
36+
return resolved, nil
4237
}
43-
4438
return defaultResolver.EndpointFor(service, region, optFns...)
4539
}
4640

4741
sess := session.Must(session.NewSession(&aws.Config{
48-
Region: aws.String(region),
49-
Credentials: credentials.NewStaticCredentials(awsAccessKeyID, awsSecretAccessKey, awsSessionToken),
50-
STSRegionalEndpoint: ep,
51-
EndpointResolver: endpoints.ResolverFunc(stsCustResolverFn),
42+
Region: aws.String(region),
43+
Credentials: credentials.NewStaticCredentials(awsAccessKeyID, awsSecretAccessKey, awsSessionToken),
44+
EndpointResolver: endpoints.ResolverFunc(stsCustResolverFn),
5245
}))
5346

5447
creds := stscreds.NewCredentials(sess, cfg.AssumeRole.RoleARN)
5548

56-
_, err = sess.Config.Credentials.Get()
49+
_, err := sess.Config.Credentials.Get()
5750
if err != nil {
5851
log.Printf("Session get credentials error: %s", err)
5952
return *cfg, err
@@ -87,6 +80,45 @@ func configureCredentialsSTS(cfg *config.Config, secret, region, awsAccessKeyID,
8780
return *cfg, nil
8881
}
8982

83+
func DeriveSTSRegionFromEndpoint(ep string) string {
84+
if ep == "" {
85+
return ""
86+
}
87+
u, err := url.Parse(ep)
88+
if err != nil {
89+
return DefaultRegionSTS
90+
}
91+
host := u.Hostname() // valid values: sts.us-west-2.amazonaws.com or sts.amazonaws.com
92+
93+
if host == endPointSTSHostnameDefault {
94+
return DefaultRegionSTS
95+
}
96+
97+
parts := strings.Split(host, ".")
98+
if len(parts) >= minSegmentsForSTSRegionalHost && parts[0] == "sts" {
99+
return parts[1]
100+
}
101+
return DefaultRegionSTS
102+
}
103+
104+
func ResolveSTSEndpoint(stsEndpoint, secretsRegion string) (endpoints.ResolvedEndpoint, error) {
105+
ep := stsEndpoint
106+
if ep == "" {
107+
r := secretsRegion
108+
if r == "" {
109+
r = DefaultRegionSTS
110+
}
111+
ep = fmt.Sprintf("https://sts.%s.amazonaws.com/", r)
112+
}
113+
114+
signingRegion := DeriveSTSRegionFromEndpoint(ep)
115+
116+
return endpoints.ResolvedEndpoint{
117+
URL: ep,
118+
SigningRegion: signingRegion,
119+
}, nil
120+
}
121+
90122
func secretsManagerGetSecretValue(sess *session.Session, creds *aws.Config, secret string) (string, error) {
91123
svc := secretsmanager.New(sess, creds)
92124
input := &secretsmanager.GetSecretValueInput{
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package provider_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/mongodb/terraform-provider-mongodbatlas/internal/provider"
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func Test_deriveSTSRegionFromEndpoint(t *testing.T) {
12+
testCases := map[string]struct {
13+
input string
14+
expected string
15+
}{
16+
"empty endpoint": {
17+
input: "",
18+
expected: "",
19+
},
20+
"global endpoint": {
21+
input: "https://sts.amazonaws.com",
22+
expected: provider.DefaultRegionSTS,
23+
},
24+
"regional": {
25+
input: "https://sts.us-east-1.amazonaws.com/",
26+
expected: "us-east-1",
27+
},
28+
"regional eu-north-1": {
29+
input: "https://sts.eu-north-1.amazonaws.com/",
30+
expected: "eu-north-1",
31+
},
32+
"malformed url": {
33+
input: "://not-a-url",
34+
expected: provider.DefaultRegionSTS,
35+
},
36+
"unexpected host shape": {
37+
input: "https://sts.something-weird",
38+
expected: provider.DefaultRegionSTS,
39+
},
40+
}
41+
42+
for testName, tc := range testCases {
43+
t.Run(testName, func(t *testing.T) {
44+
t.Parallel()
45+
got := provider.DeriveSTSRegionFromEndpoint(tc.input)
46+
if got != tc.expected {
47+
t.Fatalf("deriveSTSRegionFromEndpoint(%q) = %q; want %q", tc.input, got, tc.expected)
48+
}
49+
})
50+
}
51+
}
52+
53+
func Test_resolveSTSEndpoint(t *testing.T) {
54+
testCases := map[string]struct {
55+
stsEndpoint string
56+
secretsRegion string
57+
expectedURL string
58+
expectedSign string
59+
}{
60+
"explicit regional endpoint": {
61+
stsEndpoint: "https://sts.eu-north-1.amazonaws.com/",
62+
secretsRegion: "us-east-1",
63+
expectedURL: "https://sts.eu-north-1.amazonaws.com/",
64+
expectedSign: "eu-north-1",
65+
},
66+
"global endpoint - us-east-1 signing": {
67+
stsEndpoint: "https://sts.amazonaws.com",
68+
secretsRegion: "eu-west-1",
69+
expectedURL: "https://sts.amazonaws.com",
70+
expectedSign: provider.DefaultRegionSTS,
71+
},
72+
"no endpoint - uses secrets region": {
73+
stsEndpoint: "",
74+
secretsRegion: "us-west-2",
75+
expectedURL: "https://sts.us-west-2.amazonaws.com/",
76+
expectedSign: "us-west-2",
77+
},
78+
"no endpoint and empty region": {
79+
stsEndpoint: "",
80+
secretsRegion: "",
81+
expectedURL: "https://sts.us-east-1.amazonaws.com/",
82+
expectedSign: provider.DefaultRegionSTS,
83+
},
84+
}
85+
86+
for testName, tc := range testCases {
87+
t.Run(testName, func(t *testing.T) {
88+
ep, err := provider.ResolveSTSEndpoint(tc.stsEndpoint, tc.secretsRegion)
89+
require.NoError(t, err)
90+
assert.Equal(t, tc.expectedURL, ep.URL)
91+
assert.Equal(t, tc.expectedSign, ep.SigningRegion)
92+
})
93+
}
94+
}

0 commit comments

Comments
 (0)