|
1 | | -# ------------------------------------ |
2 | | -# Copyright (c) Microsoft Corporation. |
3 | | -# Licensed under the MIT License. |
4 | | -# ------------------------------------ |
| 1 | +# coding=utf-8 |
| 2 | +# -------------------------------------------------------------------------- |
| 3 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 4 | +# Licensed under the MIT License. See License.txt in the project root for license information. |
| 5 | +# -------------------------------------------------------------------------- |
5 | 6 | """Customize generated code here. |
6 | 7 |
|
7 | 8 | Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize |
8 | 9 | """ |
9 | | -import base64 |
10 | | -import hmac |
11 | | -import hashlib |
12 | | -import importlib |
13 | | -from datetime import datetime |
14 | | -from typing import TYPE_CHECKING, TypeVar, Any, Union |
| 10 | +from typing import List |
15 | 11 |
|
16 | | -from azure.core.pipeline.policies import SansIOHTTPPolicy |
17 | | -from azure.core.credentials import AzureNamedKeyCredential, TokenCredential |
18 | | -from azure.core.pipeline import PipelineResponse, PipelineRequest |
19 | | -from azure.core.pipeline.transport import HttpResponse |
20 | | -from azure.core.rest import HttpRequest |
21 | | - |
22 | | -from ._client import BatchClient as GenerateBatchClient |
23 | | -from ._serialization import ( |
24 | | - Serializer, |
25 | | - TZ_UTC, |
26 | | -) |
27 | | - |
28 | | -try: |
29 | | - from urlparse import urlparse, parse_qs |
30 | | -except ImportError: |
31 | | - from urllib.parse import urlparse, parse_qs |
32 | | -__all__ = [ |
33 | | - "BatchClient", |
34 | | -] # Add all objects you want publicly available to users at this package level |
35 | | - |
36 | | -if TYPE_CHECKING: |
37 | | - # pylint: disable=unused-import,ungrouped-imports |
38 | | - from typing import Any, Callable, Dict, Optional, TypeVar, Union |
39 | | - |
40 | | - from azure.core.credentials import TokenCredential |
41 | | - from azure.core.pipeline import PipelineRequest |
42 | | - |
43 | | - ClientType = TypeVar("ClientType", bound="BatchClient") |
44 | | - T = TypeVar("T") |
45 | | - ClsType = Optional[Callable[[PipelineResponse[HttpRequest, HttpResponse], T, Dict[str, Any]], Any]] |
46 | | - |
47 | | - |
48 | | -class BatchSharedKeyAuthPolicy(SansIOHTTPPolicy): |
49 | | - |
50 | | - headers_to_sign = [ |
51 | | - "content-encoding", |
52 | | - "content-language", |
53 | | - "content-length", |
54 | | - "content-md5", |
55 | | - "content-type", |
56 | | - "date", |
57 | | - "if-modified-since", |
58 | | - "if-match", |
59 | | - "if-none-match", |
60 | | - "if-unmodified-since", |
61 | | - "range", |
62 | | - ] |
63 | | - |
64 | | - def __init__(self, credential: AzureNamedKeyCredential): |
65 | | - super(BatchSharedKeyAuthPolicy, self).__init__() |
66 | | - self._account_name = credential.named_key.name |
67 | | - self._key = credential.named_key.key |
68 | | - |
69 | | - def on_request(self, request: PipelineRequest): |
70 | | - if not request.http_request.headers.get("ocp-date"): |
71 | | - now = datetime.utcnow() |
72 | | - now = now.replace(tzinfo=TZ_UTC) |
73 | | - request.http_request.headers["ocp-date"] = Serializer.serialize_rfc(now) |
74 | | - url = urlparse(request.http_request.url) |
75 | | - uri_path = url.path |
76 | | - |
77 | | - # method to sign |
78 | | - string_to_sign = request.http_request.method + "\n" |
79 | | - |
80 | | - # get headers to sign |
81 | | - request_header_dict = {key.lower(): val for key, val in request.http_request.headers.items() if val} |
82 | | - |
83 | | - if request.http_request.method not in ["GET", "HEAD"]: |
84 | | - if "content-length" not in request_header_dict: |
85 | | - request_header_dict["content-length"] = "0" |
86 | | - |
87 | | - request_headers = [str(request_header_dict.get(x, "")) for x in self.headers_to_sign] |
88 | | - |
89 | | - string_to_sign += "\n".join(request_headers) + "\n" |
90 | | - |
91 | | - # get ocp- header to sign |
92 | | - ocp_headers = [] |
93 | | - for name, value in request.http_request.headers.items(): |
94 | | - if "ocp-" in name and value: |
95 | | - ocp_headers.append((name.lower(), value)) |
96 | | - for name, value in sorted(ocp_headers): |
97 | | - string_to_sign += "{}:{}\n".format(name, value) |
98 | | - # get account_name and uri path to sign |
99 | | - string_to_sign += "/{}{}".format(self._account_name, uri_path) |
100 | | - |
101 | | - # get query string to sign if it is not table service |
102 | | - query_to_sign = parse_qs(url.query) |
103 | | - |
104 | | - for name in sorted(query_to_sign.keys()): |
105 | | - value = query_to_sign[name][0] |
106 | | - if value: |
107 | | - string_to_sign += "\n{}:{}".format(name, value) |
108 | | - # sign the request |
109 | | - auth_string = "SharedKey {}:{}".format(self._account_name, self._sign_string(string_to_sign)) |
110 | | - |
111 | | - request.http_request.headers["Authorization"] = auth_string |
112 | | - |
113 | | - return super().on_request(request) |
114 | | - |
115 | | - def _sign_string(self, string_to_sign): |
116 | | - |
117 | | - _key = self._key.encode("utf-8") |
118 | | - string_to_sign = string_to_sign.encode("utf-8") |
119 | | - |
120 | | - try: |
121 | | - key = base64.b64decode(_key) |
122 | | - except TypeError: |
123 | | - raise ValueError("Invalid key value: {}".format(self._key)) |
124 | | - signed_hmac_sha256 = hmac.HMAC(key, string_to_sign, hashlib.sha256) |
125 | | - digest = signed_hmac_sha256.digest() |
126 | | - |
127 | | - return base64.b64encode(digest).decode("utf-8") |
128 | | - |
129 | | - |
130 | | -class BatchClient(GenerateBatchClient): |
131 | | - """BatchClient. |
132 | | -
|
133 | | - :param endpoint: HTTP or HTTPS endpoint for the Web PubSub service instance. |
134 | | - :type endpoint: str |
135 | | - :param hub: Target hub name, which should start with alphabetic characters and only contain |
136 | | - alpha-numeric characters or underscore. |
137 | | - :type hub: str |
138 | | - :param credentials: Credential needed for the client to connect to Azure. |
139 | | - :type credentials: ~azure.identity.ClientSecretCredential, ~azure.core.credentials.AzureNamedKeyCredential, |
140 | | - or ~azure.identity.TokenCredentials |
141 | | - :keyword api_version: Api Version. The default value is "2021-10-01". Note that overriding this |
142 | | - default value may result in unsupported behavior. |
143 | | - :paramtype api_version: str |
144 | | - """ |
145 | | - |
146 | | - def __init__(self, endpoint: str, credential: Union[AzureNamedKeyCredential, TokenCredential], **kwargs): |
147 | | - super().__init__( |
148 | | - endpoint=endpoint, |
149 | | - credential=credential, # type: ignore |
150 | | - authentication_policy=kwargs.pop( |
151 | | - "authentication_policy", self._format_shared_key_credential("", credential) |
152 | | - ), |
153 | | - **kwargs |
154 | | - ) |
155 | | - |
156 | | - def _format_shared_key_credential(self, account_name, credential): |
157 | | - if isinstance(credential, AzureNamedKeyCredential): |
158 | | - return BatchSharedKeyAuthPolicy(credential) |
159 | | - return None |
| 12 | +__all__: List[str] = [] # Add all objects you want publicly available to users at this package level |
160 | 13 |
|
161 | 14 |
|
162 | 15 | def patch_sdk(): |
|
0 commit comments