@@ -249,7 +249,7 @@ async def call(self):
249
249
class AioCredentials (Credentials ):
250
250
async def get_frozen_credentials (self ):
251
251
return ReadOnlyCredentials (
252
- self .access_key , self .secret_key , self .token
252
+ self .access_key , self .secret_key , self .token , self . account_id
253
253
)
254
254
255
255
@@ -299,6 +299,19 @@ def token(self):
299
299
def token (self , value ):
300
300
self ._token = value
301
301
302
+ @property
303
+ def account_id (self ):
304
+ # TODO: this needs to be resolved
305
+ raise NotImplementedError (
306
+ "missing call to self._refresh. "
307
+ "Use get_frozen_credentials instead"
308
+ )
309
+ return self ._account_id
310
+
311
+ @account_id .setter
312
+ def account_id (self , value ):
313
+ self ._account_id = value
314
+
302
315
async def _refresh (self ):
303
316
if not self .refresh_needed (self ._advisory_refresh_timeout ):
304
317
return
@@ -347,7 +360,7 @@ async def _protected_refresh(self, is_mandatory):
347
360
return
348
361
self ._set_from_data (metadata )
349
362
self ._frozen_credentials = ReadOnlyCredentials (
350
- self ._access_key , self ._secret_key , self ._token
363
+ self ._access_key , self ._secret_key , self ._token , self . _account_id
351
364
)
352
365
if self ._is_expired ():
353
366
msg = (
@@ -370,6 +383,7 @@ def __init__(self, refresh_using, method, time_fetcher=_local_now):
370
383
self ._access_key = None
371
384
self ._secret_key = None
372
385
self ._token = None
386
+ self ._account_id = None
373
387
self ._expiry_time = None
374
388
self ._time_fetcher = time_fetcher
375
389
self ._refresh_lock = asyncio .Lock ()
@@ -399,13 +413,16 @@ async def _get_cached_credentials(self):
399
413
400
414
creds = response ['Credentials' ]
401
415
expiration = _serialize_if_needed (creds ['Expiration' ], iso = True )
402
- return {
416
+ credentials = {
403
417
'access_key' : creds ['AccessKeyId' ],
404
418
'secret_key' : creds ['SecretAccessKey' ],
405
419
'token' : creds ['SessionToken' ],
406
420
'expiry_time' : expiration ,
421
+ 'account_id' : creds .get ('AccountId' ),
407
422
}
408
423
424
+ return credentials
425
+
409
426
410
427
class AioBaseAssumeRoleCredentialFetcher (
411
428
BaseAssumeRoleCredentialFetcher , AioCachedCredentialFetcher
@@ -421,7 +438,9 @@ async def _get_credentials(self):
421
438
kwargs = self ._assume_role_kwargs ()
422
439
client = await self ._create_client ()
423
440
async with client as sts :
424
- return await sts .assume_role (** kwargs )
441
+ response = await sts .assume_role (** kwargs )
442
+ self ._add_account_id_to_response (response )
443
+ return response
425
444
426
445
async def _create_client (self ):
427
446
"""Create an STS client using the source credentials."""
@@ -465,7 +484,9 @@ async def _get_credentials(self):
465
484
# the token, explicitly configure the client to not sign requests.
466
485
config = AioConfig (signature_version = UNSIGNED )
467
486
async with self ._client_creator ('sts' , config = config ) as client :
468
- return await client .assume_role_with_web_identity (** kwargs )
487
+ response = await client .assume_role_with_web_identity (** kwargs )
488
+ self ._add_account_id_to_response (response )
489
+ return response
469
490
470
491
def _assume_role_kwargs (self ):
471
492
"""Get the arguments for assume role based on current configuration."""
@@ -498,6 +519,7 @@ async def load(self):
498
519
secret_key = creds_dict ['secret_key' ],
499
520
token = creds_dict .get ('token' ),
500
521
method = self .METHOD ,
522
+ account_id = creds_dict .get ('account_id' ),
501
523
)
502
524
503
525
async def _retrieve_credentials_using (self , credential_process ):
@@ -528,6 +550,7 @@ async def _retrieve_credentials_using(self, credential_process):
528
550
'secret_key' : parsed ['SecretAccessKey' ],
529
551
'token' : parsed .get ('SessionToken' ),
530
552
'expiry_time' : parsed .get ('Expiration' ),
553
+ 'account_id' : self ._get_account_id (parsed ),
531
554
}
532
555
except KeyError as e :
533
556
raise CredentialRetrievalError (
@@ -573,13 +596,15 @@ async def load(self):
573
596
expiry_time ,
574
597
refresh_using = fetcher ,
575
598
method = self .METHOD ,
599
+ account_id = credentials ['account_id' ],
576
600
)
577
601
578
602
return AioCredentials (
579
603
credentials ['access_key' ],
580
604
credentials ['secret_key' ],
581
605
credentials ['token' ],
582
606
method = self .METHOD ,
607
+ account_id = credentials ['account_id' ],
583
608
)
584
609
else :
585
610
return None
@@ -621,8 +646,13 @@ async def load(self):
621
646
config , self .ACCESS_KEY , self .SECRET_KEY
622
647
)
623
648
token = self ._get_session_token (config )
649
+ account_id = self ._get_account_id (config )
624
650
return AioCredentials (
625
- access_key , secret_key , token , method = self .METHOD
651
+ access_key ,
652
+ secret_key ,
653
+ token ,
654
+ method = self .METHOD ,
655
+ account_id = account_id ,
626
656
)
627
657
628
658
@@ -643,8 +673,13 @@ async def load(self):
643
673
profile_config , self .ACCESS_KEY , self .SECRET_KEY
644
674
)
645
675
token = self ._get_session_token (profile_config )
676
+ account_id = self ._get_account_id (profile_config )
646
677
return AioCredentials (
647
- access_key , secret_key , token , method = self .METHOD
678
+ access_key ,
679
+ secret_key ,
680
+ token ,
681
+ method = self .METHOD ,
682
+ account_id = account_id ,
648
683
)
649
684
else :
650
685
return None
@@ -748,8 +783,8 @@ async def _resolve_credentials_from_profile(self, profile_name):
748
783
):
749
784
# This is only here for backwards compatibility. If this provider
750
785
# isn't given a profile provider builder we still want to be able
751
- # handle the basic static credential case as we would before the
752
- # provile provider builder parameter was added.
786
+ # to handle the basic static credential case as we would before the
787
+ # profile provider builder parameter was added.
753
788
return self ._resolve_static_credentials_from_profile (profile )
754
789
elif self ._has_static_credentials (
755
790
profile
@@ -920,6 +955,7 @@ async def _retrieve_or_fail(self):
920
955
method = self .METHOD ,
921
956
expiry_time = _parse_if_needed (creds ['expiry_time' ]),
922
957
refresh_using = fetcher ,
958
+ account_id = creds .get ('account_id' ),
923
959
)
924
960
925
961
def _create_fetcher (self , full_uri , * args , ** kwargs ):
@@ -941,6 +977,7 @@ async def fetch_creds():
941
977
'secret_key' : response ['SecretAccessKey' ],
942
978
'token' : response ['Token' ],
943
979
'expiry_time' : response ['Expiration' ],
980
+ 'account_id' : response .get ('AccountId' ),
944
981
}
945
982
946
983
return fetch_creds
@@ -1004,6 +1041,7 @@ async def _get_credentials(self):
1004
1041
'Expiration' : self ._parse_timestamp (
1005
1042
credentials ['expiration' ]
1006
1043
),
1044
+ 'AccountId' : self ._account_id ,
1007
1045
},
1008
1046
}
1009
1047
return credentials
0 commit comments