@@ -26,22 +26,28 @@ import (
2626 "golang.org/x/oauth2"
2727)
2828
29- type awsSecurityCredentials struct {
30- AccessKeyID string `json:"AccessKeyID"`
29+ // AwsSecurityCredentials models AWS security credentials.
30+ type AwsSecurityCredentials struct {
31+ // AccessKeyId is the AWS Access Key ID - Required.
32+ AccessKeyID string `json:"AccessKeyID"`
33+ // SecretAccessKey is the AWS Secret Access Key - Required.
3134 SecretAccessKey string `json:"SecretAccessKey"`
32- SecurityToken string `json:"Token"`
35+ // SessionToken is the AWS Session token. This should be provided for temporary AWS security credentials - Optional.
36+ SessionToken string `json:"Token"`
3337}
3438
3539// awsRequestSigner is a utility class to sign http requests using a AWS V4 signature.
3640type awsRequestSigner struct {
3741 RegionName string
38- AwsSecurityCredentials awsSecurityCredentials
42+ AwsSecurityCredentials * AwsSecurityCredentials
3943}
4044
4145// getenv aliases os.Getenv for testing
4246var getenv = os .Getenv
4347
4448const (
49+ defaultRegionalCredentialVerificationUrl = "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15"
50+
4551 // AWS Signature Version 4 signing algorithm identifier.
4652 awsAlgorithm = "AWS4-HMAC-SHA256"
4753
@@ -197,8 +203,8 @@ func (rs *awsRequestSigner) SignRequest(req *http.Request) error {
197203
198204 signedRequest .Header .Add ("host" , requestHost (req ))
199205
200- if rs .AwsSecurityCredentials .SecurityToken != "" {
201- signedRequest .Header .Add (awsSecurityTokenHeader , rs .AwsSecurityCredentials .SecurityToken )
206+ if rs .AwsSecurityCredentials .SessionToken != "" {
207+ signedRequest .Header .Add (awsSecurityTokenHeader , rs .AwsSecurityCredentials .SessionToken )
202208 }
203209
204210 if signedRequest .Header .Get ("date" ) == "" {
@@ -251,16 +257,18 @@ func (rs *awsRequestSigner) generateAuthentication(req *http.Request, timestamp
251257}
252258
253259type awsCredentialSource struct {
254- EnvironmentID string
255- RegionURL string
256- RegionalCredVerificationURL string
257- CredVerificationURL string
258- IMDSv2SessionTokenURL string
259- TargetResource string
260- requestSigner * awsRequestSigner
261- region string
262- ctx context.Context
263- client * http.Client
260+ environmentID string
261+ regionURL string
262+ regionalCredVerificationURL string
263+ credVerificationURL string
264+ imdsv2SessionTokenURL string
265+ targetResource string
266+ requestSigner * awsRequestSigner
267+ region string
268+ ctx context.Context
269+ client * http.Client
270+ awsSecurityCredentialsSupplier AwsSecurityCredentialsSupplier
271+ supplierOptions SupplierOptions
264272}
265273
266274type awsRequestHeader struct {
@@ -292,18 +300,25 @@ func canRetrieveSecurityCredentialFromEnvironment() bool {
292300 return getenv (awsAccessKeyId ) != "" && getenv (awsSecretAccessKey ) != ""
293301}
294302
295- func shouldUseMetadataServer () bool {
296- return ! canRetrieveRegionFromEnvironment () || ! canRetrieveSecurityCredentialFromEnvironment ()
303+ func ( cs awsCredentialSource ) shouldUseMetadataServer () bool {
304+ return cs . awsSecurityCredentialsSupplier == nil && ( ! canRetrieveRegionFromEnvironment () || ! canRetrieveSecurityCredentialFromEnvironment () )
297305}
298306
299307func (cs awsCredentialSource ) credentialSourceType () string {
308+ if cs .awsSecurityCredentialsSupplier != nil {
309+ return "programmatic"
310+ }
300311 return "aws"
301312}
302313
303314func (cs awsCredentialSource ) subjectToken () (string , error ) {
315+ // Set Defaults
316+ if cs .regionalCredVerificationURL == "" {
317+ cs .regionalCredVerificationURL = defaultRegionalCredentialVerificationUrl
318+ }
304319 if cs .requestSigner == nil {
305320 headers := make (map [string ]string )
306- if shouldUseMetadataServer () {
321+ if cs . shouldUseMetadataServer () {
307322 awsSessionToken , err := cs .getAWSSessionToken ()
308323 if err != nil {
309324 return "" , err
@@ -318,8 +333,8 @@ func (cs awsCredentialSource) subjectToken() (string, error) {
318333 if err != nil {
319334 return "" , err
320335 }
321-
322- if cs . region , err = cs . getRegion ( headers ); err != nil {
336+ cs . region , err = cs . getRegion ( headers )
337+ if err != nil {
323338 return "" , err
324339 }
325340
@@ -331,16 +346,16 @@ func (cs awsCredentialSource) subjectToken() (string, error) {
331346
332347 // Generate the signed request to AWS STS GetCallerIdentity API.
333348 // Use the required regional endpoint. Otherwise, the request will fail.
334- req , err := http .NewRequest ("POST" , strings .Replace (cs .RegionalCredVerificationURL , "{region}" , cs .region , 1 ), nil )
349+ req , err := http .NewRequest ("POST" , strings .Replace (cs .regionalCredVerificationURL , "{region}" , cs .region , 1 ), nil )
335350 if err != nil {
336351 return "" , err
337352 }
338353 // The full, canonical resource name of the workload identity pool
339354 // provider, with or without the HTTPS prefix.
340355 // Including this header as part of the signature is recommended to
341356 // ensure data integrity.
342- if cs .TargetResource != "" {
343- req .Header .Add ("x-goog-cloud-target-resource" , cs .TargetResource )
357+ if cs .targetResource != "" {
358+ req .Header .Add ("x-goog-cloud-target-resource" , cs .targetResource )
344359 }
345360 cs .requestSigner .SignRequest (req )
346361
@@ -387,11 +402,11 @@ func (cs awsCredentialSource) subjectToken() (string, error) {
387402}
388403
389404func (cs * awsCredentialSource ) getAWSSessionToken () (string , error ) {
390- if cs .IMDSv2SessionTokenURL == "" {
405+ if cs .imdsv2SessionTokenURL == "" {
391406 return "" , nil
392407 }
393408
394- req , err := http .NewRequest ("PUT" , cs .IMDSv2SessionTokenURL , nil )
409+ req , err := http .NewRequest ("PUT" , cs .imdsv2SessionTokenURL , nil )
395410 if err != nil {
396411 return "" , err
397412 }
@@ -410,25 +425,29 @@ func (cs *awsCredentialSource) getAWSSessionToken() (string, error) {
410425 }
411426
412427 if resp .StatusCode != 200 {
413- return "" , fmt .Errorf ("oauth2/google: unable to retrieve AWS session token - %s" , string (respBody ))
428+ return "" , fmt .Errorf ("oauth2/google/externalaccount : unable to retrieve AWS session token - %s" , string (respBody ))
414429 }
415430
416431 return string (respBody ), nil
417432}
418433
419434func (cs * awsCredentialSource ) getRegion (headers map [string ]string ) (string , error ) {
435+ if cs .awsSecurityCredentialsSupplier != nil {
436+ return cs .awsSecurityCredentialsSupplier .AwsRegion (cs .ctx , cs .supplierOptions )
437+ }
420438 if canRetrieveRegionFromEnvironment () {
421439 if envAwsRegion := getenv (awsRegion ); envAwsRegion != "" {
440+ cs .region = envAwsRegion
422441 return envAwsRegion , nil
423442 }
424443 return getenv ("AWS_DEFAULT_REGION" ), nil
425444 }
426445
427- if cs .RegionURL == "" {
428- return "" , errors .New ("oauth2/google: unable to determine AWS region" )
446+ if cs .regionURL == "" {
447+ return "" , errors .New ("oauth2/google/externalaccount : unable to determine AWS region" )
429448 }
430449
431- req , err := http .NewRequest ("GET" , cs .RegionURL , nil )
450+ req , err := http .NewRequest ("GET" , cs .regionURL , nil )
432451 if err != nil {
433452 return "" , err
434453 }
@@ -449,7 +468,7 @@ func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, err
449468 }
450469
451470 if resp .StatusCode != 200 {
452- return "" , fmt .Errorf ("oauth2/google: unable to retrieve AWS region - %s" , string (respBody ))
471+ return "" , fmt .Errorf ("oauth2/google/externalaccount : unable to retrieve AWS region - %s" , string (respBody ))
453472 }
454473
455474 // This endpoint will return the region in format: us-east-2b.
@@ -461,12 +480,15 @@ func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, err
461480 return string (respBody [:respBodyEnd ]), nil
462481}
463482
464- func (cs * awsCredentialSource ) getSecurityCredentials (headers map [string ]string ) (result awsSecurityCredentials , err error ) {
483+ func (cs * awsCredentialSource ) getSecurityCredentials (headers map [string ]string ) (result * AwsSecurityCredentials , err error ) {
484+ if cs .awsSecurityCredentialsSupplier != nil {
485+ return cs .awsSecurityCredentialsSupplier .AwsSecurityCredentials (cs .ctx , cs .supplierOptions )
486+ }
465487 if canRetrieveSecurityCredentialFromEnvironment () {
466- return awsSecurityCredentials {
488+ return & AwsSecurityCredentials {
467489 AccessKeyID : getenv (awsAccessKeyId ),
468490 SecretAccessKey : getenv (awsSecretAccessKey ),
469- SecurityToken : getenv (awsSessionToken ),
491+ SessionToken : getenv (awsSessionToken ),
470492 }, nil
471493 }
472494
@@ -481,20 +503,20 @@ func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string)
481503 }
482504
483505 if credentials .AccessKeyID == "" {
484- return result , errors .New ("oauth2/google: missing AccessKeyId credential" )
506+ return result , errors .New ("oauth2/google/externalaccount : missing AccessKeyId credential" )
485507 }
486508
487509 if credentials .SecretAccessKey == "" {
488- return result , errors .New ("oauth2/google: missing SecretAccessKey credential" )
510+ return result , errors .New ("oauth2/google/externalaccount : missing SecretAccessKey credential" )
489511 }
490512
491- return credentials , nil
513+ return & credentials , nil
492514}
493515
494- func (cs * awsCredentialSource ) getMetadataSecurityCredentials (roleName string , headers map [string ]string ) (awsSecurityCredentials , error ) {
495- var result awsSecurityCredentials
516+ func (cs * awsCredentialSource ) getMetadataSecurityCredentials (roleName string , headers map [string ]string ) (AwsSecurityCredentials , error ) {
517+ var result AwsSecurityCredentials
496518
497- req , err := http .NewRequest ("GET" , fmt .Sprintf ("%s/%s" , cs .CredVerificationURL , roleName ), nil )
519+ req , err := http .NewRequest ("GET" , fmt .Sprintf ("%s/%s" , cs .credVerificationURL , roleName ), nil )
498520 if err != nil {
499521 return result , err
500522 }
@@ -516,19 +538,19 @@ func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string, h
516538 }
517539
518540 if resp .StatusCode != 200 {
519- return result , fmt .Errorf ("oauth2/google: unable to retrieve AWS security credentials - %s" , string (respBody ))
541+ return result , fmt .Errorf ("oauth2/google/externalaccount : unable to retrieve AWS security credentials - %s" , string (respBody ))
520542 }
521543
522544 err = json .Unmarshal (respBody , & result )
523545 return result , err
524546}
525547
526548func (cs * awsCredentialSource ) getMetadataRoleName (headers map [string ]string ) (string , error ) {
527- if cs .CredVerificationURL == "" {
528- return "" , errors .New ("oauth2/google: unable to determine the AWS metadata server security credentials endpoint" )
549+ if cs .credVerificationURL == "" {
550+ return "" , errors .New ("oauth2/google/externalaccount : unable to determine the AWS metadata server security credentials endpoint" )
529551 }
530552
531- req , err := http .NewRequest ("GET" , cs .CredVerificationURL , nil )
553+ req , err := http .NewRequest ("GET" , cs .credVerificationURL , nil )
532554 if err != nil {
533555 return "" , err
534556 }
@@ -549,7 +571,7 @@ func (cs *awsCredentialSource) getMetadataRoleName(headers map[string]string) (s
549571 }
550572
551573 if resp .StatusCode != 200 {
552- return "" , fmt .Errorf ("oauth2/google: unable to retrieve AWS role name - %s" , string (respBody ))
574+ return "" , fmt .Errorf ("oauth2/google/externalaccount : unable to retrieve AWS role name - %s" , string (respBody ))
553575 }
554576
555577 return string (respBody ), nil
0 commit comments