Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk/batch/azure-batch/MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ include LICENSE
include azure/batch/py.typed
recursive-include tests *.py
recursive-include samples *.py *.md
include azure/__init__.py
include azure/__init__.py
6 changes: 6 additions & 0 deletions sdk/batch/azure-batch/_meta.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"commit": "da5f436da0537251f7336b56f1e2df48c634d147",
"repository_url": "https://github.com/Azure/azure-rest-api-specs",
"typespec_src": "specification/batch/Azure.Batch",
"@azure-tools/typespec-python": "0.44.2"
}
359 changes: 359 additions & 0 deletions sdk/batch/azure-batch/apiview-properties.json

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion sdk/batch/azure-batch/azure/batch/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from ._configuration import BatchClientConfiguration
from ._operations import BatchClientOperationsMixin
from ._serialization import Deserializer, Serializer
from ._utils.serialization import Deserializer, Serializer

if TYPE_CHECKING:
from azure.core.credentials import TokenCredential
Expand All @@ -39,6 +39,7 @@ class BatchClient(BatchClientOperationsMixin):
def __init__(self, endpoint: str, credential: "TokenCredential", **kwargs: Any) -> None:
_endpoint = "{endpoint}"
self._config = BatchClientConfiguration(endpoint=endpoint, credential=credential, **kwargs)

kwargs["request_id_header_name"] = "client-request-id"
_policies = kwargs.pop("policies", None)
if _policies is None:
Expand Down
244 changes: 121 additions & 123 deletions sdk/batch/azure-batch/azure/batch/_operations/_operations.py

Large diffs are not rendered by default.

561 changes: 7 additions & 554 deletions sdk/batch/azure-batch/azure/batch/_operations/_patch.py

Large diffs are not rendered by default.

161 changes: 7 additions & 154 deletions sdk/batch/azure-batch/azure/batch/_patch.py
Original file line number Diff line number Diff line change
@@ -1,162 +1,15 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
# coding=utf-8
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------
"""Customize generated code here.

Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize
"""
import base64
import hmac
import hashlib
import importlib
from datetime import datetime
from typing import TYPE_CHECKING, TypeVar, Any, Union
from typing import List

from azure.core.pipeline.policies import SansIOHTTPPolicy
from azure.core.credentials import AzureNamedKeyCredential, TokenCredential
from azure.core.pipeline import PipelineResponse, PipelineRequest
from azure.core.pipeline.transport import HttpResponse
from azure.core.rest import HttpRequest

from ._client import BatchClient as GenerateBatchClient
from ._serialization import (
Serializer,
TZ_UTC,
)

try:
from urlparse import urlparse, parse_qs
except ImportError:
from urllib.parse import urlparse, parse_qs
__all__ = [
"BatchClient",
] # Add all objects you want publicly available to users at this package level

if TYPE_CHECKING:
# pylint: disable=unused-import,ungrouped-imports
from typing import Any, Callable, Dict, Optional, TypeVar, Union

from azure.core.credentials import TokenCredential
from azure.core.pipeline import PipelineRequest

ClientType = TypeVar("ClientType", bound="BatchClient")
T = TypeVar("T")
ClsType = Optional[Callable[[PipelineResponse[HttpRequest, HttpResponse], T, Dict[str, Any]], Any]]


class BatchSharedKeyAuthPolicy(SansIOHTTPPolicy):

headers_to_sign = [
"content-encoding",
"content-language",
"content-length",
"content-md5",
"content-type",
"date",
"if-modified-since",
"if-match",
"if-none-match",
"if-unmodified-since",
"range",
]

def __init__(self, credential: AzureNamedKeyCredential):
super(BatchSharedKeyAuthPolicy, self).__init__()
self._account_name = credential.named_key.name
self._key = credential.named_key.key

def on_request(self, request: PipelineRequest):
if not request.http_request.headers.get("ocp-date"):
now = datetime.utcnow()
now = now.replace(tzinfo=TZ_UTC)
request.http_request.headers["ocp-date"] = Serializer.serialize_rfc(now)
url = urlparse(request.http_request.url)
uri_path = url.path

# method to sign
string_to_sign = request.http_request.method + "\n"

# get headers to sign
request_header_dict = {key.lower(): val for key, val in request.http_request.headers.items() if val}

if request.http_request.method not in ["GET", "HEAD"]:
if "content-length" not in request_header_dict:
request_header_dict["content-length"] = "0"

request_headers = [str(request_header_dict.get(x, "")) for x in self.headers_to_sign]

string_to_sign += "\n".join(request_headers) + "\n"

# get ocp- header to sign
ocp_headers = []
for name, value in request.http_request.headers.items():
if "ocp-" in name and value:
ocp_headers.append((name.lower(), value))
for name, value in sorted(ocp_headers):
string_to_sign += "{}:{}\n".format(name, value)
# get account_name and uri path to sign
string_to_sign += "/{}{}".format(self._account_name, uri_path)

# get query string to sign if it is not table service
query_to_sign = parse_qs(url.query)

for name in sorted(query_to_sign.keys()):
value = query_to_sign[name][0]
if value:
string_to_sign += "\n{}:{}".format(name, value)
# sign the request
auth_string = "SharedKey {}:{}".format(self._account_name, self._sign_string(string_to_sign))

request.http_request.headers["Authorization"] = auth_string

return super().on_request(request)

def _sign_string(self, string_to_sign):

_key = self._key.encode("utf-8")
string_to_sign = string_to_sign.encode("utf-8")

try:
key = base64.b64decode(_key)
except TypeError:
raise ValueError("Invalid key value: {}".format(self._key))
signed_hmac_sha256 = hmac.HMAC(key, string_to_sign, hashlib.sha256)
digest = signed_hmac_sha256.digest()

return base64.b64encode(digest).decode("utf-8")


class BatchClient(GenerateBatchClient):
"""BatchClient.

:param endpoint: HTTP or HTTPS endpoint for the Web PubSub service instance.
:type endpoint: str
:param hub: Target hub name, which should start with alphabetic characters and only contain
alpha-numeric characters or underscore.
:type hub: str
:param credentials: Credential needed for the client to connect to Azure.
:type credentials: ~azure.identity.ClientSecretCredential, ~azure.core.credentials.AzureNamedKeyCredential,
or ~azure.identity.TokenCredentials
:keyword api_version: Api Version. The default value is "2021-10-01". Note that overriding this
default value may result in unsupported behavior.
:paramtype api_version: str
"""

def __init__(self, endpoint: str, credential: Union[AzureNamedKeyCredential, TokenCredential], **kwargs):
super().__init__(
endpoint=endpoint,
credential=credential, # type: ignore
authentication_policy=kwargs.pop(
"authentication_policy", self._format_shared_key_credential("", credential)
),
**kwargs
)

def _format_shared_key_credential(self, account_name, credential):
if isinstance(credential, AzureNamedKeyCredential):
return BatchSharedKeyAuthPolicy(credential)
return None
__all__: List[str] = [] # Add all objects you want publicly available to users at this package level


def patch_sdk():
Expand Down
6 changes: 6 additions & 0 deletions sdk/batch/azure-batch/azure/batch/_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# Code generated by Microsoft (R) Python Code Generator.
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# coding=utf-8
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# Licensed under the MIT License. See License.txt in the project root for license information.
# Code generated by Microsoft (R) Python Code Generator.
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------
# pylint: disable=protected-access, broad-except

Expand All @@ -21,18 +22,14 @@
from datetime import datetime, date, time, timedelta, timezone
from json import JSONEncoder
import xml.etree.ElementTree as ET
from collections.abc import MutableMapping
from typing_extensions import Self
import isodate
from azure.core.exceptions import DeserializationError
from azure.core import CaseInsensitiveEnumMeta
from azure.core.pipeline import PipelineResponse
from azure.core.serialization import _Null

if sys.version_info >= (3, 9):
from collections.abc import MutableMapping
else:
from typing import MutableMapping

_LOGGER = logging.getLogger(__name__)

__all__ = ["SdkJSONEncoder", "Model", "rest_field", "rest_discriminator"]
Expand Down Expand Up @@ -347,7 +344,7 @@ def _get_model(module_name: str, model_name: str):
_UNSET = object()


class _MyMutableMapping(MutableMapping[str, typing.Any]): # pylint: disable=unsubscriptable-object
class _MyMutableMapping(MutableMapping[str, typing.Any]):
def __init__(self, data: typing.Dict[str, typing.Any]) -> None:
self._data = data

Expand Down Expand Up @@ -407,13 +404,13 @@ def get(self, key: str, default: typing.Any = None) -> typing.Any:
return default

@typing.overload
def pop(self, key: str) -> typing.Any: ...
def pop(self, key: str) -> typing.Any: ... # pylint: disable=arguments-differ

@typing.overload
def pop(self, key: str, default: _T) -> _T: ...
def pop(self, key: str, default: _T) -> _T: ... # pylint: disable=signature-differs

@typing.overload
def pop(self, key: str, default: typing.Any) -> typing.Any: ...
def pop(self, key: str, default: typing.Any) -> typing.Any: ... # pylint: disable=signature-differs

def pop(self, key: str, default: typing.Any = _UNSET) -> typing.Any:
"""
Expand Down Expand Up @@ -443,7 +440,7 @@ def clear(self) -> None:
"""
self._data.clear()

def update(self, *args: typing.Any, **kwargs: typing.Any) -> None:
def update(self, *args: typing.Any, **kwargs: typing.Any) -> None: # pylint: disable=arguments-differ
"""
Updates D from mapping/iterable E and F.
:param any args: Either a mapping object or an iterable of key-value pairs.
Expand All @@ -454,7 +451,7 @@ def update(self, *args: typing.Any, **kwargs: typing.Any) -> None:
def setdefault(self, key: str, default: None = None) -> None: ...

@typing.overload
def setdefault(self, key: str, default: typing.Any) -> typing.Any: ...
def setdefault(self, key: str, default: typing.Any) -> typing.Any: ... # pylint: disable=signature-differs

def setdefault(self, key: str, default: typing.Any = _UNSET) -> typing.Any:
"""
Expand Down Expand Up @@ -644,7 +641,7 @@ def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> Self:
cls._attr_to_rest_field: typing.Dict[str, _RestField] = dict(attr_to_rest_field.items())
cls._calculated.add(f"{cls.__module__}.{cls.__qualname__}")

return super().__new__(cls) # pylint: disable=no-value-for-parameter
return super().__new__(cls)

def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None:
for base in cls.__bases__:
Expand Down Expand Up @@ -680,7 +677,7 @@ def _deserialize(cls, data, exist_discriminators):
discriminator_value = data.find(xml_name).text # pyright: ignore
else:
discriminator_value = data.get(discriminator._rest_name)
mapped_cls = cls.__mapping__.get(discriminator_value, cls) # pyright: ignore
mapped_cls = cls.__mapping__.get(discriminator_value, cls) # pyright: ignore # pylint: disable=no-member
return mapped_cls._deserialize(data, exist_discriminators)

def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,28 +1,10 @@
# pylint: disable=line-too-long,useless-suppression,too-many-lines
# coding=utf-8
# --------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
# Licensed under the MIT License. See License.txt in the project root for license information.
# Code generated by Microsoft (R) Python Code Generator.
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

# pyright: reportUnnecessaryTypeIgnoreComment=false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,23 @@
# --------------------------------------------------------------------------

from abc import ABC
from typing import Optional, TYPE_CHECKING
from typing import Generic, Optional, TYPE_CHECKING, TypeVar

from azure.core import MatchConditions

from ._configuration import BatchClientConfiguration

if TYPE_CHECKING:
from azure.core import PipelineClient
from .serialization import Deserializer, Serializer


from ._serialization import Deserializer, Serializer
TClient = TypeVar("TClient")
TConfig = TypeVar("TConfig")


class BatchClientMixinABC(ABC):
class ClientMixinABC(ABC, Generic[TClient, TConfig]):
"""DO NOT use this class. It is for internal typing use only."""

_client: "PipelineClient"
_config: BatchClientConfiguration
_client: TClient
_config: TConfig
_serialize: "Serializer"
_deserialize: "Deserializer"

Expand Down
2 changes: 1 addition & 1 deletion sdk/batch/azure-batch/azure/batch/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

VERSION = "15.0.0b2"
VERSION = "1.0.0b1"
3 changes: 2 additions & 1 deletion sdk/batch/azure-batch/azure/batch/aio/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from azure.core.pipeline import policies
from azure.core.rest import AsyncHttpResponse, HttpRequest

from .._serialization import Deserializer, Serializer
from .._utils.serialization import Deserializer, Serializer
from ._configuration import BatchClientConfiguration
from ._operations import BatchClientOperationsMixin

Expand All @@ -39,6 +39,7 @@ class BatchClient(BatchClientOperationsMixin):
def __init__(self, endpoint: str, credential: "AsyncTokenCredential", **kwargs: Any) -> None:
_endpoint = "{endpoint}"
self._config = BatchClientConfiguration(endpoint=endpoint, credential=credential, **kwargs)

kwargs["request_id_header_name"] = "client-request-id"
_policies = kwargs.pop("policies", None)
if _policies is None:
Expand Down
Loading