Skip to content

Commit cb4e8d3

Browse files
authored
Create GetSessionWithAuthSettings (#144)
1 parent ab5b2ad commit cb4e8d3

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

pkg/awsds/sessions.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@ var newRemoteCredentials = func(sess *session.Session) *credentials.Credentials
7373
return credentials.NewCredentials(defaults.RemoteCredProvider(*sess.Config, sess.Handlers))
7474
}
7575

76+
type GetSessionConfig struct {
77+
Settings AWSDatasourceSettings
78+
HTTPClient *http.Client
79+
UserAgentName *string
80+
}
81+
7682
type SessionConfig struct {
7783
Settings AWSDatasourceSettings
7884
HTTPClient *http.Client
@@ -100,7 +106,7 @@ func isOptInRegion(region string) bool {
100106
return regions[region]
101107
}
102108

103-
// GetSession returns a session from the config and possible region overrides -- implements AmazonSessionProvider
109+
// Deprecated: use GetSessionWithAuthSettings instead
104110
func (sc *SessionCache) GetSession(c SessionConfig) (*session.Session, error) {
105111
if c.Settings.Region == "" && c.Settings.DefaultRegion != "" {
106112
// DefaultRegion is deprecated, Region should be used instead
@@ -289,6 +295,16 @@ func (sc *SessionCache) GetSession(c SessionConfig) (*session.Session, error) {
289295
return sess, nil
290296
}
291297

298+
// AuthSettings can be grabed from the datasource instance's context with ReadSettingsFromContext
299+
func (sc *SessionCache) GetSessionWithAuthSettings(c GetSessionConfig, as AuthSettings) (*session.Session, error) {
300+
return sc.GetSession(SessionConfig{
301+
Settings: c.Settings,
302+
HTTPClient: c.HTTPClient,
303+
UserAgentName: c.UserAgentName,
304+
AuthSettings: &as,
305+
})
306+
}
307+
292308
// getSTSEndpoint returns true if the set endpoint is a fips endpoint
293309
func isFIPSEndpoint(endpoint string) bool {
294310
return strings.Contains(endpoint, "fips") ||

pkg/awsds/sessions_test.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,3 +592,21 @@ func TestWithCustomHTTPClient(t *testing.T) {
592592
require.NotNil(t, sess)
593593
assert.Equal(t, time.Duration(123), sess.Config.HTTPClient.Timeout)
594594
}
595+
596+
func TestGetSessionWithAuthSettings(t *testing.T) {
597+
t.Run("it uses the passed in for auth settings", func(t *testing.T) {
598+
sessionConfig := GetSessionConfig{
599+
Settings: AWSDatasourceSettings{
600+
AuthType: AuthTypeKeys,
601+
AccessKey: "foo",
602+
SecretKey: "bar",
603+
},
604+
}
605+
authSettings := AuthSettings{
606+
AllowedAuthProviders: []string{"ec2_iam_role"},
607+
}
608+
sessionCache := NewSessionCache()
609+
_, err := sessionCache.GetSessionWithAuthSettings(sessionConfig, authSettings)
610+
require.EqualError(t, err, "attempting to use an auth type that is not allowed: \"keys\"")
611+
})
612+
}

0 commit comments

Comments
 (0)