Skip to content

Commit 838b57a

Browse files
authored
Fix assuming a role with an endpoint set (#108)
1 parent c5f0a4a commit 838b57a

File tree

2 files changed

+73
-14
lines changed

2 files changed

+73
-14
lines changed

pkg/awsds/sessions.go

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -255,15 +255,6 @@ func (sc *SessionCache) GetSession(c SessionConfig) (*session.Session, error) {
255255
panic(fmt.Sprintf("Unrecognized authType: %d", c.Settings.AuthType))
256256
}
257257

258-
if c.Settings.Endpoint != "" {
259-
cfgs = append(cfgs, &aws.Config{Endpoint: aws.String(c.Settings.Endpoint)})
260-
}
261-
262-
sess, err := newSession(cfgs...)
263-
if err != nil {
264-
return nil, err
265-
}
266-
267258
duration := stscreds.DefaultDuration
268259
if sc.authSettings.SessionDuration != nil {
269260
duration = *sc.authSettings.SessionDuration
@@ -273,7 +264,16 @@ func (sc *SessionCache) GetSession(c SessionConfig) (*session.Session, error) {
273264
// We should assume a role in AWS
274265
backend.Logger.Debug("Trying to assume role in AWS", "arn", c.Settings.AssumeRoleARN)
275266

276-
cfgs := []*aws.Config{
267+
if c.Settings.Endpoint != "" {
268+
cfgs = append(cfgs, &aws.Config{Endpoint: aws.String(getSTSEndpoint(c.Settings.Endpoint))})
269+
}
270+
271+
sess, err := newSession(cfgs...)
272+
if err != nil {
273+
return nil, err
274+
}
275+
276+
cfgs = []*aws.Config{
277277
{
278278
CredentialsChainVerboseErrors: aws.Bool(true),
279279
},
@@ -296,10 +296,15 @@ func (sc *SessionCache) GetSession(c SessionConfig) (*session.Session, error) {
296296
regionCfg = &aws.Config{Region: aws.String(c.Settings.Region)}
297297
cfgs = append(cfgs, regionCfg)
298298
}
299-
sess, err = newSession(cfgs...)
300-
if err != nil {
301-
return nil, err
302-
}
299+
}
300+
301+
if c.Settings.Endpoint != "" {
302+
cfgs = append(cfgs, &aws.Config{Endpoint: aws.String(c.Settings.Endpoint)})
303+
}
304+
305+
sess, err := newSession(cfgs...)
306+
if err != nil {
307+
return nil, err
303308
}
304309

305310
if c.UserAgentName != nil {
@@ -319,3 +324,30 @@ func (sc *SessionCache) GetSession(c SessionConfig) (*session.Session, error) {
319324

320325
return sess, nil
321326
}
327+
328+
// getSTSEndpoint checks if the set endpoint is a fips endpoint, and if so, returns the STS fips endpoint for the same region
329+
func getSTSEndpoint(endpoint string) string {
330+
if endpoint == "" {
331+
return ""
332+
}
333+
if strings.Contains(endpoint, "fips") {
334+
switch {
335+
case strings.Contains(endpoint, "us-east-1"):
336+
return "sts-fips.us-east-1.amazonaws.com"
337+
case strings.Contains(endpoint, "us-east-2"):
338+
return "sts-fips.us-east-2.amazonaws.com"
339+
case strings.Contains(endpoint, "us-west-1"):
340+
return "sts-fips.us-west-1.amazonaws.com"
341+
case strings.Contains(endpoint, "us-west-2"):
342+
return "sts-fips.us-west-2.amazonaws.com"
343+
}
344+
}
345+
346+
if strings.Contains(endpoint, "us-gov-east-1") {
347+
return "sts.us-gov-east-1.amazonaws.com"
348+
}
349+
if strings.Contains(endpoint, "us-gov-west-1") {
350+
return "sts.us-gov-west-1.amazonaws.com"
351+
}
352+
return ""
353+
}

pkg/awsds/sessions_test.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,33 @@ func TestNewSession_AssumeRole(t *testing.T) {
211211
require.NotNil(t, sess)
212212
assert.Equal(t, "us-gov-east-1", *sess.Config.Region)
213213
})
214+
215+
t.Run("Assume role is enabled with a fips endpoint", func(t *testing.T) {
216+
defer unsetEnvironmentVariables()
217+
fakeNewSTSCredentials := newSTSCredentials
218+
newSTSCredentials = func(c client.ConfigProvider, roleARN string,
219+
options ...func(*stscreds.AssumeRoleProvider)) *credentials.Credentials {
220+
sess := c.(*session.Session)
221+
// Verify that we are using the correct sts endpoint
222+
assert.Equal(t, "sts-fips.us-east-1.amazonaws.com", *sess.Config.Endpoint)
223+
return fakeNewSTSCredentials(c, roleARN, options...)
224+
}
225+
settings := AWSDatasourceSettings{
226+
AssumeRoleARN: "test",
227+
Region: "us-east-1",
228+
Endpoint: "athena-fips.us-east-1.amazonaws.com",
229+
}
230+
require.NoError(t, os.Setenv(AllowedAuthProvidersEnvVarKeyName, "default"))
231+
require.NoError(t, os.Setenv(AssumeRoleEnabledEnvVarKeyName, "true"))
232+
cache := NewSessionCache()
233+
sess, err := cache.GetSession(SessionConfig{Settings: settings})
234+
newSTSCredentials = fakeNewSTSCredentials
235+
236+
require.NoError(t, err)
237+
require.NotNil(t, sess)
238+
// Verify that we use the endpoint from the settings
239+
assert.Equal(t, settings.Endpoint, *sess.Config.Endpoint)
240+
})
214241
}
215242

216243
func TestNewSession_AllowedAuthProviders(t *testing.T) {

0 commit comments

Comments
 (0)