|
1 |
| -# ------------------------------------ |
2 |
| -# Copyright (c) Microsoft Corporation. |
3 |
| -# Licensed under the MIT License. |
4 |
| -# ------------------------------------ |
| 1 | +# coding=utf-8 |
| 2 | +# -------------------------------------------------------------------------- |
| 3 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 4 | +# Licensed under the MIT License. See License.txt in the project root for license information. |
| 5 | +# -------------------------------------------------------------------------- |
5 | 6 | """Customize generated code here.
|
6 | 7 |
|
7 | 8 | Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize
|
8 | 9 | """
|
9 |
| -import base64 |
10 |
| -import hmac |
11 |
| -import hashlib |
12 |
| -import importlib |
13 |
| -from datetime import datetime |
14 |
| -from typing import TYPE_CHECKING, TypeVar, Any, Union |
| 10 | +from typing import List |
15 | 11 |
|
16 |
| -from azure.core.pipeline.policies import SansIOHTTPPolicy |
17 |
| -from azure.core.credentials import AzureNamedKeyCredential, TokenCredential |
18 |
| -from azure.core.pipeline import PipelineResponse, PipelineRequest |
19 |
| -from azure.core.pipeline.transport import HttpResponse |
20 |
| -from azure.core.rest import HttpRequest |
21 |
| - |
22 |
| -from ._client import BatchClient as GenerateBatchClient |
23 |
| -from ._serialization import ( |
24 |
| - Serializer, |
25 |
| - TZ_UTC, |
26 |
| -) |
27 |
| - |
28 |
| -try: |
29 |
| - from urlparse import urlparse, parse_qs |
30 |
| -except ImportError: |
31 |
| - from urllib.parse import urlparse, parse_qs |
32 |
| -__all__ = [ |
33 |
| - "BatchClient", |
34 |
| -] # Add all objects you want publicly available to users at this package level |
35 |
| - |
36 |
| -if TYPE_CHECKING: |
37 |
| - # pylint: disable=unused-import,ungrouped-imports |
38 |
| - from typing import Any, Callable, Dict, Optional, TypeVar, Union |
39 |
| - |
40 |
| - from azure.core.credentials import TokenCredential |
41 |
| - from azure.core.pipeline import PipelineRequest |
42 |
| - |
43 |
| - ClientType = TypeVar("ClientType", bound="BatchClient") |
44 |
| - T = TypeVar("T") |
45 |
| - ClsType = Optional[Callable[[PipelineResponse[HttpRequest, HttpResponse], T, Dict[str, Any]], Any]] |
46 |
| - |
47 |
| - |
48 |
| -class BatchSharedKeyAuthPolicy(SansIOHTTPPolicy): |
49 |
| - |
50 |
| - headers_to_sign = [ |
51 |
| - "content-encoding", |
52 |
| - "content-language", |
53 |
| - "content-length", |
54 |
| - "content-md5", |
55 |
| - "content-type", |
56 |
| - "date", |
57 |
| - "if-modified-since", |
58 |
| - "if-match", |
59 |
| - "if-none-match", |
60 |
| - "if-unmodified-since", |
61 |
| - "range", |
62 |
| - ] |
63 |
| - |
64 |
| - def __init__(self, credential: AzureNamedKeyCredential): |
65 |
| - super(BatchSharedKeyAuthPolicy, self).__init__() |
66 |
| - self._account_name = credential.named_key.name |
67 |
| - self._key = credential.named_key.key |
68 |
| - |
69 |
| - def on_request(self, request: PipelineRequest): |
70 |
| - if not request.http_request.headers.get("ocp-date"): |
71 |
| - now = datetime.utcnow() |
72 |
| - now = now.replace(tzinfo=TZ_UTC) |
73 |
| - request.http_request.headers["ocp-date"] = Serializer.serialize_rfc(now) |
74 |
| - url = urlparse(request.http_request.url) |
75 |
| - uri_path = url.path |
76 |
| - |
77 |
| - # method to sign |
78 |
| - string_to_sign = request.http_request.method + "\n" |
79 |
| - |
80 |
| - # get headers to sign |
81 |
| - request_header_dict = {key.lower(): val for key, val in request.http_request.headers.items() if val} |
82 |
| - |
83 |
| - if request.http_request.method not in ["GET", "HEAD"]: |
84 |
| - if "content-length" not in request_header_dict: |
85 |
| - request_header_dict["content-length"] = "0" |
86 |
| - |
87 |
| - request_headers = [str(request_header_dict.get(x, "")) for x in self.headers_to_sign] |
88 |
| - |
89 |
| - string_to_sign += "\n".join(request_headers) + "\n" |
90 |
| - |
91 |
| - # get ocp- header to sign |
92 |
| - ocp_headers = [] |
93 |
| - for name, value in request.http_request.headers.items(): |
94 |
| - if "ocp-" in name and value: |
95 |
| - ocp_headers.append((name.lower(), value)) |
96 |
| - for name, value in sorted(ocp_headers): |
97 |
| - string_to_sign += "{}:{}\n".format(name, value) |
98 |
| - # get account_name and uri path to sign |
99 |
| - string_to_sign += "/{}{}".format(self._account_name, uri_path) |
100 |
| - |
101 |
| - # get query string to sign if it is not table service |
102 |
| - query_to_sign = parse_qs(url.query) |
103 |
| - |
104 |
| - for name in sorted(query_to_sign.keys()): |
105 |
| - value = query_to_sign[name][0] |
106 |
| - if value: |
107 |
| - string_to_sign += "\n{}:{}".format(name, value) |
108 |
| - # sign the request |
109 |
| - auth_string = "SharedKey {}:{}".format(self._account_name, self._sign_string(string_to_sign)) |
110 |
| - |
111 |
| - request.http_request.headers["Authorization"] = auth_string |
112 |
| - |
113 |
| - return super().on_request(request) |
114 |
| - |
115 |
| - def _sign_string(self, string_to_sign): |
116 |
| - |
117 |
| - _key = self._key.encode("utf-8") |
118 |
| - string_to_sign = string_to_sign.encode("utf-8") |
119 |
| - |
120 |
| - try: |
121 |
| - key = base64.b64decode(_key) |
122 |
| - except TypeError: |
123 |
| - raise ValueError("Invalid key value: {}".format(self._key)) |
124 |
| - signed_hmac_sha256 = hmac.HMAC(key, string_to_sign, hashlib.sha256) |
125 |
| - digest = signed_hmac_sha256.digest() |
126 |
| - |
127 |
| - return base64.b64encode(digest).decode("utf-8") |
128 |
| - |
129 |
| - |
130 |
| -class BatchClient(GenerateBatchClient): |
131 |
| - """BatchClient. |
132 |
| -
|
133 |
| - :param endpoint: HTTP or HTTPS endpoint for the Web PubSub service instance. |
134 |
| - :type endpoint: str |
135 |
| - :param hub: Target hub name, which should start with alphabetic characters and only contain |
136 |
| - alpha-numeric characters or underscore. |
137 |
| - :type hub: str |
138 |
| - :param credentials: Credential needed for the client to connect to Azure. |
139 |
| - :type credentials: ~azure.identity.ClientSecretCredential, ~azure.core.credentials.AzureNamedKeyCredential, |
140 |
| - or ~azure.identity.TokenCredentials |
141 |
| - :keyword api_version: Api Version. The default value is "2021-10-01". Note that overriding this |
142 |
| - default value may result in unsupported behavior. |
143 |
| - :paramtype api_version: str |
144 |
| - """ |
145 |
| - |
146 |
| - def __init__(self, endpoint: str, credential: Union[AzureNamedKeyCredential, TokenCredential], **kwargs): |
147 |
| - super().__init__( |
148 |
| - endpoint=endpoint, |
149 |
| - credential=credential, # type: ignore |
150 |
| - authentication_policy=kwargs.pop( |
151 |
| - "authentication_policy", self._format_shared_key_credential("", credential) |
152 |
| - ), |
153 |
| - **kwargs |
154 |
| - ) |
155 |
| - |
156 |
| - def _format_shared_key_credential(self, account_name, credential): |
157 |
| - if isinstance(credential, AzureNamedKeyCredential): |
158 |
| - return BatchSharedKeyAuthPolicy(credential) |
159 |
| - return None |
| 12 | +__all__: List[str] = [] # Add all objects you want publicly available to users at this package level |
160 | 13 |
|
161 | 14 |
|
162 | 15 | def patch_sdk():
|
|
0 commit comments