Skip to content

Commit 3365c8c

Browse files
committed
Update placement of fallback logic.
1 parent 6afd079 commit 3365c8c

File tree

2 files changed

+38
-36
lines changed

2 files changed

+38
-36
lines changed

sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Licensed under the MIT License. See LICENSE.txt in the project root for
44
# license information.
55
# -------------------------------------------------------------------------
6-
from typing import TypeVar, Any, MutableMapping, cast
6+
from typing import TypeVar, Any, MutableMapping, cast, Optional
77

88
from azure.core.pipeline import PipelineRequest
99
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
@@ -19,7 +19,7 @@
1919
class CosmosBearerTokenCredentialPolicy(BearerTokenCredentialPolicy):
2020
AadDefaultScope = "https://cosmos.azure.com/.default"
2121

22-
def __init__(self, credential, account_scope: str, override_scope: str = None):
22+
def __init__(self, credential, account_scope: str, override_scope: Optional[str] = None):
2323
self._account_scope = account_scope
2424
self._override_scope = override_scope
2525
self._current_scope = override_scope or account_scope
@@ -41,25 +41,11 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
4141
4242
:param ~azure.core.pipeline.PipelineRequest request: the request
4343
"""
44-
self.authorize_request(request)
45-
super().on_request(request)
46-
# The None-check for self._token is done in the parent on_request
47-
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)
48-
49-
def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
50-
"""Acquire a token from the credential and authorize the request with it.
51-
52-
Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
53-
authorize future requests.
54-
55-
:param ~azure.core.pipeline.PipelineRequest request: the request
56-
:param str scopes: required scopes of authentication
57-
"""
5844
tried_fallback = False
5945
while True:
6046
try:
61-
super().authorize_request(request, self._current_scope, **kwargs)
62-
# The None-check for self._token is done in the parent authorize_request
47+
super().on_request(request)
48+
# The None-check for self._token is done in the parent on_request
6349
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)
6450
break
6551
except Exception as ex:
@@ -70,7 +56,22 @@ def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes:
7056
self._current_scope != self.AadDefaultScope and
7157
"AADSTS500011" in str(ex)
7258
):
59+
self._scopes = (self.AadDefaultScope,)
7360
self._current_scope = self.AadDefaultScope
7461
tried_fallback = True
7562
continue
7663
raise
64+
65+
def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
66+
"""Acquire a token from the credential and authorize the request with it.
67+
68+
Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
69+
authorize future requests.
70+
71+
:param ~azure.core.pipeline.PipelineRequest request: the request
72+
:param str scopes: required scopes of authentication
73+
"""
74+
75+
super().authorize_request(request, self._current_scope, **kwargs)
76+
# The None-check for self._token is done in the parent authorize_request
77+
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)

sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# license information.
55
# -------------------------------------------------------------------------
66

7-
from typing import Any, MutableMapping, TypeVar, cast
7+
from typing import Any, MutableMapping, TypeVar, cast, Optional
88

99
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy
1010
from azure.core.pipeline import PipelineRequest
@@ -20,7 +20,7 @@
2020
class AsyncCosmosBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy):
2121
AadDefaultScope = "https://cosmos.azure.com/.default"
2222

23-
def __init__(self, credential, account_scope: str, override_scope: str):
23+
def __init__(self, credential, account_scope: str, override_scope: Optional[str] = None):
2424
self._account_scope = account_scope
2525
self._override_scope = override_scope
2626
self._current_scope = override_scope or account_scope
@@ -42,25 +42,11 @@ async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
4242
:type request: ~azure.core.pipeline.PipelineRequest
4343
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
4444
"""
45-
await self.authorize_request(request)
46-
await super().on_request(request)
47-
# The None-check for self._token is done in the parent on_request
48-
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)
49-
50-
async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
51-
"""Acquire a token from the credential and authorize the request with it.
52-
53-
Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
54-
authorize future requests.
55-
56-
:param ~azure.core.pipeline.PipelineRequest request: the request
57-
:param str scopes: required scopes of authentication
58-
"""
5945
tried_fallback = False
6046
while True:
6147
try:
62-
await super().authorize_request(request, self._current_scope, **kwargs)
63-
# The None-check for self._token is done in the parent authorize_request
48+
await super().on_request(request)
49+
# The None-check for self._token is done in the parent on_request
6450
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)
6551
break
6652
except Exception as ex:
@@ -71,7 +57,22 @@ async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *sc
7157
self._current_scope != self.AadDefaultScope and
7258
"AADSTS500011" in str(ex)
7359
):
60+
self._scopes = (self.AadDefaultScope,)
7461
self._current_scope = self.AadDefaultScope
7562
tried_fallback = True
7663
continue
7764
raise
65+
66+
async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
67+
"""Acquire a token from the credential and authorize the request with it.
68+
69+
Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
70+
authorize future requests.
71+
72+
:param ~azure.core.pipeline.PipelineRequest request: the request
73+
:param str scopes: required scopes of authentication
74+
"""
75+
76+
await super().authorize_request(request, self._current_scope, **kwargs)
77+
# The None-check for self._token is done in the parent authorize_request
78+
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)

0 commit comments

Comments
 (0)