@@ -2,6 +2,7 @@ package config
22
33import (
44 "context"
5+ "errors"
56 "fmt"
67 "io/ioutil"
78 "net/http"
@@ -15,9 +16,11 @@ import (
1516 "time"
1617
1718 "github.com/aws/aws-sdk-go-v2/aws"
19+ "github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
1820 "github.com/aws/aws-sdk-go-v2/internal/awstesting"
1921 "github.com/aws/aws-sdk-go-v2/service/sso"
2022 "github.com/aws/aws-sdk-go-v2/service/sts"
23+ "github.com/aws/smithy-go"
2124 "github.com/aws/smithy-go/middleware"
2225 smithytime "github.com/aws/smithy-go/time"
2326)
@@ -471,3 +474,122 @@ func TestResolveCredentialsCacheOptions(t *testing.T) {
471474 t .Errorf ("expect options to be called" )
472475 }
473476}
477+
478+ func TestResolveCredentialsIMDSClient (t * testing.T ) {
479+ expectEnabled := func (t * testing.T , err error ) {
480+ if err == nil {
481+ t .Fatalf ("expect error got none" )
482+ }
483+ if e , a := "expected HTTP client error" , err .Error (); ! strings .Contains (a , e ) {
484+ t .Fatalf ("expected %v error in %v" , e , a )
485+ }
486+ }
487+
488+ expectDisabled := func (t * testing.T , err error ) {
489+ var oe * smithy.OperationError
490+ if ! errors .As (err , & oe ) {
491+ t .Fatalf ("unexpected error: %v" , err )
492+ } else {
493+ e := errors .Unwrap (oe )
494+ if e == nil {
495+ t .Fatalf ("unexpected empty operation error: %v" , oe )
496+ } else {
497+ if ! strings .HasPrefix (e .Error (), "access disabled to EC2 IMDS" ) {
498+ t .Fatalf ("unexpected operation error: %v" , oe )
499+ }
500+ }
501+ }
502+ }
503+
504+ testcases := map [string ]struct {
505+ enabledState imds.ClientEnableState
506+ envvar string
507+ expectedState imds.ClientEnableState
508+ expectedError func (* testing.T , error )
509+ }{
510+ "default no options" : {
511+ expectedState : imds .ClientDefaultEnableState ,
512+ expectedError : expectEnabled ,
513+ },
514+
515+ "state enabled" : {
516+ enabledState : imds .ClientEnabled ,
517+ expectedState : imds .ClientEnabled ,
518+ expectedError : expectEnabled ,
519+ },
520+ "state disabled" : {
521+ enabledState : imds .ClientDisabled ,
522+ expectedState : imds .ClientDisabled ,
523+ expectedError : expectDisabled ,
524+ },
525+
526+ "env var DISABLED true" : {
527+ envvar : "true" ,
528+ expectedState : imds .ClientDisabled ,
529+ expectedError : expectDisabled ,
530+ },
531+ "env var DISABLED false" : {
532+ envvar : "false" ,
533+ expectedState : imds .ClientEnabled ,
534+ expectedError : expectEnabled ,
535+ },
536+
537+ "option state enabled overrides env var DISABLED true" : {
538+ enabledState : imds .ClientEnabled ,
539+ envvar : "true" ,
540+ expectedState : imds .ClientEnabled ,
541+ expectedError : expectEnabled ,
542+ },
543+ "option state disabled overrides env var DISABLED false" : {
544+ enabledState : imds .ClientDisabled ,
545+ envvar : "false" ,
546+ expectedState : imds .ClientDisabled ,
547+ expectedError : expectDisabled ,
548+ },
549+ }
550+
551+ for name , tc := range testcases {
552+ t .Run (name , func (t * testing.T ) {
553+ restoreEnv := awstesting .StashEnv ()
554+ defer awstesting .PopEnv (restoreEnv )
555+
556+ var httpClient HTTPClient
557+ if tc .expectedState == imds .ClientDisabled {
558+ httpClient = stubErrorClient {err : fmt .Errorf ("expect HTTP client not to be called" )}
559+ } else {
560+ httpClient = stubErrorClient {err : fmt .Errorf ("expected HTTP client error" )}
561+ }
562+
563+ opts := []func (* LoadOptions ) error {
564+ WithRetryer (func () aws.Retryer { return aws.NopRetryer {} }),
565+ WithHTTPClient (httpClient ),
566+ }
567+
568+ if tc .enabledState != imds .ClientDefaultEnableState {
569+ opts = append (opts ,
570+ WithEC2IMDSClientEnableState (tc .enabledState ),
571+ )
572+ }
573+
574+ if tc .envvar != "" {
575+ os .Setenv ("AWS_EC2_METADATA_DISABLED" , tc .envvar )
576+ }
577+
578+ c , err := LoadDefaultConfig (context .TODO (), opts ... )
579+ if err != nil {
580+ t .Fatalf ("could not load config: %s" , err )
581+ }
582+
583+ creds := c .Credentials
584+
585+ _ , err = creds .Retrieve (context .TODO ())
586+ tc .expectedError (t , err )
587+ })
588+ }
589+ }
590+
591+ type stubErrorClient struct {
592+ err error
593+ }
594+
595+ func (c stubErrorClient ) Do (* http.Request ) (* http.Response , error ) { return nil , c .err }
0 commit comments