4
4
"encoding/json"
5
5
"fmt"
6
6
"log"
7
+ "net/url"
8
+ "strings"
7
9
8
10
"github.com/aws/aws-sdk-go/aws"
9
11
"github.com/aws/aws-sdk-go/aws/awserr"
@@ -12,48 +14,39 @@ import (
12
14
"github.com/aws/aws-sdk-go/aws/endpoints"
13
15
"github.com/aws/aws-sdk-go/aws/session"
14
16
"github.com/aws/aws-sdk-go/service/secretsmanager"
17
+ "github.com/aws/aws-sdk-go/service/sts"
18
+
15
19
"github.com/mongodb/terraform-provider-mongodbatlas/internal/config"
16
20
)
17
21
18
22
const (
19
- endPointSTSDefault = "https://sts.amazonaws.com"
23
+ endPointSTSHostnameDefault = "sts.amazonaws.com"
24
+ DefaultRegionSTS = "us-east-1"
25
+ minSegmentsForSTSRegionalHost = 4
20
26
)
21
27
22
28
func configureCredentialsSTS (cfg * config.Config , secret , region , awsAccessKeyID , awsSecretAccessKey , awsSessionToken , endpoint string ) (config.Config , error ) {
23
- ep , err := endpoints .GetSTSRegionalEndpoint ("regional" )
24
- if err != nil {
25
- log .Printf ("GetSTSRegionalEndpoint error: %s" , err )
26
- return * cfg , err
27
- }
28
-
29
29
defaultResolver := endpoints .DefaultResolver ()
30
- stsCustResolverFn := func (service , region string , optFns ... func (* endpoints.Options )) (endpoints.ResolvedEndpoint , error ) {
31
- if service == endpoints .StsServiceID {
32
- if endpoint == "" {
33
- return endpoints.ResolvedEndpoint {
34
- URL : endPointSTSDefault ,
35
- SigningRegion : region ,
36
- }, nil
30
+ stsCustResolverFn := func (service , _ string , optFns ... func (* endpoints.Options )) (endpoints.ResolvedEndpoint , error ) {
31
+ if service == sts .EndpointsID {
32
+ resolved , err := ResolveSTSEndpoint (endpoint , region )
33
+ if err != nil {
34
+ return endpoints.ResolvedEndpoint {}, err
37
35
}
38
- return endpoints.ResolvedEndpoint {
39
- URL : endpoint ,
40
- SigningRegion : region ,
41
- }, nil
36
+ return resolved , nil
42
37
}
43
-
44
38
return defaultResolver .EndpointFor (service , region , optFns ... )
45
39
}
46
40
47
41
sess := session .Must (session .NewSession (& aws.Config {
48
- Region : aws .String (region ),
49
- Credentials : credentials .NewStaticCredentials (awsAccessKeyID , awsSecretAccessKey , awsSessionToken ),
50
- STSRegionalEndpoint : ep ,
51
- EndpointResolver : endpoints .ResolverFunc (stsCustResolverFn ),
42
+ Region : aws .String (region ),
43
+ Credentials : credentials .NewStaticCredentials (awsAccessKeyID , awsSecretAccessKey , awsSessionToken ),
44
+ EndpointResolver : endpoints .ResolverFunc (stsCustResolverFn ),
52
45
}))
53
46
54
47
creds := stscreds .NewCredentials (sess , cfg .AssumeRole .RoleARN )
55
48
56
- _ , err = sess .Config .Credentials .Get ()
49
+ _ , err : = sess .Config .Credentials .Get ()
57
50
if err != nil {
58
51
log .Printf ("Session get credentials error: %s" , err )
59
52
return * cfg , err
@@ -87,6 +80,45 @@ func configureCredentialsSTS(cfg *config.Config, secret, region, awsAccessKeyID,
87
80
return * cfg , nil
88
81
}
89
82
83
+ func DeriveSTSRegionFromEndpoint (ep string ) string {
84
+ if ep == "" {
85
+ return ""
86
+ }
87
+ u , err := url .Parse (ep )
88
+ if err != nil {
89
+ return DefaultRegionSTS
90
+ }
91
+ host := u .Hostname () // valid values: sts.us-west-2.amazonaws.com or sts.amazonaws.com
92
+
93
+ if host == endPointSTSHostnameDefault {
94
+ return DefaultRegionSTS
95
+ }
96
+
97
+ parts := strings .Split (host , "." )
98
+ if len (parts ) >= minSegmentsForSTSRegionalHost && parts [0 ] == "sts" {
99
+ return parts [1 ]
100
+ }
101
+ return DefaultRegionSTS
102
+ }
103
+
104
+ func ResolveSTSEndpoint (stsEndpoint , secretsRegion string ) (endpoints.ResolvedEndpoint , error ) {
105
+ ep := stsEndpoint
106
+ if ep == "" {
107
+ r := secretsRegion
108
+ if r == "" {
109
+ r = DefaultRegionSTS
110
+ }
111
+ ep = fmt .Sprintf ("https://sts.%s.amazonaws.com/" , r )
112
+ }
113
+
114
+ signingRegion := DeriveSTSRegionFromEndpoint (ep )
115
+
116
+ return endpoints.ResolvedEndpoint {
117
+ URL : ep ,
118
+ SigningRegion : signingRegion ,
119
+ }, nil
120
+ }
121
+
90
122
func secretsManagerGetSecretValue (sess * session.Session , creds * aws.Config , secret string ) (string , error ) {
91
123
svc := secretsmanager .New (sess , creds )
92
124
input := & secretsmanager.GetSecretValueInput {
0 commit comments