Skip to content

Commit 862a371

Browse files
authored
TokenCredentialCache: Adds a fallback mechanism to AAD scope override. (#42731)
* Update AAD fallback mechanism. * Resolve merge conflicts. * Update changelog. * Update placement of fallback logic. * Code cleaup. * Updates based on code reviews * Code cleanup * Code cleanup
1 parent 3dc804e commit 862a371

File tree

8 files changed

+310
-54
lines changed

8 files changed

+310
-54
lines changed

sdk/cosmos/azure-cosmos/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#### Other Changes
2121
* Added session token false progress merge logic. See [42393](https://github.com/Azure/azure-sdk-for-python/pull/42393)
22+
* Added a fallback mechanism to AAD scope override. See [PR 42731](https://github.com/Azure/azure-sdk-for-python/pull/42731).
2223

2324
### 4.14.0b2 (2025-08-12)
2425

sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,32 @@
33
# Licensed under the MIT License. See LICENSE.txt in the project root for
44
# license information.
55
# -------------------------------------------------------------------------
6-
from typing import TypeVar, Any, MutableMapping, cast
6+
from typing import TypeVar, Any, MutableMapping, cast, Optional
77

88
from azure.core.pipeline import PipelineRequest
99
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
1010
from azure.core.pipeline.transport import HttpRequest as LegacyHttpRequest
1111
from azure.core.rest import HttpRequest
1212
from azure.core.credentials import AccessToken
13+
from azure.core.exceptions import HttpResponseError
1314

1415
from .http_constants import HttpHeaders
16+
from ._constants import _Constants as Constants
1517

1618
HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest)
1719

18-
20+
# NOTE: This class accesses protected members (_scopes, _token) of the parent class
21+
# to implement fallback and scope-switching logic not exposed by the public API.
22+
# Composition was considered, but still required accessing protected members, so inheritance is retained
23+
# for seamless Azure SDK pipeline integration.
1924
class CosmosBearerTokenCredentialPolicy(BearerTokenCredentialPolicy):
25+
AadDefaultScope = Constants.AAD_DEFAULT_SCOPE
26+
27+
def __init__(self, credential, account_scope: str, override_scope: Optional[str] = None):
28+
self._account_scope = account_scope
29+
self._override_scope = override_scope
30+
self._current_scope = override_scope or account_scope
31+
super().__init__(credential, self._current_scope)
2032

2133
@staticmethod
2234
def _update_headers(headers: MutableMapping[str, str], token: str) -> None:
@@ -34,9 +46,26 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
3446
3547
:param ~azure.core.pipeline.PipelineRequest request: the request
3648
"""
37-
super().on_request(request)
38-
# The None-check for self._token is done in the parent on_request
39-
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)
49+
tried_fallback = False
50+
while True:
51+
try:
52+
super().on_request(request)
53+
# The None-check for self._token is done in the parent on_request
54+
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)
55+
break
56+
except HttpResponseError as ex:
57+
# Only fallback if not using override, not already tried, and error is AADSTS500011
58+
if (
59+
not self._override_scope and
60+
not tried_fallback and
61+
self._current_scope != self.AadDefaultScope and
62+
"AADSTS500011" in str(ex)
63+
):
64+
self._scopes = (self.AadDefaultScope,)
65+
self._current_scope = self.AadDefaultScope
66+
tried_fallback = True
67+
continue
68+
raise
4069

4170
def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
4271
"""Acquire a token from the credential and authorize the request with it.
@@ -47,6 +76,7 @@ def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes:
4776
:param ~azure.core.pipeline.PipelineRequest request: the request
4877
:param str scopes: required scopes of authentication
4978
"""
79+
5080
super().authorize_request(request, *scopes, **kwargs)
5181
# The None-check for self._token is done in the parent authorize_request
5282
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)

sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class _Constants:
5656
CIRCUIT_BREAKER_ENABLED_CONFIG: str = "AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"
5757
CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT: str = "False"
5858
AAD_SCOPE_OVERRIDE: str = "AZURE_COSMOS_AAD_SCOPE_OVERRIDE"
59+
AAD_DEFAULT_SCOPE: str = "https://cosmos.azure.com/.default"
5960

6061
# Database Account Retry Policy constants
6162
AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES: str = "AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES"

sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ def __init__( # pylint: disable=too-many-statements
132132
The connection policy for the client.
133133
:param documents.ConsistencyLevel consistency_level:
134134
The default consistency policy for client operations.
135-
136135
"""
137136
self.client_id = str(uuid.uuid4())
138137
self.url_connection = url_connection
@@ -205,11 +204,12 @@ def __init__( # pylint: disable=too-many-statements
205204
credentials_policy = None
206205
if self.aad_credentials:
207206
scope_override = os.environ.get(Constants.AAD_SCOPE_OVERRIDE, "")
208-
if scope_override:
209-
scope = scope_override
210-
else:
211-
scope = base.create_scope_from_url(self.url_connection)
212-
credentials_policy = CosmosBearerTokenCredentialPolicy(self.aad_credentials, scope)
207+
account_scope = base.create_scope_from_url(self.url_connection)
208+
credentials_policy = CosmosBearerTokenCredentialPolicy(
209+
self.aad_credentials,
210+
account_scope=account_scope,
211+
override_scope=scope_override if scope_override else None
212+
)
213213
self._enable_diagnostics_logging = kwargs.pop("enable_diagnostics_logging", False)
214214
policies = [
215215
HeadersPolicy(**kwargs),

sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,32 @@
44
# license information.
55
# -------------------------------------------------------------------------
66

7-
from typing import Any, MutableMapping, TypeVar, cast
7+
from typing import Any, MutableMapping, TypeVar, cast, Optional
88

99
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy
1010
from azure.core.pipeline import PipelineRequest
1111
from azure.core.pipeline.transport import HttpRequest as LegacyHttpRequest
1212
from azure.core.rest import HttpRequest
1313
from azure.core.credentials import AccessToken
14+
from azure.core.exceptions import HttpResponseError
1415

1516
from ..http_constants import HttpHeaders
17+
from .._constants import _Constants as Constants
1618

1719
HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest)
1820

19-
21+
# NOTE: This class accesses protected members (_scopes, _token) of the parent class
22+
# to implement fallback and scope-switching logic not exposed by the public API.
23+
# Composition was considered, but still required accessing protected members, so inheritance is retained
24+
# for seamless Azure SDK pipeline integration.
2025
class AsyncCosmosBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy):
26+
AadDefaultScope = Constants.AAD_DEFAULT_SCOPE
27+
28+
def __init__(self, credential, account_scope: str, override_scope: Optional[str] = None):
29+
self._account_scope = account_scope
30+
self._override_scope = override_scope
31+
self._current_scope = override_scope or account_scope
32+
super().__init__(credential, self._current_scope)
2133

2234
@staticmethod
2335
def _update_headers(headers: MutableMapping[str, str], token: str) -> None:
@@ -35,9 +47,26 @@ async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
3547
:type request: ~azure.core.pipeline.PipelineRequest
3648
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
3749
"""
38-
await super().on_request(request)
39-
# The None-check for self._token is done in the parent on_request
40-
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)
50+
tried_fallback = False
51+
while True:
52+
try:
53+
await super().on_request(request)
54+
# The None-check for self._token is done in the parent on_request
55+
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)
56+
break
57+
except HttpResponseError as ex:
58+
# Only fallback if not using override, not already tried, and error is AADSTS500011
59+
if (
60+
not self._override_scope and
61+
not tried_fallback and
62+
self._current_scope != self.AadDefaultScope and
63+
"AADSTS500011" in str(ex)
64+
):
65+
self._scopes = (self.AadDefaultScope,)
66+
self._current_scope = self.AadDefaultScope
67+
tried_fallback = True
68+
continue
69+
raise
4170

4271
async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
4372
"""Acquire a token from the credential and authorize the request with it.
@@ -48,6 +77,7 @@ async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *sc
4877
:param ~azure.core.pipeline.PipelineRequest request: the request
4978
:param str scopes: required scopes of authentication
5079
"""
80+
5181
await super().authorize_request(request, *scopes, **kwargs)
5282
# The None-check for self._token is done in the parent authorize_request
5383
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)

sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ def __init__( # pylint: disable=too-many-statements
137137
The connection policy for the client.
138138
:param documents.ConsistencyLevel consistency_level:
139139
The default consistency policy for client operations.
140-
141140
"""
142141
self.client_id = str(uuid.uuid4())
143142
self.url_connection = url_connection
@@ -212,11 +211,12 @@ def __init__( # pylint: disable=too-many-statements
212211
credentials_policy = None
213212
if self.aad_credentials:
214213
scope_override = os.environ.get(Constants.AAD_SCOPE_OVERRIDE, "")
215-
if scope_override:
216-
scope = scope_override
217-
else:
218-
scope = base.create_scope_from_url(self.url_connection)
219-
credentials_policy = AsyncCosmosBearerTokenCredentialPolicy(self.aad_credentials, scope)
214+
account_scope = base.create_scope_from_url(self.url_connection)
215+
credentials_policy = AsyncCosmosBearerTokenCredentialPolicy(
216+
self.aad_credentials,
217+
account_scope,
218+
scope_override
219+
)
220220
self._enable_diagnostics_logging = kwargs.pop("enable_diagnostics_logging", False)
221221
policies = [
222222
HeadersPolicy(**kwargs),

sdk/cosmos/azure-cosmos/tests/test_aad.py

Lines changed: 104 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import azure.cosmos.cosmos_client as cosmos_client
1515
import test_config
1616
from azure.cosmos import DatabaseProxy, ContainerProxy, exceptions
17-
17+
from azure.core.exceptions import HttpResponseError
1818

1919
def _remove_padding(encoded_string):
2020
while encoded_string.endswith("="):
@@ -34,7 +34,6 @@ def get_test_item(num):
3434

3535

3636
class CosmosEmulatorCredential(object):
37-
3837
def get_token(self, *scopes, **kwargs):
3938
# type: (*str, **Any) -> AccessToken
4039
"""Request an access token for the emulator. Based on Azure Core's Access Token Credential.
@@ -118,33 +117,126 @@ def test_aad_credentials(self):
118117
assert e.status_code == 403
119118
print("403 error assertion success")
120119

121-
def test_aad_scope_override(self):
122-
override_scope = "https://my.custom.scope/.default"
123-
os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = override_scope
124120

121+
def _run_with_scope_capture(self, credential_cls, action, *args, **kwargs):
125122
scopes_captured = []
126-
original_get_token = CosmosEmulatorCredential.get_token
123+
original_get_token = credential_cls.get_token
127124

128125
def capturing_get_token(self, *scopes, **kwargs):
129126
scopes_captured.extend(scopes)
130127
return original_get_token(self, *scopes, **kwargs)
131128

132-
CosmosEmulatorCredential.get_token = capturing_get_token
133-
129+
credential_cls.get_token = capturing_get_token
134130
try:
131+
result = action(scopes_captured, *args, **kwargs)
132+
finally:
133+
credential_cls.get_token = original_get_token
134+
return scopes_captured, result
135+
136+
def test_override_scope_no_fallback(self):
137+
"""When override scope is provided, only that scope is used and no fallback occurs."""
138+
override_scope = "https://my.custom.scope/.default"
139+
os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = override_scope
140+
141+
def action(scopes_captured):
135142
credential = CosmosEmulatorCredential()
136143
client = cosmos_client.CosmosClient(self.host, credential)
137144
db = client.get_database_client(self.configs.TEST_DATABASE_ID)
138145
container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID)
139-
container.create_item(get_test_item(1))
140-
assert override_scope in scopes_captured
146+
container.create_item(get_test_item(10))
147+
return container
148+
149+
scopes, container = self._run_with_scope_capture(CosmosEmulatorCredential, action)
150+
try:
151+
assert all(scope == override_scope for scope in scopes), f"Expected only override scope(s), got: {scopes}"
141152
finally:
142-
CosmosEmulatorCredential.get_token = original_get_token
143153
del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"]
144154
try:
145-
container.delete_item(item='Item_1', partition_key='pk')
155+
container.delete_item(item='Item_10', partition_key='pk')
146156
except Exception:
147157
pass
148158

159+
def test_override_scope_auth_error_no_fallback(self):
160+
"""When override scope is provided and auth fails, no fallback to other scopes occurs."""
161+
override_scope = "https://my.custom.scope/.default"
162+
os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = override_scope
163+
164+
class FailingCredential(CosmosEmulatorCredential):
165+
def get_token(self, *scopes, **kwargs):
166+
raise Exception("Simulated auth error for override scope")
167+
168+
def action(scopes_captured):
169+
with pytest.raises(Exception) as excinfo:
170+
client = cosmos_client.CosmosClient(self.host, FailingCredential())
171+
db = client.get_database_client(self.configs.TEST_DATABASE_ID)
172+
container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID)
173+
container.create_item(get_test_item(11))
174+
assert "Simulated auth error" in str(excinfo.value)
175+
return None
176+
177+
scopes, _ = self._run_with_scope_capture(FailingCredential, action)
178+
try:
179+
assert scopes == [override_scope], f"Expected only override scope, got: {scopes}"
180+
finally:
181+
del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"]
182+
183+
def test_account_scope_only(self):
184+
"""When account scope is provided, only that scope is used."""
185+
account_scope = "https://localhost/.default"
186+
os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = ""
187+
188+
def action(scopes_captured):
189+
credential = CosmosEmulatorCredential()
190+
client = cosmos_client.CosmosClient(self.host, credential)
191+
db = client.get_database_client(self.configs.TEST_DATABASE_ID)
192+
container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID)
193+
container.create_item(get_test_item(12))
194+
return container
195+
196+
scopes, container = self._run_with_scope_capture(CosmosEmulatorCredential, action)
197+
try:
198+
# Accept multiple calls, but only the account_scope should be used
199+
assert all(scope == account_scope for scope in scopes), f"Expected only account scope, got: {scopes}"
200+
finally:
201+
try:
202+
container.delete_item(item='Item_12', partition_key='pk')
203+
except Exception:
204+
pass
205+
206+
def test_account_scope_fallback_on_error(self):
207+
"""When account scope is provided and auth fails, fallback to default scope occurs."""
208+
account_scope = "https://localhost/.default"
209+
fallback_scope = "https://cosmos.azure.com/.default"
210+
os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = ""
211+
212+
class FallbackCredential(CosmosEmulatorCredential):
213+
def __init__(self):
214+
self.call_count = 0
215+
216+
def get_token(self, *scopes, **kwargs):
217+
self.call_count += 1
218+
if self.call_count == 1:
219+
raise HttpResponseError(message="AADSTS500011: Simulated error for fallback")
220+
return super().get_token(*scopes, **kwargs)
221+
222+
def action(scopes_captured):
223+
credential = FallbackCredential()
224+
client = cosmos_client.CosmosClient(self.host, credential)
225+
db = client.get_database_client(self.configs.TEST_DATABASE_ID)
226+
container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID)
227+
container.create_item(get_test_item(13))
228+
return container
229+
230+
scopes, container = self._run_with_scope_capture(FallbackCredential, action)
231+
try:
232+
# Accept multiple calls, but the first should be account_scope, and fallback_scope should appear after error
233+
assert account_scope in scopes and fallback_scope in scopes, f"Expected fallback to default scope, got: {scopes}"
234+
finally:
235+
try:
236+
container.delete_item(item='Item_13', partition_key='pk')
237+
except Exception:
238+
pass
239+
240+
149241
if __name__ == "__main__":
150242
unittest.main()

0 commit comments

Comments
 (0)