Skip to content

Commit 0899474

Browse files
authored
[Cosmos] Typing part 2 (#33341)
* More typing * Literals * Fix import * Add typing extensions * Body must be Dict * Fixed some samples typing * Multihash pk typehints * Remove ignored option * Review feedback * More review feedback * Fix tests * Fix instance check
1 parent 3319841 commit 0899474

31 files changed

+1225
-1309
lines changed

sdk/cosmos/azure-cosmos/azure/cosmos/_auth_policy.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,47 @@
33
# Licensed under the MIT License. See LICENSE.txt in the project root for
44
# license information.
55
# -------------------------------------------------------------------------
6-
from typing import MutableMapping
6+
from typing import TypeVar, Any, MutableMapping
77

8+
from azure.core.pipeline import PipelineRequest
89
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
9-
from azure.cosmos import http_constants
10+
from azure.core.pipeline.transport import HttpRequest as LegacyHttpRequest
11+
from azure.core.rest import HttpRequest
12+
13+
from .http_constants import HttpHeaders
14+
15+
HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest)
1016

1117

1218
class CosmosBearerTokenCredentialPolicy(BearerTokenCredentialPolicy):
1319

1420
@staticmethod
1521
def _update_headers(headers: MutableMapping[str, str], token: str) -> None:
1622
"""Updates the Authorization header with the bearer token.
17-
This is the main method that differentiates this policy from core's BearerTokenCredentialPolicy and works
18-
to properly sign the authorization header for Cosmos' REST API. For more information:
19-
https://docs.microsoft.com/rest/api/cosmos-db/access-control-on-cosmosdb-resources#authorization-header
2023
21-
:param dict headers: The HTTP Request headers
24+
:param MutableMapping[str, str] headers: The HTTP Request headers
2225
:param str token: The OAuth token.
2326
"""
24-
headers[http_constants.HttpHeaders.Authorization] = f"type=aad&ver=1.0&sig={token}"
27+
headers[HttpHeaders.Authorization] = f"type=aad&ver=1.0&sig={token}"
28+
29+
def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
30+
"""Called before the policy sends a request.
31+
32+
The base implementation authorizes the request with a bearer token.
33+
34+
:param ~azure.core.pipeline.PipelineRequest request: the request
35+
"""
36+
super().on_request(request)
37+
self._update_headers(request.http_request.headers, self._token.token)
38+
39+
def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
40+
"""Acquire a token from the credential and authorize the request with it.
41+
42+
Keyword arguments are passed to the credential's get_token method. The token will be cached and used to
43+
authorize future requests.
44+
45+
:param ~azure.core.pipeline.PipelineRequest request: the request
46+
:param str scopes: required scopes of authentication
47+
"""
48+
super().authorize_request(request, *scopes, **kwargs)
49+
self._update_headers(request.http_request.headers, self._token.token)

sdk/cosmos/azure-cosmos/azure/cosmos/_base.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from email.utils import formatdate
2727
import json
2828
import uuid
29+
import re
2930
import binascii
3031
from typing import Dict, Any, List, Mapping, Optional, Sequence, Union, Tuple, TYPE_CHECKING
3132

@@ -61,6 +62,13 @@
6162
'priority_level': 'priorityLevel'
6263
}
6364

65+
# Cosmos resource ID validation regex breakdown:
66+
# ^ Match start of string.
67+
# [^/\#?]{0,255} Match any character that is not /\#? for between 0-255 characters.
68+
# $ End of string
69+
_VALID_COSMOS_RESOURCE = re.compile(r"^[^/\\#?\t\r\n]{0,255}$")
70+
71+
6472
def _get_match_headers(kwargs: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]:
6573
if_match = kwargs.pop('if_match', None)
6674
if_none_match = kwargs.pop('if_none_match', None)
@@ -684,8 +692,20 @@ def validate_cache_staleness_value(max_integrated_cache_staleness: Any) -> None:
684692
"integer greater than or equal to zero")
685693

686694

695+
def _validate_resource(resource: Mapping[str, Any]) -> None:
696+
id_: Optional[str] = resource.get("id")
697+
if id_:
698+
try:
699+
if _VALID_COSMOS_RESOURCE.match(id_) is None:
700+
raise ValueError("Id contains illegal chars.")
701+
if id_[-1] in [" ", "\n"]:
702+
raise ValueError("Id ends with a space or newline.")
703+
except TypeError as e:
704+
raise TypeError("Id type must be a string.") from e
705+
706+
687707
def _stringify_auto_scale(offer: ThroughputProperties) -> str:
688-
auto_scale_params = None
708+
auto_scale_params: Optional[Dict[str, Union[None, int, Dict[str, Any]]]] = None
689709
max_throughput = offer.auto_scale_max_throughput
690710
increment_percent = offer.auto_scale_increment_percent
691711
auto_scale_params = {"maxThroughput": max_throughput}

sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,28 @@
2323
"""
2424

2525

26-
class _Constants(object):
26+
from typing import Dict
27+
from typing_extensions import Literal
28+
29+
30+
class _Constants:
2731
"""Constants used in the azure-cosmos package"""
2832

29-
UserConsistencyPolicy = "userConsistencyPolicy"
30-
DefaultConsistencyLevel = "defaultConsistencyLevel"
33+
UserConsistencyPolicy: Literal["userConsistencyPolicy"] = "userConsistencyPolicy"
34+
DefaultConsistencyLevel: Literal["defaultConsistencyLevel"] = "defaultConsistencyLevel"
3135

3236
# GlobalDB related constants
33-
WritableLocations = "writableLocations"
34-
ReadableLocations = "readableLocations"
35-
Name = "name"
36-
DatabaseAccountEndpoint = "databaseAccountEndpoint"
37-
DefaultUnavailableLocationExpirationTime = 5 * 60 * 1000
37+
WritableLocations: Literal["writableLocations"] = "writableLocations"
38+
ReadableLocations: Literal["readableLocations"] = "readableLocations"
39+
Name: Literal["name"] = "name"
40+
DatabaseAccountEndpoint: Literal["databaseAccountEndpoint"] = "databaseAccountEndpoint"
41+
DefaultUnavailableLocationExpirationTime: int = 5 * 60 * 1000
3842

3943
# ServiceDocument Resource
40-
EnableMultipleWritableLocations = "enableMultipleWriteLocations"
44+
EnableMultipleWritableLocations: Literal["enableMultipleWriteLocations"] = "enableMultipleWriteLocations"
4145

4246
# Error code translations
43-
ERROR_TRANSLATIONS = {
47+
ERROR_TRANSLATIONS: Dict[int, str] = {
4448
400: "BAD_REQUEST - Request being sent is invalid.",
4549
401: "UNAUTHORIZED - The input authorization token can't serve the request.",
4650
403: "FORBIDDEN",

0 commit comments

Comments
 (0)