44# license information.
55# -------------------------------------------------------------------------
66
7- from typing import Any , MutableMapping , TypeVar , cast
7+ from typing import Any , MutableMapping , TypeVar , cast , Optional
88
99from azure .core .pipeline .policies import AsyncBearerTokenCredentialPolicy
1010from azure .core .pipeline import PipelineRequest
2020class AsyncCosmosBearerTokenCredentialPolicy (AsyncBearerTokenCredentialPolicy ):
2121 AadDefaultScope = "https://cosmos.azure.com/.default"
2222
23- def __init__ (self , credential , account_scope : str , override_scope : str ):
23+ def __init__ (self , credential , account_scope : str , override_scope : Optional [ str ] = None ):
2424 self ._account_scope = account_scope
2525 self ._override_scope = override_scope
2626 self ._current_scope = override_scope or account_scope
@@ -42,25 +42,11 @@ async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
4242 :type request: ~azure.core.pipeline.PipelineRequest
4343 :raises: :class:`~azure.core.exceptions.ServiceRequestError`
4444 """
45- await self .authorize_request (request )
46- await super ().on_request (request )
47- # The None-check for self._token is done in the parent on_request
48- self ._update_headers (request .http_request .headers , cast (AccessToken , self ._token ).token )
49-
50- async def authorize_request (self , request : PipelineRequest [HTTPRequestType ], * scopes : str , ** kwargs : Any ) -> None :
51- """Acquire a token from the credential and authorize the request with it.
52-
53- Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
54- authorize future requests.
55-
56- :param ~azure.core.pipeline.PipelineRequest request: the request
57- :param str scopes: required scopes of authentication
58- """
5945 tried_fallback = False
6046 while True :
6147 try :
62- await super ().authorize_request (request , self . _current_scope , ** kwargs )
63- # The None-check for self._token is done in the parent authorize_request
48+ await super ().on_request (request )
49+ # The None-check for self._token is done in the parent on_request
6450 self ._update_headers (request .http_request .headers , cast (AccessToken , self ._token ).token )
6551 break
6652 except Exception as ex :
@@ -71,7 +57,22 @@ async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *sc
7157 self ._current_scope != self .AadDefaultScope and
7258 "AADSTS500011" in str (ex )
7359 ):
60+ self ._scopes = (self .AadDefaultScope ,)
7461 self ._current_scope = self .AadDefaultScope
7562 tried_fallback = True
7663 continue
7764 raise
65+
66+ async def authorize_request (self , request : PipelineRequest [HTTPRequestType ], * scopes : str , ** kwargs : Any ) -> None :
67+ """Acquire a token from the credential and authorize the request with it.
68+
69+ Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
70+ authorize future requests.
71+
72+ :param ~azure.core.pipeline.PipelineRequest request: the request
73+ :param str scopes: required scopes of authentication
74+ """
75+
76+ await super ().authorize_request (request , self ._current_scope , ** kwargs )
77+ # The None-check for self._token is done in the parent authorize_request
78+ self ._update_headers (request .http_request .headers , cast (AccessToken , self ._token ).token )
0 commit comments