Skip to content

Commit 217379a

Browse files
committed
Update AAD fallback mechanism.
1 parent 5eba99f commit 217379a

File tree

6 files changed

+290
-46
lines changed

6 files changed

+290
-46
lines changed

sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@
1717

1818

1919
class CosmosBearerTokenCredentialPolicy(BearerTokenCredentialPolicy):
20+
AadDefaultScope = "https://cosmos.azure.com/.default"
21+
22+
def __init__(self, credential, account_scope: str, override_scope: str = None):
23+
self._account_scope = account_scope
24+
self._override_scope = override_scope
25+
self._current_scope = override_scope or account_scope
26+
super().__init__(credential, self._current_scope)
2027

2128
@staticmethod
2229
def _update_headers(headers: MutableMapping[str, str], token: str) -> None:
@@ -34,6 +41,7 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
3441
3542
:param ~azure.core.pipeline.PipelineRequest request: the request
3643
"""
44+
self.authorize_request(request)
3745
super().on_request(request)
3846
# The None-check for self._token is done in the parent on_request
3947
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)
@@ -47,6 +55,22 @@ def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes:
4755
:param ~azure.core.pipeline.PipelineRequest request: the request
4856
:param str scopes: required scopes of authentication
4957
"""
50-
super().authorize_request(request, *scopes, **kwargs)
51-
# The None-check for self._token is done in the parent authorize_request
52-
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)
58+
tried_fallback = False
59+
while True:
60+
try:
61+
super().authorize_request(request, self._current_scope, **kwargs)
62+
# The None-check for self._token is done in the parent authorize_request
63+
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)
64+
break
65+
except Exception as ex:
66+
# Only fallback if not using override, not already tried, and error is AADSTS500011
67+
if (
68+
not self._override_scope and
69+
not tried_fallback and
70+
self._current_scope != self.AadDefaultScope and
71+
"AADSTS500011" in str(ex)
72+
):
73+
self._current_scope = self.AadDefaultScope
74+
tried_fallback = True
75+
continue
76+
raise

sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,12 @@ def __init__( # pylint: disable=too-many-statements
202202
credentials_policy = None
203203
if self.aad_credentials:
204204
scope_override = os.environ.get(Constants.AAD_SCOPE_OVERRIDE, "")
205-
if scope_override:
206-
scope = scope_override
207-
else:
208-
scope = base.create_scope_from_url(self.url_connection)
209-
credentials_policy = CosmosBearerTokenCredentialPolicy(self.aad_credentials, scope)
205+
account_scope = base.create_scope_from_url(self.url_connection)
206+
credentials_policy = CosmosBearerTokenCredentialPolicy(
207+
self.aad_credentials,
208+
account_scope=account_scope,
209+
override_scope=scope_override if scope_override else None
210+
)
210211

211212
policies = [
212213
HeadersPolicy(**kwargs),

sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@
1818

1919

2020
class AsyncCosmosBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy):
21+
AadDefaultScope = "https://cosmos.azure.com/.default"
22+
23+
def __init__(self, credential, account_scope: str, override_scope: str):
24+
self._account_scope = account_scope
25+
self._override_scope = override_scope
26+
self._current_scope = override_scope or account_scope
27+
super().__init__(credential, self._current_scope)
2128

2229
@staticmethod
2330
def _update_headers(headers: MutableMapping[str, str], token: str) -> None:
@@ -35,6 +42,7 @@ async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
3542
:type request: ~azure.core.pipeline.PipelineRequest
3643
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
3744
"""
45+
await self.authorize_request(request)
3846
await super().on_request(request)
3947
# The None-check for self._token is done in the parent on_request
4048
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)
@@ -48,6 +56,22 @@ async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *sc
4856
:param ~azure.core.pipeline.PipelineRequest request: the request
4957
:param str scopes: required scopes of authentication
5058
"""
51-
await super().authorize_request(request, *scopes, **kwargs)
52-
# The None-check for self._token is done in the parent authorize_request
53-
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)
59+
tried_fallback = False
60+
while True:
61+
try:
62+
await super().authorize_request(request, self._current_scope, **kwargs)
63+
# The None-check for self._token is done in the parent authorize_request
64+
self._update_headers(request.http_request.headers, cast(AccessToken, self._token).token)
65+
break
66+
except Exception as ex:
67+
# Only fallback if not using override, not already tried, and error is AADSTS500011
68+
if (
69+
not self._override_scope and
70+
not tried_fallback and
71+
self._current_scope != self.AadDefaultScope and
72+
"AADSTS500011" in str(ex)
73+
):
74+
self._current_scope = self.AadDefaultScope
75+
tried_fallback = True
76+
continue
77+
raise

sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,12 @@ def __init__( # pylint: disable=too-many-statements
212212
credentials_policy = None
213213
if self.aad_credentials:
214214
scope_override = os.environ.get(Constants.AAD_SCOPE_OVERRIDE, "")
215-
if scope_override:
216-
scope = scope_override
217-
else:
218-
scope = base.create_scope_from_url(self.url_connection)
219-
credentials_policy = AsyncCosmosBearerTokenCredentialPolicy(self.aad_credentials, scope)
215+
account_scope = base.create_scope_from_url(self.url_connection)
216+
credentials_policy = AsyncCosmosBearerTokenCredentialPolicy(
217+
self.aad_credentials,
218+
account_scope,
219+
scope_override
220+
)
220221

221222
policies = [
222223
HeadersPolicy(**kwargs),

sdk/cosmos/azure-cosmos/tests/test_aad.py

Lines changed: 103 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ def get_test_item(num):
3434

3535

3636
class CosmosEmulatorCredential(object):
37-
3837
def get_token(self, *scopes, **kwargs):
3938
# type: (*str, **Any) -> AccessToken
4039
"""Request an access token for the emulator. Based on Azure Core's Access Token Credential.
@@ -118,33 +117,126 @@ def test_aad_credentials(self):
118117
assert e.status_code == 403
119118
print("403 error assertion success")
120119

121-
def test_aad_scope_override(self):
122-
override_scope = "https://my.custom.scope/.default"
123-
os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = override_scope
124120

121+
def _run_with_scope_capture(self, credential_cls, action, *args, **kwargs):
125122
scopes_captured = []
126-
original_get_token = CosmosEmulatorCredential.get_token
123+
original_get_token = credential_cls.get_token
127124

128125
def capturing_get_token(self, *scopes, **kwargs):
129126
scopes_captured.extend(scopes)
130127
return original_get_token(self, *scopes, **kwargs)
131128

132-
CosmosEmulatorCredential.get_token = capturing_get_token
133-
129+
credential_cls.get_token = capturing_get_token
134130
try:
131+
result = action(scopes_captured, *args, **kwargs)
132+
finally:
133+
credential_cls.get_token = original_get_token
134+
return scopes_captured, result
135+
136+
def test_override_scope_no_fallback(self):
137+
"""When override scope is provided, only that scope is used and no fallback occurs."""
138+
override_scope = "https://my.custom.scope/.default"
139+
os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = override_scope
140+
141+
def action(scopes_captured):
135142
credential = CosmosEmulatorCredential()
136143
client = cosmos_client.CosmosClient(self.host, credential)
137144
db = client.get_database_client(self.configs.TEST_DATABASE_ID)
138145
container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID)
139-
container.create_item(get_test_item(1))
140-
assert override_scope in scopes_captured
146+
container.create_item(get_test_item(10))
147+
return container
148+
149+
scopes, container = self._run_with_scope_capture(CosmosEmulatorCredential, action)
150+
try:
151+
assert all(scope == override_scope for scope in scopes), f"Expected only override scope(s), got: {scopes}"
152+
finally:
153+
del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"]
154+
try:
155+
container.delete_item(item='Item_10', partition_key='pk')
156+
except Exception:
157+
pass
158+
159+
def test_override_scope_auth_error_no_fallback(self):
160+
"""When override scope is provided and auth fails, no fallback to other scopes occurs."""
161+
override_scope = "https://my.custom.scope/.default"
162+
os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = override_scope
163+
164+
class FailingCredential(CosmosEmulatorCredential):
165+
def get_token(self, *scopes, **kwargs):
166+
raise Exception("Simulated auth error for override scope")
167+
168+
def action(scopes_captured):
169+
with pytest.raises(Exception) as excinfo:
170+
client = cosmos_client.CosmosClient(self.host, FailingCredential())
171+
db = client.get_database_client(self.configs.TEST_DATABASE_ID)
172+
container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID)
173+
container.create_item(get_test_item(11))
174+
assert "Simulated auth error" in str(excinfo.value)
175+
return None
176+
177+
scopes, _ = self._run_with_scope_capture(FailingCredential, action)
178+
try:
179+
assert scopes == [override_scope], f"Expected only override scope, got: {scopes}"
141180
finally:
142-
CosmosEmulatorCredential.get_token = original_get_token
143181
del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"]
182+
183+
def test_account_scope_only(self):
184+
"""When account scope is provided, only that scope is used."""
185+
account_scope = "https://localhost/.default"
186+
os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = ""
187+
188+
def action(scopes_captured):
189+
credential = CosmosEmulatorCredential()
190+
client = cosmos_client.CosmosClient(self.host, credential)
191+
db = client.get_database_client(self.configs.TEST_DATABASE_ID)
192+
container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID)
193+
container.create_item(get_test_item(12))
194+
return container
195+
196+
scopes, container = self._run_with_scope_capture(CosmosEmulatorCredential, action)
197+
try:
198+
# Accept multiple calls, but only the account_scope should be used
199+
assert all(scope == account_scope for scope in scopes), f"Expected only account scope, got: {scopes}"
200+
finally:
144201
try:
145-
container.delete_item(item='Item_1', partition_key='pk')
202+
container.delete_item(item='Item_12', partition_key='pk')
146203
except Exception:
147204
pass
148205

206+
def test_account_scope_fallback_on_error(self):
207+
"""When account scope is provided and auth fails, fallback to default scope occurs."""
208+
account_scope = "https://localhost/.default"
209+
fallback_scope = "https://cosmos.azure.com/.default"
210+
os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"] = ""
211+
212+
class FallbackCredential(CosmosEmulatorCredential):
213+
def __init__(self):
214+
self.call_count = 0
215+
216+
def get_token(self, *scopes, **kwargs):
217+
self.call_count += 1
218+
if self.call_count == 1:
219+
raise Exception("AADSTS500011: Simulated error for fallback")
220+
return super().get_token(*scopes, **kwargs)
221+
222+
def action(scopes_captured):
223+
credential = FallbackCredential()
224+
client = cosmos_client.CosmosClient(self.host, credential)
225+
db = client.get_database_client(self.configs.TEST_DATABASE_ID)
226+
container = db.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID)
227+
container.create_item(get_test_item(13))
228+
return container
229+
230+
scopes, container = self._run_with_scope_capture(FallbackCredential, action)
231+
try:
232+
# Accept multiple calls, but the first should be account_scope, and fallback_scope should appear after error
233+
assert account_scope in scopes and fallback_scope in scopes, f"Expected fallback to default scope, got: {scopes}"
234+
finally:
235+
try:
236+
container.delete_item(item='Item_13', partition_key='pk')
237+
except Exception:
238+
pass
239+
240+
149241
if __name__ == "__main__":
150242
unittest.main()

0 commit comments

Comments
 (0)