Skip to content

Commit c17568f

Browse files
authored
[Cosmos] AAD authentication async client (Azure#23717)
* working authentication to get database account * working aad authentication for sync client with sample * readme and changelog * pylint and better comments on sample * working async aad * Delete access_cosmos_with_aad.py snuck its way into the async PR * Update _auth_policies.py * small changes * Update _cosmos_client_connection.py * removing changes made in sync * Update _auth_policy_async.py * Update _auth_policy_async.py * Update _auth_policy_async.py * added licenses to samples
1 parent 93a8dda commit c17568f

23 files changed

+392
-22
lines changed

sdk/cosmos/azure-cosmos/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
### 4.3.0b4 (Unreleased)
44

55
#### Features Added
6+
- Added support for AAD authentication for the async client
67
- Added support for AAD authentication for the sync client
78

89
### 4.3.0b3 (2022-03-10)

sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,6 @@
6161
from ._auth_policy import CosmosBearerTokenCredentialPolicy
6262

6363
ClassType = TypeVar("ClassType")
64-
65-
6664
# pylint: disable=protected-access
6765

6866

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See LICENSE.txt in the project root for
4+
# license information.
5+
# -------------------------------------------------------------------------
6+
import asyncio
7+
import time
8+
9+
from typing import Any, Awaitable, Optional, Dict, Union
10+
from azure.core.pipeline.policies import AsyncHTTPPolicy
11+
from azure.core.credentials import AccessToken
12+
from azure.core.pipeline import PipelineRequest, PipelineResponse
13+
from azure.cosmos import http_constants
14+
15+
16+
async def await_result(func, *args, **kwargs):
17+
"""If func returns an awaitable, await it."""
18+
result = func(*args, **kwargs)
19+
if hasattr(result, "__await__"):
20+
# type ignore on await: https://github.com/python/mypy/issues/7587
21+
return await result # type: ignore
22+
return result
23+
24+
25+
class _AsyncCosmosBearerTokenCredentialPolicyBase(object):
26+
"""Base class for a Bearer Token Credential Policy.
27+
28+
:param credential: The credential.
29+
:type credential: ~azure.core.credentials.TokenCredential
30+
:param str scopes: Lets you specify the type of access needed.
31+
"""
32+
33+
def __init__(self, credential, *scopes, **kwargs): # pylint:disable=unused-argument
34+
# type: (TokenCredential, *str, **Any) -> None
35+
super(_AsyncCosmosBearerTokenCredentialPolicyBase, self).__init__()
36+
self._scopes = scopes
37+
self._credential = credential
38+
self._token = None # type: Optional[AccessToken]
39+
self._lock = asyncio.Lock()
40+
41+
@staticmethod
42+
def _enforce_https(request):
43+
# type: (PipelineRequest) -> None
44+
45+
# move 'enforce_https' from options to context so it persists
46+
# across retries but isn't passed to a transport implementation
47+
option = request.context.options.pop("enforce_https", None)
48+
49+
# True is the default setting; we needn't preserve an explicit opt in to the default behavior
50+
if option is False:
51+
request.context["enforce_https"] = option
52+
53+
enforce_https = request.context.get("enforce_https", True)
54+
if enforce_https and not request.http_request.url.lower().startswith("https"):
55+
raise ValueError(
56+
"Bearer token authentication is not permitted for non-TLS protected (non-https) URLs."
57+
)
58+
59+
@staticmethod
60+
def _update_headers(headers, token):
61+
# type: (Dict[str, str], str) -> None
62+
"""Updates the Authorization header with the cosmos signature and bearer token.
63+
This is the main method that differentiates this policy from core's BearerTokenCredentialPolicy and works
64+
to properly sign the authorization header for Cosmos' REST API. For more information:
65+
https://docs.microsoft.com/rest/api/cosmos-db/access-control-on-cosmosdb-resources#authorization-header
66+
67+
:param dict headers: The HTTP Request headers
68+
:param str token: The OAuth token.
69+
"""
70+
headers[http_constants.HttpHeaders.Authorization] = "type=aad&ver=1.0&sig={}".format(token)
71+
72+
@property
73+
def _need_new_token(self) -> bool:
74+
return not self._token or self._token.expires_on - time.time() < 300
75+
76+
77+
class AsyncCosmosBearerTokenCredentialPolicy(_AsyncCosmosBearerTokenCredentialPolicyBase, AsyncHTTPPolicy):
78+
"""Adds a bearer token Authorization header to requests.
79+
80+
:param credential: The credential.
81+
:type credential: ~azure.core.TokenCredential
82+
:param str scopes: Lets you specify the type of access needed.
83+
:raises ValueError: If https_enforce does not match with endpoint being used.
84+
"""
85+
86+
async def on_request(self, request: "PipelineRequest") -> None: # pylint:disable=invalid-overridden-method
87+
"""Adds a bearer token Authorization header to request and sends request to next policy.
88+
89+
:param request: The pipeline request object to be modified.
90+
:type request: ~azure.core.pipeline.PipelineRequest
91+
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
92+
"""
93+
self._enforce_https(request) # pylint:disable=protected-access
94+
95+
if self._token is None or self._need_new_token:
96+
async with self._lock:
97+
# double check because another coroutine may have acquired a token while we waited to acquire the lock
98+
if self._token is None or self._need_new_token:
99+
self._token = await self._credential.get_token(*self._scopes)
100+
self._update_headers(request.http_request.headers, self._token.token)
101+
102+
async def authorize_request(self, request: "PipelineRequest", *scopes: str, **kwargs: "Any") -> None:
103+
"""Acquire a token from the credential and authorize the request with it.
104+
105+
Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
106+
authorize future requests.
107+
108+
:param ~azure.core.pipeline.PipelineRequest request: the request
109+
:param str scopes: required scopes of authentication
110+
"""
111+
async with self._lock:
112+
self._token = await self._credential.get_token(*scopes, **kwargs)
113+
self._update_headers(request.http_request.headers, self._token.token)
114+
115+
async def send(self, request: "PipelineRequest") -> "PipelineResponse":
116+
"""Authorize request with a bearer token and send it to the next policy
117+
118+
:param request: The pipeline request object
119+
:type request: ~azure.core.pipeline.PipelineRequest
120+
"""
121+
await await_result(self.on_request, request)
122+
try:
123+
response = await self.next.send(request)
124+
await await_result(self.on_response, request, response)
125+
except Exception: # pylint:disable=broad-except
126+
handled = await await_result(self.on_exception, request)
127+
if not handled:
128+
raise
129+
else:
130+
if response.http_response.status_code == 401:
131+
self._token = None # any cached token is invalid
132+
if "WWW-Authenticate" in response.http_response.headers:
133+
request_authorized = await self.on_challenge(request, response)
134+
if request_authorized:
135+
try:
136+
response = await self.next.send(request)
137+
await await_result(self.on_response, request, response)
138+
except Exception: # pylint:disable=broad-except
139+
handled = await await_result(self.on_exception, request)
140+
if not handled:
141+
raise
142+
143+
return response
144+
145+
async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool:
146+
"""Authorize request according to an authentication challenge
147+
148+
This method is called when the resource provider responds 401 with a WWW-Authenticate header.
149+
150+
:param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge
151+
:param ~azure.core.pipeline.PipelineResponse response: the resource provider's response
152+
:returns: a bool indicating whether the policy should send the request
153+
"""
154+
# pylint:disable=unused-argument,no-self-use
155+
return False
156+
157+
def on_response(self, request: PipelineRequest, response: PipelineResponse) -> Union[None, Awaitable[None]]:
158+
"""Executed after the request comes back from the next policy.
159+
160+
:param request: Request to be modified after returning from the policy.
161+
:type request: ~azure.core.pipeline.PipelineRequest
162+
:param response: Pipeline response object
163+
:type response: ~azure.core.pipeline.PipelineResponse
164+
"""
165+
166+
def on_exception(self, request: PipelineRequest) -> None:
167+
"""Executed when an exception is raised while executing the next policy.
168+
169+
This method is executed inside the exception handler.
170+
171+
:param request: The Pipeline request object
172+
:type request: ~azure.core.pipeline.PipelineRequest
173+
"""
174+
# pylint: disable=no-self-use,unused-argument
175+
return

sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,12 @@
5858
from .. import _session
5959
from .. import _utils
6060
from ..partition_key import _Undefined, _Empty
61+
from ._auth_policy_async import AsyncCosmosBearerTokenCredentialPolicy
6162

6263
ClassType = TypeVar("ClassType")
6364
# pylint: disable=protected-access
6465

66+
6567
class CosmosClientConnection(object): # pylint: disable=too-many-public-methods,too-many-instance-attributes
6668
"""Represents a document client.
6769
@@ -113,9 +115,11 @@ def __init__(
113115

114116
self.master_key = None
115117
self.resource_tokens = None
118+
self.aad_credentials = None
116119
if auth is not None:
117120
self.master_key = auth.get("masterKey")
118121
self.resource_tokens = auth.get("resourceTokens")
122+
self.aad_credentials = auth.get("clientSecretCredential")
119123

120124
if auth.get("permissionFeed"):
121125
self.resource_tokens = {}
@@ -176,12 +180,18 @@ def __init__(
176180

177181
self._user_agent = _utils.get_user_agent_async()
178182

183+
credentials_policy = None
184+
if self.aad_credentials:
185+
scopes = base.create_scope_from_url(self.url_connection)
186+
credentials_policy = AsyncCosmosBearerTokenCredentialPolicy(self.aad_credentials, scopes)
187+
179188
policies = [
180189
HeadersPolicy(**kwargs),
181190
ProxyPolicy(proxies=proxies),
182191
UserAgentPolicy(base_user_agent=self._user_agent, **kwargs),
183192
ContentDecodePolicy(),
184193
retry_policy,
194+
credentials_policy,
185195
CustomHookPolicy(**kwargs),
186196
NetworkTraceLoggingPolicy(**kwargs),
187197
DistributedTracingPolicy(**kwargs),

sdk/cosmos/azure-cosmos/samples/access_cosmos_with_aad.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See LICENSE.txt in the project root for
4+
# license information.
5+
# -------------------------------------------------------------------------
16
from azure.cosmos import CosmosClient
27
import azure.cosmos.exceptions as exceptions
38
from azure.cosmos.partition_key import PartitionKey
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See LICENSE.txt in the project root for
4+
# license information.
5+
# -------------------------------------------------------------------------
6+
from azure.cosmos.aio import CosmosClient
7+
import azure.cosmos.exceptions as exceptions
8+
from azure.cosmos.partition_key import PartitionKey
9+
from azure.identity.aio import ClientSecretCredential, DefaultAzureCredential
10+
import config
11+
import asyncio
12+
13+
# ----------------------------------------------------------------------------------------------------------
14+
# Prerequistes -
15+
#
16+
# 1. An Azure Cosmos account -
17+
# https://docs.microsoft.com/azure/cosmos-db/create-sql-api-python#create-a-database-account
18+
#
19+
# 2. Microsoft Azure Cosmos
20+
# pip install azure-cosmos>=4.3.0b4
21+
# ----------------------------------------------------------------------------------------------------------
22+
# Sample - demonstrates how to authenticate and use your database account using AAD credentials
23+
# Read more about operations allowed for this authorization method: https://aka.ms/cosmos-native-rbac
24+
# ----------------------------------------------------------------------------------------------------------
25+
# Note:
26+
# This sample creates a Container to your database account.
27+
# Each time a Container is created the account will be billed for 1 hour of usage based on
28+
# the provisioned throughput (RU/s) of that account.
29+
# ----------------------------------------------------------------------------------------------------------
30+
# <configureConnectivity>
31+
HOST = config.settings["host"]
32+
MASTER_KEY = config.settings["master_key"]
33+
34+
TENANT_ID = config.settings["tenant_id"]
35+
CLIENT_ID = config.settings["client_id"]
36+
CLIENT_SECRET = config.settings["client_secret"]
37+
38+
DATABASE_ID = config.settings["database_id"]
39+
CONTAINER_ID = config.settings["container_id"]
40+
PARTITION_KEY = PartitionKey(path="/id")
41+
42+
43+
def get_test_item(num):
44+
test_item = {
45+
'id': 'Item_' + str(num),
46+
'test_object': True,
47+
'lastName': 'Smith'
48+
}
49+
return test_item
50+
51+
52+
async def create_sample_resources():
53+
print("creating sample resources")
54+
async with CosmosClient(HOST, MASTER_KEY) as client:
55+
db = await client.create_database(DATABASE_ID)
56+
await db.create_container(id=CONTAINER_ID, partition_key=PARTITION_KEY)
57+
58+
59+
async def delete_sample_resources():
60+
print("deleting sample resources")
61+
async with CosmosClient(HOST, MASTER_KEY) as client:
62+
await client.delete_database(DATABASE_ID)
63+
64+
65+
async def run_sample():
66+
# Since Azure Cosmos DB data plane SDK does not cover management operations, we have to create our resources
67+
# with a master key authenticated client for this sample.
68+
await create_sample_resources()
69+
70+
# With this done, you can use your AAD service principal id and secret to create your ClientSecretCredential.
71+
# The async ClientSecretCredentials, like the async client, also have a context manager,
72+
# and as such should be used with the `async with` keywords.
73+
async with ClientSecretCredential(
74+
tenant_id=TENANT_ID,
75+
client_id=CLIENT_ID,
76+
client_secret=CLIENT_SECRET) as aad_credentials:
77+
78+
# Use your credentials to authenticate your client.
79+
async with CosmosClient(HOST, aad_credentials) as aad_client:
80+
print("Showed ClientSecretCredential, now showing DefaultAzureCredential")
81+
82+
# You can also utilize DefaultAzureCredential rather than directly passing in the id's and secrets.
83+
# This is the recommended method of authentication, and uses environment variables rather than in-code strings.
84+
async with DefaultAzureCredential() as aad_credentials:
85+
86+
# Use your credentials to authenticate your client.
87+
async with CosmosClient(HOST, aad_credentials) as aad_client:
88+
89+
# Do any R/W data operations with your authorized AAD client.
90+
db = aad_client.get_database_client(DATABASE_ID)
91+
container = db.get_container_client(CONTAINER_ID)
92+
93+
print("Container info: " + str(container.read()))
94+
await container.create_item(get_test_item(879))
95+
print("Point read result: " + str(container.read_item(item='Item_0', partition_key='Item_0')))
96+
query_results = [item async for item in
97+
container.query_items(query='select * from c', partition_key='Item_0')]
98+
assert len(query_results) == 1
99+
print("Query result: " + str(query_results[0]))
100+
await container.delete_item(item='Item_0', partition_key='Item_0')
101+
102+
# Attempting to do management operations will return a 403 Forbidden exception.
103+
try:
104+
await aad_client.delete_database(DATABASE_ID)
105+
except exceptions.CosmosHttpResponseError as e:
106+
assert e.status_code == 403
107+
print("403 error assertion success")
108+
109+
# To clean up the sample, we use a master key client again to get access to deleting containers/ databases.
110+
await delete_sample_resources()
111+
print("end of sample")
112+
113+
114+
if __name__ == "__main__":
115+
loop = asyncio.get_event_loop()
116+
loop.run_until_complete(run_sample())

sdk/cosmos/azure-cosmos/samples/access_cosmos_with_resource_token.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See LICENSE.txt in the project root for
4+
# license information.
5+
# -------------------------------------------------------------------------
16
import azure.cosmos.cosmos_client as cosmos_client
27
import azure.cosmos.exceptions as exceptions
38
from azure.cosmos.partition_key import PartitionKey

sdk/cosmos/azure-cosmos/samples/access_cosmos_with_resource_token_async.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See LICENSE.txt in the project root for
4+
# license information.
5+
# -------------------------------------------------------------------------
16
import azure.cosmos.aio.cosmos_client as cosmos_client
27
import azure.cosmos.exceptions as exceptions
38
from azure.cosmos.partition_key import PartitionKey

sdk/cosmos/azure-cosmos/samples/change_feed_management.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See LICENSE.txt in the project root for
4+
# license information.
5+
# -------------------------------------------------------------------------
16
import azure.cosmos.documents as documents
27
import azure.cosmos.cosmos_client as cosmos_client
38
import azure.cosmos.exceptions as exceptions

sdk/cosmos/azure-cosmos/samples/change_feed_management_async.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See LICENSE.txt in the project root for
4+
# license information.
5+
# -------------------------------------------------------------------------
16
import azure.cosmos.aio.cosmos_client as cosmos_client
27
import azure.cosmos.exceptions as exceptions
38
import azure.cosmos.documents as documents

0 commit comments

Comments
 (0)