Skip to content

Commit e4e0904

Browse files
committed
Updates based on code reviews
1 parent ccb3d37 commit e4e0904

File tree

7 files changed

+29
-14
lines changed

7 files changed

+29
-14
lines changed

sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,24 @@
55
# -------------------------------------------------------------------------
66
from typing import TypeVar, Any, MutableMapping, cast, Optional
77

8+
from azure.cosmos import _constants as Constants
89
from azure.core.pipeline import PipelineRequest
910
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
1011
from azure.core.pipeline.transport import HttpRequest as LegacyHttpRequest
1112
from azure.core.rest import HttpRequest
1213
from azure.core.credentials import AccessToken
14+
from azure.core.exceptions import HttpResponseError
1315

1416
from .http_constants import HttpHeaders
1517

1618
HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest)
1719

18-
20+
# NOTE: This class accesses protected members (_scopes, _token) of the parent class
21+
# to implement fallback and scope-switching logic not exposed by the public API.
22+
# Composition was considered, but still required accessing protected members, so inheritance is retained
23+
# for seamless Azure SDK pipeline integration.
1924
class CosmosBearerTokenCredentialPolicy(BearerTokenCredentialPolicy):
20-
AadDefaultScope = "https://cosmos.azure.com/.default"
25+
AadDefaultScope = Constants._Constants.AAD_DEFAULT_SCOPE
2126

2227
def __init__(self, credential, account_scope: str, override_scope: Optional[str] = None):
2328
self._account_scope = account_scope
@@ -48,7 +53,7 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
4853
# The None-check for self._token is done in the parent on_request
4954
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)
5055
break
51-
except Exception as ex:
56+
except HttpResponseError as ex:
5257
# Only fallback if not using override, not already tried, and error is AADSTS500011
5358
if (
5459
not self._override_scope and

sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class _Constants:
5656
CIRCUIT_BREAKER_ENABLED_CONFIG: str = "AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"
5757
CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT: str = "False"
5858
AAD_SCOPE_OVERRIDE: str = "AZURE_COSMOS_AAD_SCOPE_OVERRIDE"
59+
AAD_DEFAULT_SCOPE: str = "https://cosmos.azure.com/.default"
5960

6061
# Database Account Retry Policy constants
6162
AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES: str = "AZURE_COSMOS_HEALTH_CHECK_MAX_RETRIES"

sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def __init__( # pylint: disable=too-many-statements
118118
auth: CredentialDict,
119119
connection_policy: Optional[ConnectionPolicy] = None,
120120
consistency_level: Optional[str] = None,
121+
audience: Optional[str] = None,
121122
**kwargs: Any
122123
) -> None:
123124
"""
@@ -132,7 +133,8 @@ def __init__( # pylint: disable=too-many-statements
132133
The connection policy for the client.
133134
:param documents.ConsistencyLevel consistency_level:
134135
The default consistency policy for client operations.
135-
136+
:param str audience:
137+
The overridden scope value.
136138
"""
137139
self.client_id = str(uuid.uuid4())
138140
self.url_connection = url_connection
@@ -204,7 +206,7 @@ def __init__( # pylint: disable=too-many-statements
204206

205207
credentials_policy = None
206208
if self.aad_credentials:
207-
scope_override = os.environ.get(Constants.AAD_SCOPE_OVERRIDE, "")
209+
scope_override = audience or os.environ.get(Constants.AAD_SCOPE_OVERRIDE, "")
208210
account_scope = base.create_scope_from_url(self.url_connection)
209211
credentials_policy = CosmosBearerTokenCredentialPolicy(
210212
self.aad_credentials,

sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,24 @@
66

77
from typing import Any, MutableMapping, TypeVar, cast, Optional
88

9+
from azure.cosmos import _constants as Constants
910
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy
1011
from azure.core.pipeline import PipelineRequest
1112
from azure.core.pipeline.transport import HttpRequest as LegacyHttpRequest
1213
from azure.core.rest import HttpRequest
1314
from azure.core.credentials import AccessToken
15+
from azure.core.exceptions import HttpResponseError
1416

1517
from ..http_constants import HttpHeaders
1618

1719
HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest)
1820

19-
21+
# NOTE: This class accesses protected members (_scopes, _token) of the parent class
22+
# to implement fallback and scope-switching logic not exposed by the public API.
23+
# Composition was considered, but still required accessing protected members, so inheritance is retained
24+
# for seamless Azure SDK pipeline integration.
2025
class AsyncCosmosBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy):
21-
AadDefaultScope = "https://cosmos.azure.com/.default"
26+
AadDefaultScope = Constants._Constants.AAD_DEFAULT_SCOPE
2227

2328
def __init__(self, credential, account_scope: str, override_scope: Optional[str] = None):
2429
self._account_scope = account_scope
@@ -49,7 +54,7 @@ async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
4954
# The None-check for self._token is done in the parent on_request
5055
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)
5156
break
52-
except Exception as ex:
57+
except HttpResponseError as ex:
5358
# Only fallback if not using override, not already tried, and error is AADSTS500011
5459
if (
5560
not self._override_scope and

sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def __init__( # pylint: disable=too-many-statements
124124
auth: CredentialDict,
125125
connection_policy: Optional[ConnectionPolicy] = None,
126126
consistency_level: Optional[str] = None,
127+
audience: Optional[str] = None,
127128
**kwargs: Any
128129
) -> None:
129130
"""
@@ -138,7 +139,8 @@ def __init__( # pylint: disable=too-many-statements
138139
The connection policy for the client.
139140
:param documents.ConsistencyLevel consistency_level:
140141
The default consistency policy for client operations.
141-
142+
:param str audience:
143+
The overridden scope value.
142144
"""
143145
self.client_id = str(uuid.uuid4())
144146
self.url_connection = url_connection
@@ -212,7 +214,7 @@ def __init__( # pylint: disable=too-many-statements
212214

213215
credentials_policy = None
214216
if self.aad_credentials:
215-
scope_override = os.environ.get(Constants.AAD_SCOPE_OVERRIDE, "")
217+
scope_override = audience or os.environ.get(Constants.AAD_SCOPE_OVERRIDE, "")
216218
account_scope = base.create_scope_from_url(self.url_connection)
217219
credentials_policy = AsyncCosmosBearerTokenCredentialPolicy(
218220
self.aad_credentials,

sdk/cosmos/azure-cosmos/tests/test_aad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import azure.cosmos.cosmos_client as cosmos_client
1515
import test_config
1616
from azure.cosmos import DatabaseProxy, ContainerProxy, exceptions
17-
17+
from azure.core.exceptions import HttpResponseError
1818

1919
def _remove_padding(encoded_string):
2020
while encoded_string.endswith("="):
@@ -216,7 +216,7 @@ def __init__(self):
216216
def get_token(self, *scopes, **kwargs):
217217
self.call_count += 1
218218
if self.call_count == 1:
219-
raise Exception("AADSTS500011: Simulated error for fallback")
219+
raise HttpResponseError(message="AADSTS500011: Simulated error for fallback")
220220
return super().get_token(*scopes, **kwargs)
221221

222222
def action(scopes_captured):

sdk/cosmos/azure-cosmos/tests/test_aad_async.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import test_config
1515
from azure.cosmos import exceptions
1616
from azure.cosmos.aio import CosmosClient, DatabaseProxy, ContainerProxy
17-
17+
from azure.core.exceptions import HttpResponseError
1818

1919
def _remove_padding(encoded_string):
2020
while encoded_string.endswith("="):
@@ -243,7 +243,7 @@ def __init__(self):
243243
async def get_token(self, *scopes, **kwargs):
244244
self.call_count += 1
245245
if self.call_count == 1:
246-
raise Exception("AADSTS500011: Simulated error for fallback")
246+
raise HttpResponseError(message="AADSTS500011: Simulated error for fallback")
247247
return await super().get_token(*scopes, **kwargs)
248248

249249
async def action(scopes_captured):

0 commit comments

Comments
 (0)