Skip to content

Commit 0122506

Browse files
vanclueverldez
andauthored
route53: Allow static credentials to be supplied (#1746)
Co-authored-by: Fernandez Ludovic <[email protected]>
1 parent 07d957f commit 0122506

File tree

2 files changed

+169
-2
lines changed

2 files changed

+169
-2
lines changed

providers/dns/route53/route53.go

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010

1111
"github.com/aws/aws-sdk-go/aws"
1212
"github.com/aws/aws-sdk-go/aws/client"
13+
"github.com/aws/aws-sdk-go/aws/credentials"
1314
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
1415
"github.com/aws/aws-sdk-go/aws/request"
1516
"github.com/aws/aws-sdk-go/aws/session"
@@ -28,15 +29,22 @@ const (
2829
EnvRegion = envNamespace + "REGION"
2930
EnvHostedZoneID = envNamespace + "HOSTED_ZONE_ID"
3031
EnvMaxRetries = envNamespace + "MAX_RETRIES"
32+
EnvAssumeRoleArn = envNamespace + "ASSUME_ROLE_ARN"
3133

3234
EnvTTL = envNamespace + "TTL"
3335
EnvPropagationTimeout = envNamespace + "PROPAGATION_TIMEOUT"
3436
EnvPollingInterval = envNamespace + "POLLING_INTERVAL"
35-
EnvAssumeRoleArn = envNamespace + "ASSUME_ROLE_ARN"
3637
)
3738

3839
// Config is used to configure the creation of the DNSProvider.
3940
type Config struct {
41+
// Static credential chain.
42+
// These are not set via environment for the time being and are only used if they are explicitly provided.
43+
AccessKeyID string
44+
SecretAccessKey string
45+
SessionToken string
46+
Region string
47+
4048
HostedZoneID string
4149
MaxRetries int
4250
AssumeRoleArn string
@@ -301,10 +309,23 @@ func (d *DNSProvider) getHostedZoneID(fqdn string) (string, error) {
301309
}
302310

303311
func createSession(config *Config) (*session.Session, error) {
312+
if err := createSessionCheckParams(config); err != nil {
313+
return nil, err
314+
}
315+
304316
retry := customRetryer{}
305317
retry.NumMaxRetries = config.MaxRetries
306318

307-
sessionCfg := request.WithRetryer(aws.NewConfig(), retry)
319+
awsConfig := aws.NewConfig()
320+
if config.AccessKeyID != "" && config.SecretAccessKey != "" {
321+
awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(config.AccessKeyID, config.SecretAccessKey, config.SessionToken))
322+
}
323+
324+
if config.Region != "" {
325+
awsConfig = awsConfig.WithRegion(config.Region)
326+
}
327+
328+
sessionCfg := request.WithRetryer(awsConfig, retry)
308329

309330
sess, err := session.NewSessionWithOptions(session.Options{Config: *sessionCfg})
310331
if err != nil {
@@ -320,3 +341,19 @@ func createSession(config *Config) (*session.Session, error) {
320341
Credentials: stscreds.NewCredentials(sess, config.AssumeRoleArn),
321342
})
322343
}
344+
345+
func createSessionCheckParams(config *Config) error {
346+
if config == nil {
347+
return errors.New("config is nil")
348+
}
349+
350+
switch {
351+
case config.SessionToken != "" && config.AccessKeyID == "" && config.SecretAccessKey == "":
352+
return errors.New("SessionToken must be supplied with AccessKeyID and SecretAccessKey")
353+
354+
case config.AccessKeyID == "" && config.SecretAccessKey != "" || config.AccessKeyID != "" && config.SecretAccessKey == "":
355+
return errors.New("AccessKeyID and SecretAccessKey must be supplied together")
356+
}
357+
358+
return nil
359+
}

providers/dns/route53/route53_test.go

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,133 @@ func TestDNSProvider_Present(t *testing.T) {
177177
err := provider.Present(domain, "", keyAuth)
178178
require.NoError(t, err, "Expected Present to return no error")
179179
}
180+
181+
func TestCreateSession(t *testing.T) {
182+
testCases := []struct {
183+
desc string
184+
env map[string]string
185+
config *Config
186+
wantCreds credentials.Value
187+
wantDefaultChain bool
188+
wantRegion string
189+
wantErr string
190+
}{
191+
{
192+
desc: "config is nil",
193+
wantErr: "config is nil",
194+
},
195+
{
196+
desc: "session token without access key id or secret access key",
197+
config: &Config{SessionToken: "foo"},
198+
wantErr: "SessionToken must be supplied with AccessKeyID and SecretAccessKey",
199+
},
200+
{
201+
desc: "access key id without secret access key",
202+
config: &Config{AccessKeyID: "foo"},
203+
wantErr: "AccessKeyID and SecretAccessKey must be supplied together",
204+
},
205+
{
206+
desc: "access key id without secret access key",
207+
config: &Config{SecretAccessKey: "foo"},
208+
wantErr: "AccessKeyID and SecretAccessKey must be supplied together",
209+
},
210+
{
211+
desc: "credentials from default chain",
212+
config: &Config{},
213+
wantDefaultChain: true,
214+
},
215+
{
216+
desc: "static credentials",
217+
config: &Config{
218+
AccessKeyID: "one",
219+
SecretAccessKey: "two",
220+
},
221+
wantCreds: credentials.Value{
222+
AccessKeyID: "one",
223+
SecretAccessKey: "two",
224+
SessionToken: "",
225+
ProviderName: credentials.StaticProviderName,
226+
},
227+
},
228+
{
229+
desc: "static credentials with session token",
230+
config: &Config{
231+
AccessKeyID: "one",
232+
SecretAccessKey: "two",
233+
SessionToken: "three",
234+
},
235+
wantCreds: credentials.Value{
236+
AccessKeyID: "one",
237+
SecretAccessKey: "two",
238+
SessionToken: "three",
239+
ProviderName: credentials.StaticProviderName,
240+
},
241+
},
242+
{
243+
desc: "region from env",
244+
config: &Config{},
245+
env: map[string]string{
246+
"AWS_REGION": "foo",
247+
},
248+
wantDefaultChain: true,
249+
wantRegion: "foo",
250+
},
251+
{
252+
desc: "static region",
253+
config: &Config{
254+
Region: "one",
255+
},
256+
env: map[string]string{
257+
"AWS_REGION": "foo",
258+
},
259+
wantDefaultChain: true,
260+
wantRegion: "one",
261+
},
262+
}
263+
264+
for _, test := range testCases {
265+
t.Run(test.desc, func(t *testing.T) {
266+
defer envTest.RestoreEnv()
267+
envTest.ClearEnv()
268+
269+
envTest.Apply(test.env)
270+
271+
sess, err := createSession(test.config)
272+
requireErr(t, err, test.wantErr)
273+
274+
if err != nil {
275+
return
276+
}
277+
278+
gotCreds, err := sess.Config.Credentials.Get()
279+
280+
if test.wantDefaultChain {
281+
assert.NotEqual(t, credentials.StaticProviderName, gotCreds.ProviderName)
282+
} else {
283+
require.NoError(t, err)
284+
assert.Equal(t, test.wantCreds, gotCreds)
285+
}
286+
287+
if test.wantRegion != "" {
288+
assert.Equal(t, test.wantRegion, aws.StringValue(sess.Config.Region))
289+
}
290+
})
291+
}
292+
}
293+
294+
func requireErr(t *testing.T, err error, wantErr string) {
295+
t.Helper()
296+
297+
switch {
298+
case err != nil && wantErr == "":
299+
// force the assertion error.
300+
require.NoError(t, err)
301+
302+
case err == nil && wantErr != "":
303+
// force the assertion error.
304+
require.EqualError(t, err, wantErr)
305+
306+
case err != nil && wantErr != "":
307+
require.EqualError(t, err, wantErr)
308+
}
309+
}

0 commit comments

Comments
 (0)