@@ -19,17 +19,14 @@ class DataPlaneTokenSource:
1919 """
2020
2121 # TODO: Enable async once its stable. @oauth_credentials_provider must also have async enabled.
22- def __init__ (self ,
23- token_exchange_host : str ,
24- cpts : Callable [[], Token ],
25- disable_async : Optional [bool ] = True ):
22+ def __init__ (self , token_exchange_host : str , cpts : Callable [[], Token ], disable_async : Optional [bool ] = True ):
2623 self ._cpts = cpts
2724 self ._token_exchange_host = token_exchange_host
2825 self ._token_sources = {}
2926 self ._disable_async = disable_async
3027 self ._lock = threading .Lock ()
3128
32- def get_token (self , endpoint , auth_details ):
29+ def token (self , endpoint , auth_details ):
3330 key = f"{ endpoint } :{ auth_details } "
3431
3532 # First, try to read without acquiring the lock to avoid contention.
@@ -43,8 +40,9 @@ def get_token(self, endpoint, auth_details):
4340 # Another thread might have created it while we were waiting for the lock.
4441 token_source = self ._token_sources .get (key )
4542 if not token_source :
46- token_source = DataPlaneEndpointTokenSource (self ._token_exchange_host , self ._cpts ,
47- auth_details , self ._disable_async )
43+ token_source = DataPlaneEndpointTokenSource (
44+ self ._token_exchange_host , self ._cpts , auth_details , self ._disable_async
45+ )
4846 self ._token_sources [key ] = token_source
4947
5048 return token_source .token ()
@@ -55,8 +53,7 @@ class DataPlaneEndpointTokenSource(oauth.Refreshable):
5553 EXPERIMENTAL A token source for a specific DataPlane endpoint.
5654 """
5755
58- def __init__ (self , token_exchange_host : str , cpts : Callable [[], Token ], auth_details : str ,
59- disable_async : bool ):
56+ def __init__ (self , token_exchange_host : str , cpts : Callable [[], Token ], auth_details : str , disable_async : bool ):
6057 super ().__init__ (disable_async = disable_async )
6158 self ._auth_details = auth_details
6259 self ._cpts = cpts
@@ -65,16 +62,20 @@ def __init__(self, token_exchange_host: str, cpts: Callable[[], Token], auth_det
6562 def refresh (self ) -> Token :
6663 control_plane_token = self ._cpts ()
6764 headers = {"Content-Type" : URL_ENCODED_CONTENT_TYPE }
68- params = parse .urlencode ({
69- "grant_type" : JWT_BEARER_GRANT_TYPE ,
70- "authorization_details" : self ._auth_details ,
71- "assertion" : control_plane_token .access_token
72- })
73- return oauth .retrieve_token (client_id = "" ,
74- client_secret = "" ,
75- token_url = self ._token_exchange_host + OIDC_TOKEN_PATH ,
76- params = params ,
77- headers = headers )
65+ params = parse .urlencode (
66+ {
67+ "grant_type" : JWT_BEARER_GRANT_TYPE ,
68+ "authorization_details" : self ._auth_details ,
69+ "assertion" : control_plane_token .access_token ,
70+ }
71+ )
72+ return oauth .retrieve_token (
73+ client_id = "" ,
74+ client_secret = "" ,
75+ token_url = self ._token_exchange_host + OIDC_TOKEN_PATH ,
76+ params = params ,
77+ headers = headers ,
78+ )
7879
7980
8081@dataclass
0 commit comments