diff --git a/sdk/ai/azure-ai-inference/_meta.json b/sdk/ai/azure-ai-inference/_meta.json new file mode 100644 index 000000000000..79055fc5cec0 --- /dev/null +++ b/sdk/ai/azure-ai-inference/_meta.json @@ -0,0 +1,6 @@ +{ + "commit": "b1ad4b2d6b802834ce695f4b21da2af587f53fba", + "repository_url": "https://github.com/Azure/azure-rest-api-specs", + "typespec_src": "specification/ai/ModelClient", + "@azure-tools/typespec-python": "0.36.1" +} \ No newline at end of file diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/__init__.py b/sdk/ai/azure-ai-inference/azure/ai/inference/__init__.py index ff62b276a309..b7537d16cab3 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/__init__.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/__init__.py @@ -5,24 +5,32 @@ # Code generated by Microsoft (R) Python Code Generator. # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- +# pylint: disable=wrong-import-position -from ._patch import ChatCompletionsClient -from ._patch import EmbeddingsClient -from ._patch import ImageEmbeddingsClient +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ._patch import * # pylint: disable=unused-wildcard-import + +from ._client import ChatCompletionsClient # type: ignore +from ._client import EmbeddingsClient # type: ignore +from ._client import ImageEmbeddingsClient # type: ignore from ._version import VERSION __version__ = VERSION - -from ._patch import load_client +try: + from ._patch import __all__ as _patch_all + from ._patch import * +except ImportError: + _patch_all = [] from ._patch import patch_sdk as _patch_sdk __all__ = [ - "load_client", "ChatCompletionsClient", "EmbeddingsClient", "ImageEmbeddingsClient", ] - +__all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/_client.py b/sdk/ai/azure-ai-inference/azure/ai/inference/_client.py index 25f4b3746e76..f7610fa42095 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/_client.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/_client.py @@ -28,19 +28,18 @@ from ._serialization import Deserializer, Serializer if TYPE_CHECKING: - # pylint: disable=unused-import,ungrouped-imports from azure.core.credentials import TokenCredential -class ChatCompletionsClient(ChatCompletionsClientOperationsMixin): # pylint: disable=client-accepts-api-version-keyword +class ChatCompletionsClient(ChatCompletionsClientOperationsMixin): """ChatCompletionsClient. :param endpoint: Service host. Required. :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is one of the - following types: AzureKeyCredential, AzureKeyCredential, TokenCredential Required. + :param credential: Credential used to authenticate requests to the service. Is either a + AzureKeyCredential type or a TokenCredential type. Required. :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials.TokenCredential + ~azure.core.credentials.TokenCredential :keyword api_version: The API version to use for this operation. Default value is "2024-05-01-preview". Note that overriding this default value may result in unsupported behavior. @@ -110,15 +109,15 @@ def __exit__(self, *exc_details: Any) -> None: self._client.__exit__(*exc_details) -class EmbeddingsClient(EmbeddingsClientOperationsMixin): # pylint: disable=client-accepts-api-version-keyword +class EmbeddingsClient(EmbeddingsClientOperationsMixin): """EmbeddingsClient. :param endpoint: Service host. Required. :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is one of the - following types: AzureKeyCredential, AzureKeyCredential, TokenCredential Required. + :param credential: Credential used to authenticate requests to the service. Is either a + AzureKeyCredential type or a TokenCredential type. Required. :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials.TokenCredential + ~azure.core.credentials.TokenCredential :keyword api_version: The API version to use for this operation. Default value is "2024-05-01-preview". Note that overriding this default value may result in unsupported behavior. @@ -188,15 +187,15 @@ def __exit__(self, *exc_details: Any) -> None: self._client.__exit__(*exc_details) -class ImageEmbeddingsClient(ImageEmbeddingsClientOperationsMixin): # pylint: disable=client-accepts-api-version-keyword +class ImageEmbeddingsClient(ImageEmbeddingsClientOperationsMixin): """ImageEmbeddingsClient. :param endpoint: Service host. Required. :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is one of the - following types: AzureKeyCredential, AzureKeyCredential, TokenCredential Required. + :param credential: Credential used to authenticate requests to the service. Is either a + AzureKeyCredential type or a TokenCredential type. Required. :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials.TokenCredential + ~azure.core.credentials.TokenCredential :keyword api_version: The API version to use for this operation. Default value is "2024-05-01-preview". Note that overriding this default value may result in unsupported behavior. diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/_configuration.py b/sdk/ai/azure-ai-inference/azure/ai/inference/_configuration.py index 9393659fb910..fb44d9dfdbe2 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/_configuration.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/_configuration.py @@ -14,11 +14,10 @@ from ._version import VERSION if TYPE_CHECKING: - # pylint: disable=unused-import,ungrouped-imports from azure.core.credentials import TokenCredential -class ChatCompletionsClientConfiguration: # pylint: disable=too-many-instance-attributes,name-too-long +class ChatCompletionsClientConfiguration: # pylint: disable=too-many-instance-attributes """Configuration for ChatCompletionsClient. Note that all parameters used to create this instance are saved as instance @@ -26,10 +25,10 @@ class ChatCompletionsClientConfiguration: # pylint: disable=too-many-instance-a :param endpoint: Service host. Required. :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is one of the - following types: AzureKeyCredential, AzureKeyCredential, TokenCredential Required. + :param credential: Credential used to authenticate requests to the service. Is either a + AzureKeyCredential type or a TokenCredential type. Required. :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials.TokenCredential + ~azure.core.credentials.TokenCredential :keyword api_version: The API version to use for this operation. Default value is "2024-05-01-preview". Note that overriding this default value may result in unsupported behavior. @@ -55,8 +54,6 @@ def __init__(self, endpoint: str, credential: Union[AzureKeyCredential, "TokenCr def _infer_policy(self, **kwargs): if isinstance(self.credential, AzureKeyCredential): return policies.AzureKeyCredentialPolicy(self.credential, "Authorization", prefix="Bearer", **kwargs) - if isinstance(self.credential, AzureKeyCredential): - return policies.AzureKeyCredentialPolicy(self.credential, "api-key", **kwargs) if hasattr(self.credential, "get_token"): return policies.BearerTokenCredentialPolicy(self.credential, *self.credential_scopes, **kwargs) raise TypeError(f"Unsupported credential: {self.credential}") @@ -75,7 +72,7 @@ def _configure(self, **kwargs: Any) -> None: self.authentication_policy = self._infer_policy(**kwargs) -class EmbeddingsClientConfiguration: # pylint: disable=too-many-instance-attributes,name-too-long +class EmbeddingsClientConfiguration: # pylint: disable=too-many-instance-attributes """Configuration for EmbeddingsClient. Note that all parameters used to create this instance are saved as instance @@ -83,10 +80,10 @@ class EmbeddingsClientConfiguration: # pylint: disable=too-many-instance-attrib :param endpoint: Service host. Required. :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is one of the - following types: AzureKeyCredential, AzureKeyCredential, TokenCredential Required. + :param credential: Credential used to authenticate requests to the service. Is either a + AzureKeyCredential type or a TokenCredential type. Required. :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials.TokenCredential + ~azure.core.credentials.TokenCredential :keyword api_version: The API version to use for this operation. Default value is "2024-05-01-preview". Note that overriding this default value may result in unsupported behavior. @@ -112,8 +109,6 @@ def __init__(self, endpoint: str, credential: Union[AzureKeyCredential, "TokenCr def _infer_policy(self, **kwargs): if isinstance(self.credential, AzureKeyCredential): return policies.AzureKeyCredentialPolicy(self.credential, "Authorization", prefix="Bearer", **kwargs) - if isinstance(self.credential, AzureKeyCredential): - return policies.AzureKeyCredentialPolicy(self.credential, "api-key", **kwargs) if hasattr(self.credential, "get_token"): return policies.BearerTokenCredentialPolicy(self.credential, *self.credential_scopes, **kwargs) raise TypeError(f"Unsupported credential: {self.credential}") @@ -132,7 +127,7 @@ def _configure(self, **kwargs: Any) -> None: self.authentication_policy = self._infer_policy(**kwargs) -class ImageEmbeddingsClientConfiguration: # pylint: disable=too-many-instance-attributes,name-too-long +class ImageEmbeddingsClientConfiguration: # pylint: disable=too-many-instance-attributes """Configuration for ImageEmbeddingsClient. Note that all parameters used to create this instance are saved as instance @@ -140,10 +135,10 @@ class ImageEmbeddingsClientConfiguration: # pylint: disable=too-many-instance-a :param endpoint: Service host. Required. :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is one of the - following types: AzureKeyCredential, AzureKeyCredential, TokenCredential Required. + :param credential: Credential used to authenticate requests to the service. Is either a + AzureKeyCredential type or a TokenCredential type. Required. :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials.TokenCredential + ~azure.core.credentials.TokenCredential :keyword api_version: The API version to use for this operation. Default value is "2024-05-01-preview". Note that overriding this default value may result in unsupported behavior. @@ -169,8 +164,6 @@ def __init__(self, endpoint: str, credential: Union[AzureKeyCredential, "TokenCr def _infer_policy(self, **kwargs): if isinstance(self.credential, AzureKeyCredential): return policies.AzureKeyCredentialPolicy(self.credential, "Authorization", prefix="Bearer", **kwargs) - if isinstance(self.credential, AzureKeyCredential): - return policies.AzureKeyCredentialPolicy(self.credential, "api-key", **kwargs) if hasattr(self.credential, "get_token"): return policies.BearerTokenCredentialPolicy(self.credential, *self.credential_scopes, **kwargs) raise TypeError(f"Unsupported credential: {self.credential}") diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/_model_base.py b/sdk/ai/azure-ai-inference/azure/ai/inference/_model_base.py index c4b1008c1e85..e6a2730f9276 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/_model_base.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/_model_base.py @@ -1,10 +1,11 @@ +# pylint: disable=too-many-lines # coding=utf-8 # -------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -# pylint: disable=protected-access, arguments-differ, signature-differs, broad-except +# pylint: disable=protected-access, broad-except import copy import calendar @@ -19,6 +20,7 @@ import email.utils from datetime import datetime, date, time, timedelta, timezone from json import JSONEncoder +import xml.etree.ElementTree as ET from typing_extensions import Self import isodate from azure.core.exceptions import DeserializationError @@ -123,7 +125,7 @@ def _serialize_datetime(o, format: typing.Optional[str] = None): def _is_readonly(p): try: - return p._visibility == ["read"] # pylint: disable=protected-access + return p._visibility == ["read"] except AttributeError: return False @@ -286,6 +288,12 @@ def _deserialize_decimal(attr): return decimal.Decimal(str(attr)) +def _deserialize_int_as_str(attr): + if isinstance(attr, int): + return attr + return int(attr) + + _DESERIALIZE_MAPPING = { datetime: _deserialize_datetime, date: _deserialize_date, @@ -307,9 +315,11 @@ def _deserialize_decimal(attr): def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = None): + if annotation is int and rf and rf._format == "str": + return _deserialize_int_as_str if rf and rf._format: return _DESERIALIZE_MAPPING_WITHFORMAT.get(rf._format) - return _DESERIALIZE_MAPPING.get(annotation) + return _DESERIALIZE_MAPPING.get(annotation) # pyright: ignore def _get_type_alias_type(module_name: str, alias_name: str): @@ -441,6 +451,10 @@ def _serialize(o, format: typing.Optional[str] = None): # pylint: disable=too-m return float(o) if isinstance(o, enum.Enum): return o.value + if isinstance(o, int): + if format == "str": + return str(o) + return o try: # First try datetime.datetime return _serialize_datetime(o, format) @@ -471,6 +485,8 @@ def _create_value(rf: typing.Optional["_RestField"], value: typing.Any) -> typin return value if rf._is_model: return _deserialize(rf._type, value) + if isinstance(value, ET.Element): + value = _deserialize(rf._type, value) return _serialize(value, rf._format) @@ -489,10 +505,58 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: for rest_field in self._attr_to_rest_field.values() if rest_field._default is not _UNSET } - if args: - dict_to_pass.update( - {k: _create_value(_get_rest_field(self._attr_to_rest_field, k), v) for k, v in args[0].items()} - ) + if args: # pylint: disable=too-many-nested-blocks + if isinstance(args[0], ET.Element): + existed_attr_keys = [] + model_meta = getattr(self, "_xml", {}) + + for rf in self._attr_to_rest_field.values(): + prop_meta = getattr(rf, "_xml", {}) + xml_name = prop_meta.get("name", rf._rest_name) + xml_ns = prop_meta.get("ns", model_meta.get("ns", None)) + if xml_ns: + xml_name = "{" + xml_ns + "}" + xml_name + + # attribute + if prop_meta.get("attribute", False) and args[0].get(xml_name) is not None: + existed_attr_keys.append(xml_name) + dict_to_pass[rf._rest_name] = _deserialize(rf._type, args[0].get(xml_name)) + continue + + # unwrapped element is array + if prop_meta.get("unwrapped", False): + # unwrapped array could either use prop items meta/prop meta + if prop_meta.get("itemsName"): + xml_name = prop_meta.get("itemsName") + xml_ns = prop_meta.get("itemNs") + if xml_ns: + xml_name = "{" + xml_ns + "}" + xml_name + items = args[0].findall(xml_name) # pyright: ignore + if len(items) > 0: + existed_attr_keys.append(xml_name) + dict_to_pass[rf._rest_name] = _deserialize(rf._type, items) + continue + + # text element is primitive type + if prop_meta.get("text", False): + if args[0].text is not None: + dict_to_pass[rf._rest_name] = _deserialize(rf._type, args[0].text) + continue + + # wrapped element could be normal property or array, it should only have one element + item = args[0].find(xml_name) + if item is not None: + existed_attr_keys.append(xml_name) + dict_to_pass[rf._rest_name] = _deserialize(rf._type, item) + + # rest thing is additional properties + for e in args[0]: + if e.tag not in existed_attr_keys: + dict_to_pass[e.tag] = _convert_element(e) + else: + dict_to_pass.update( + {k: _create_value(_get_rest_field(self._attr_to_rest_field, k), v) for k, v in args[0].items()} + ) else: non_attr_kwargs = [k for k in kwargs if k not in self._attr_to_rest_field] if non_attr_kwargs: @@ -510,7 +574,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: def copy(self) -> "Model": return Model(self.__dict__) - def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> Self: # pylint: disable=unused-argument + def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> Self: if f"{cls.__module__}.{cls.__qualname__}" not in cls._calculated: # we know the last nine classes in mro are going to be 'Model', '_MyMutableMapping', 'MutableMapping', # 'Mapping', 'Collection', 'Sized', 'Iterable', 'Container' and 'object' @@ -521,8 +585,8 @@ def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> Self: # pylint: di annotations = { k: v for mro_class in mros - if hasattr(mro_class, "__annotations__") # pylint: disable=no-member - for k, v in mro_class.__annotations__.items() # pylint: disable=no-member + if hasattr(mro_class, "__annotations__") + for k, v in mro_class.__annotations__.items() } for attr, rf in attr_to_rest_field.items(): rf._module = cls.__module__ @@ -537,31 +601,43 @@ def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> Self: # pylint: di def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: for base in cls.__bases__: - if hasattr(base, "__mapping__"): # pylint: disable=no-member - base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member + if hasattr(base, "__mapping__"): + base.__mapping__[discriminator or cls.__name__] = cls # type: ignore @classmethod - def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional["_RestField"]: for v in cls.__dict__.values(): - if ( - isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators - ): # pylint: disable=protected-access - return v._rest_name # pylint: disable=protected-access + if isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators: + return v return None @classmethod def _deserialize(cls, data, exist_discriminators): - if not hasattr(cls, "__mapping__"): # pylint: disable=no-member + if not hasattr(cls, "__mapping__"): return cls(data) discriminator = cls._get_discriminator(exist_discriminators) - exist_discriminators.append(discriminator) - mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pyright: ignore # pylint: disable=no-member - if mapped_cls == cls: + if discriminator is None: return cls(data) - return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + exist_discriminators.append(discriminator._rest_name) + if isinstance(data, ET.Element): + model_meta = getattr(cls, "_xml", {}) + prop_meta = getattr(discriminator, "_xml", {}) + xml_name = prop_meta.get("name", discriminator._rest_name) + xml_ns = prop_meta.get("ns", model_meta.get("ns", None)) + if xml_ns: + xml_name = "{" + xml_ns + "}" + xml_name + + if data.get(xml_name) is not None: + discriminator_value = data.get(xml_name) + else: + discriminator_value = data.find(xml_name).text # pyright: ignore + else: + discriminator_value = data.get(discriminator._rest_name) + mapped_cls = cls.__mapping__.get(discriminator_value, cls) # pyright: ignore + return mapped_cls._deserialize(data, exist_discriminators) def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: - """Return a dict that can be JSONify using json.dump. + """Return a dict that can be turned into json using json.dump. :keyword bool exclude_readonly: Whether to remove the readonly properties. :returns: A dict JSON compatible object @@ -624,6 +700,8 @@ def _deserialize_dict( ): if obj is None: return obj + if isinstance(obj, ET.Element): + obj = {child.tag: child for child in obj} return {k: _deserialize(value_deserializer, v, module) for k, v in obj.items()} @@ -644,6 +722,8 @@ def _deserialize_sequence( ): if obj is None: return obj + if isinstance(obj, ET.Element): + obj = list(obj) return type(obj)(_deserialize(deserializer, entry, module) for entry in obj) @@ -654,12 +734,12 @@ def _sorted_annotations(types: typing.List[typing.Any]) -> typing.List[typing.An ) -def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 +def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-branches annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, ) -> typing.Optional[typing.Callable[[typing.Any], typing.Any]]: - if not annotation or annotation in [int, float]: + if not annotation: return None # is it a type alias? @@ -734,7 +814,6 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, try: if annotation._name in ["List", "Set", "Tuple", "Sequence"]: # pyright: ignore if len(annotation.__args__) > 1: # pyright: ignore - entry_deserializers = [ _get_deserialize_callable_from_annotation(dt, module, rf) for dt in annotation.__args__ # pyright: ignore @@ -769,12 +848,23 @@ def _deserialize_default( def _deserialize_with_callable( deserializer: typing.Optional[typing.Callable[[typing.Any], typing.Any]], value: typing.Any, -): +): # pylint: disable=too-many-return-statements try: if value is None or isinstance(value, _Null): return None + if isinstance(value, ET.Element): + if deserializer is str: + return value.text or "" + if deserializer is int: + return int(value.text) if value.text else None + if deserializer is float: + return float(value.text) if value.text else None + if deserializer is bool: + return value.text == "true" if value.text else None if deserializer is None: return value + if deserializer in [int, float, bool]: + return deserializer(value) if isinstance(deserializer, CaseInsensitiveEnumMeta): try: return deserializer(value) @@ -815,6 +905,7 @@ def __init__( default: typing.Any = _UNSET, format: typing.Optional[str] = None, is_multipart_file_input: bool = False, + xml: typing.Optional[typing.Dict[str, typing.Any]] = None, ): self._type = type self._rest_name_input = name @@ -825,6 +916,7 @@ def __init__( self._default = default self._format = format self._is_multipart_file_input = is_multipart_file_input + self._xml = xml if xml is not None else {} @property def _class_type(self) -> typing.Any: @@ -875,6 +967,7 @@ def rest_field( default: typing.Any = _UNSET, format: typing.Optional[str] = None, is_multipart_file_input: bool = False, + xml: typing.Optional[typing.Dict[str, typing.Any]] = None, ) -> typing.Any: return _RestField( name=name, @@ -883,6 +976,7 @@ def rest_field( default=default, format=format, is_multipart_file_input=is_multipart_file_input, + xml=xml, ) @@ -891,5 +985,175 @@ def rest_discriminator( name: typing.Optional[str] = None, type: typing.Optional[typing.Callable] = None, # pylint: disable=redefined-builtin visibility: typing.Optional[typing.List[str]] = None, + xml: typing.Optional[typing.Dict[str, typing.Any]] = None, +) -> typing.Any: + return _RestField(name=name, type=type, is_discriminator=True, visibility=visibility, xml=xml) + + +def serialize_xml(model: Model, exclude_readonly: bool = False) -> str: + """Serialize a model to XML. + + :param Model model: The model to serialize. + :param bool exclude_readonly: Whether to exclude readonly properties. + :returns: The XML representation of the model. + :rtype: str + """ + return ET.tostring(_get_element(model, exclude_readonly), encoding="unicode") # type: ignore + + +def _get_element( + o: typing.Any, + exclude_readonly: bool = False, + parent_meta: typing.Optional[typing.Dict[str, typing.Any]] = None, + wrapped_element: typing.Optional[ET.Element] = None, +) -> typing.Union[ET.Element, typing.List[ET.Element]]: + if _is_model(o): + model_meta = getattr(o, "_xml", {}) + + # if prop is a model, then use the prop element directly, else generate a wrapper of model + if wrapped_element is None: + wrapped_element = _create_xml_element( + model_meta.get("name", o.__class__.__name__), + model_meta.get("prefix"), + model_meta.get("ns"), + ) + + readonly_props = [] + if exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + + for k, v in o.items(): + # do not serialize readonly properties + if exclude_readonly and k in readonly_props: + continue + + prop_rest_field = _get_rest_field(o._attr_to_rest_field, k) + if prop_rest_field: + prop_meta = getattr(prop_rest_field, "_xml").copy() + # use the wire name as xml name if no specific name is set + if prop_meta.get("name") is None: + prop_meta["name"] = k + else: + # additional properties will not have rest field, use the wire name as xml name + prop_meta = {"name": k} + + # if no ns for prop, use model's + if prop_meta.get("ns") is None and model_meta.get("ns"): + prop_meta["ns"] = model_meta.get("ns") + prop_meta["prefix"] = model_meta.get("prefix") + + if prop_meta.get("unwrapped", False): + # unwrapped could only set on array + wrapped_element.extend(_get_element(v, exclude_readonly, prop_meta)) + elif prop_meta.get("text", False): + # text could only set on primitive type + wrapped_element.text = _get_primitive_type_value(v) + elif prop_meta.get("attribute", False): + xml_name = prop_meta.get("name", k) + if prop_meta.get("ns"): + ET.register_namespace(prop_meta.get("prefix"), prop_meta.get("ns")) # pyright: ignore + xml_name = "{" + prop_meta.get("ns") + "}" + xml_name # pyright: ignore + # attribute should be primitive type + wrapped_element.set(xml_name, _get_primitive_type_value(v)) + else: + # other wrapped prop element + wrapped_element.append(_get_wrapped_element(v, exclude_readonly, prop_meta)) + return wrapped_element + if isinstance(o, list): + return [_get_element(x, exclude_readonly, parent_meta) for x in o] # type: ignore + if isinstance(o, dict): + result = [] + for k, v in o.items(): + result.append( + _get_wrapped_element( + v, + exclude_readonly, + { + "name": k, + "ns": parent_meta.get("ns") if parent_meta else None, + "prefix": parent_meta.get("prefix") if parent_meta else None, + }, + ) + ) + return result + + # primitive case need to create element based on parent_meta + if parent_meta: + return _get_wrapped_element( + o, + exclude_readonly, + { + "name": parent_meta.get("itemsName", parent_meta.get("name")), + "prefix": parent_meta.get("itemsPrefix", parent_meta.get("prefix")), + "ns": parent_meta.get("itemsNs", parent_meta.get("ns")), + }, + ) + + raise ValueError("Could not serialize value into xml: " + o) + + +def _get_wrapped_element( + v: typing.Any, + exclude_readonly: bool, + meta: typing.Optional[typing.Dict[str, typing.Any]], +) -> ET.Element: + wrapped_element = _create_xml_element( + meta.get("name") if meta else None, meta.get("prefix") if meta else None, meta.get("ns") if meta else None + ) + if isinstance(v, (dict, list)): + wrapped_element.extend(_get_element(v, exclude_readonly, meta)) + elif _is_model(v): + _get_element(v, exclude_readonly, meta, wrapped_element) + else: + wrapped_element.text = _get_primitive_type_value(v) + return wrapped_element + + +def _get_primitive_type_value(v) -> str: + if v is True: + return "true" + if v is False: + return "false" + if isinstance(v, _Null): + return "" + return str(v) + + +def _create_xml_element(tag, prefix=None, ns=None): + if prefix and ns: + ET.register_namespace(prefix, ns) + if ns: + return ET.Element("{" + ns + "}" + tag) + return ET.Element(tag) + + +def _deserialize_xml( + deserializer: typing.Any, + value: str, ) -> typing.Any: - return _RestField(name=name, type=type, is_discriminator=True, visibility=visibility) + element = ET.fromstring(value) # nosec + return _deserialize(deserializer, element) + + +def _convert_element(e: ET.Element): + # dict case + if len(e.attrib) > 0 or len({child.tag for child in e}) > 1: + dict_result: typing.Dict[str, typing.Any] = {} + for child in e: + if dict_result.get(child.tag) is not None: + if isinstance(dict_result[child.tag], list): + dict_result[child.tag].append(_convert_element(child)) + else: + dict_result[child.tag] = [dict_result[child.tag], _convert_element(child)] + else: + dict_result[child.tag] = _convert_element(child) + dict_result.update(e.attrib) + return dict_result + # array case + if len(e) > 0: + array_result: typing.List[typing.Any] = [] + for child in e: + array_result.append(_convert_element(child)) + return array_result + # primitive case + return e.text diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/_operations/__init__.py b/sdk/ai/azure-ai-inference/azure/ai/inference/_operations/__init__.py index d3ebd561f739..ab87088736aa 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/_operations/__init__.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/_operations/__init__.py @@ -5,13 +5,19 @@ # Code generated by Microsoft (R) Python Code Generator. # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- +# pylint: disable=wrong-import-position -from ._operations import ChatCompletionsClientOperationsMixin -from ._operations import EmbeddingsClientOperationsMixin -from ._operations import ImageEmbeddingsClientOperationsMixin +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ._patch import * # pylint: disable=unused-wildcard-import + +from ._operations import ChatCompletionsClientOperationsMixin # type: ignore +from ._operations import EmbeddingsClientOperationsMixin # type: ignore +from ._operations import ImageEmbeddingsClientOperationsMixin # type: ignore from ._patch import __all__ as _patch_all -from ._patch import * # pylint: disable=unused-wildcard-import +from ._patch import * from ._patch import patch_sdk as _patch_sdk __all__ = [ @@ -19,5 +25,5 @@ "EmbeddingsClientOperationsMixin", "ImageEmbeddingsClientOperationsMixin", ] -__all__.extend([p for p in _patch_all if p not in __all__]) +__all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/_operations/_operations.py b/sdk/ai/azure-ai-inference/azure/ai/inference/_operations/_operations.py index 3a24ee5736d3..525761e8b5ce 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/_operations/_operations.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/_operations/_operations.py @@ -1,4 +1,3 @@ -# pylint: disable=too-many-lines,too-many-statements # coding=utf-8 # -------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. @@ -9,7 +8,7 @@ from io import IOBase import json import sys -from typing import Any, Callable, Dict, IO, List, Optional, Type, TypeVar, Union, overload +from typing import Any, Callable, Dict, IO, List, Optional, TypeVar, Union, overload from azure.core.exceptions import ( ClientAuthenticationError, @@ -34,7 +33,7 @@ if sys.version_info >= (3, 9): from collections.abc import MutableMapping else: - from typing import MutableMapping # type: ignore # pylint: disable=ungrouped-imports + from typing import MutableMapping # type: ignore JSON = MutableMapping[str, Any] # pylint: disable=unsubscriptable-object _Unset: Any = object() T = TypeVar("T") @@ -194,23 +193,9 @@ def _complete( def _complete( self, *, - messages: List[_models.ChatRequestMessage], + chat_completions_options: _models._models.ChatCompletionsOptions, extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, content_type: str = "application/json", - frequency_penalty: Optional[float] = None, - stream_parameter: Optional[bool] = None, - presence_penalty: Optional[float] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, - response_format: Optional[_models.ChatCompletionsResponseFormat] = None, - stop: Optional[List[str]] = None, - tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, - tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] - ] = None, - seed: Optional[int] = None, - model: Optional[str] = None, **kwargs: Any ) -> _models.ChatCompletions: ... @overload @@ -228,25 +213,10 @@ def _complete( self, body: Union[JSON, IO[bytes]] = _Unset, *, - messages: List[_models.ChatRequestMessage] = _Unset, + chat_completions_options: _models._models.ChatCompletionsOptions = _Unset, extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, - frequency_penalty: Optional[float] = None, - stream_parameter: Optional[bool] = None, - presence_penalty: Optional[float] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, - response_format: Optional[_models.ChatCompletionsResponseFormat] = None, - stop: Optional[List[str]] = None, - tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, - tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] - ] = None, - seed: Optional[int] = None, - model: Optional[str] = None, **kwargs: Any ) -> _models.ChatCompletions: - # pylint: disable=too-many-locals """Gets chat completions for the provided chat messages. Completions support a wide variety of tasks and generate text that continues from or "completes" @@ -255,87 +225,16 @@ def _complete( :param body: Is either a JSON type or a IO[bytes] type. Required. :type body: JSON or IO[bytes] - :keyword messages: The collection of context messages associated with this chat completions - request. - Typical usage begins with a chat message for the System role that provides instructions for - the behavior of the assistant, followed by alternating messages between the User and - Assistant roles. Required. - :paramtype messages: list[~azure.ai.inference.models.ChatRequestMessage] - :keyword extra_params: Controls what happens if extra parameters, undefined by the REST API, - are passed in the JSON request payload. - This sets the HTTP request header ``extra-parameters``. Known values are: "error", "drop", and - "pass-through". Default value is None. + :keyword chat_completions_options: Required. + :paramtype chat_completions_options: ~azure.ai.inference.models._models.ChatCompletionsOptions + :keyword extra_params: Known values are: "error", "drop", and "pass-through". Default value is + None. :paramtype extra_params: str or ~azure.ai.inference.models.ExtraParameters - :keyword frequency_penalty: A value that influences the probability of generated tokens - appearing based on their cumulative - frequency in generated text. - Positive values will make tokens less likely to appear as their frequency increases and - decrease the likelihood of the model repeating the same statements verbatim. - Supported range is [-2, 2]. Default value is None. - :paramtype frequency_penalty: float - :keyword stream_parameter: A value indicating whether chat completions should be streamed for - this request. Default value is None. - :paramtype stream_parameter: bool - :keyword presence_penalty: A value that influences the probability of generated tokens - appearing based on their existing - presence in generated text. - Positive values will make tokens less likely to appear when they already exist and increase - the - model's likelihood to output new topics. - Supported range is [-2, 2]. Default value is None. - :paramtype presence_penalty: float - :keyword temperature: The sampling temperature to use that controls the apparent creativity of - generated completions. - Higher values will make output more random while lower values will make results more focused - and deterministic. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. Default value is None. - :paramtype temperature: float - :keyword top_p: An alternative to sampling with temperature called nucleus sampling. This value - causes the - model to consider the results of tokens with the provided probability mass. As an example, a - value of 0.15 will cause only the tokens comprising the top 15% of probability mass to be - considered. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. Default value is None. - :paramtype top_p: float - :keyword max_tokens: The maximum number of tokens to generate. Default value is None. - :paramtype max_tokens: int - :keyword response_format: The format that the model must output. Use this to enable JSON mode - instead of the default text mode. - Note that to enable JSON mode, some AI models may also require you to instruct the model to - produce JSON - via a system or user message. Default value is None. - :paramtype response_format: ~azure.ai.inference.models.ChatCompletionsResponseFormat - :keyword stop: A collection of textual sequences that will end completions generation. Default - value is None. - :paramtype stop: list[str] - :keyword tools: A list of tools the model may request to call. Currently, only functions are - supported as a tool. The model - may response with a function call request and provide the input arguments in JSON format for - that function. Default value is None. - :paramtype tools: list[~azure.ai.inference.models.ChatCompletionsToolDefinition] - :keyword tool_choice: If specified, the model will configure which of the provided tools it can - use for the chat completions response. Is either a Union[str, - "_models.ChatCompletionsToolChoicePreset"] type or a ChatCompletionsNamedToolChoice type. - Default value is None. - :paramtype tool_choice: str or ~azure.ai.inference.models.ChatCompletionsToolChoicePreset or - ~azure.ai.inference.models.ChatCompletionsNamedToolChoice - :keyword seed: If specified, the system will make a best effort to sample deterministically - such that repeated requests with the - same seed and parameters should return the same result. Determinism is not guaranteed. Default - value is None. - :paramtype seed: int - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str :return: ChatCompletions. The ChatCompletions is compatible with MutableMapping :rtype: ~azure.ai.inference.models.ChatCompletions :raises ~azure.core.exceptions.HttpResponseError: """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { + error_map: MutableMapping = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError, @@ -350,23 +249,9 @@ def _complete( cls: ClsType[_models.ChatCompletions] = kwargs.pop("cls", None) if body is _Unset: - if messages is _Unset: - raise TypeError("missing required argument: messages") - body = { - "frequency_penalty": frequency_penalty, - "max_tokens": max_tokens, - "messages": messages, - "model": model, - "presence_penalty": presence_penalty, - "response_format": response_format, - "seed": seed, - "stop": stop, - "stream": stream_parameter, - "temperature": temperature, - "tool_choice": tool_choice, - "tools": tools, - "top_p": top_p, - } + if chat_completions_options is _Unset: + raise TypeError("missing required argument: chat_completions_options") + body = {"chatCompletionsOptions": chat_completions_options} body = {k: v for k, v in body.items() if v is not None} content_type = content_type or "application/json" _content = None @@ -418,14 +303,12 @@ def _complete( def _get_model_info(self, **kwargs: Any) -> _models.ModelInfo: """Returns information about the AI model. The method makes a REST API call to the ``/info`` route on the given endpoint. - This method will only work when using Serverless API or Managed Compute endpoint. - It will not work for GitHub Models endpoint or Azure OpenAI endpoint. :return: ModelInfo. The ModelInfo is compatible with MutableMapping :rtype: ~azure.ai.inference.models.ModelInfo :raises ~azure.core.exceptions.HttpResponseError: """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { + error_map: MutableMapping = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError, @@ -556,7 +439,7 @@ def _embed( :rtype: ~azure.ai.inference.models.EmbeddingsResult :raises ~azure.core.exceptions.HttpResponseError: """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { + error_map: MutableMapping = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError, @@ -631,14 +514,12 @@ def _embed( def _get_model_info(self, **kwargs: Any) -> _models.ModelInfo: """Returns information about the AI model. The method makes a REST API call to the ``/info`` route on the given endpoint. - This method will only work when using Serverless API or Managed Compute endpoint. - It will not work for GitHub Models endpoint or Azure OpenAI endpoint. :return: ModelInfo. The ModelInfo is compatible with MutableMapping :rtype: ~azure.ai.inference.models.ModelInfo :raises ~azure.core.exceptions.HttpResponseError: """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { + error_map: MutableMapping = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError, @@ -772,7 +653,7 @@ def _embed( :rtype: ~azure.ai.inference.models.EmbeddingsResult :raises ~azure.core.exceptions.HttpResponseError: """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { + error_map: MutableMapping = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError, @@ -847,14 +728,12 @@ def _embed( def _get_model_info(self, **kwargs: Any) -> _models.ModelInfo: """Returns information about the AI model. The method makes a REST API call to the ``/info`` route on the given endpoint. - This method will only work when using Serverless API or Managed Compute endpoint. - It will not work for GitHub Models endpoint or Azure OpenAI endpoint. :return: ModelInfo. The ModelInfo is compatible with MutableMapping :rtype: ~azure.ai.inference.models.ModelInfo :raises ~azure.core.exceptions.HttpResponseError: """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { + error_map: MutableMapping = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError, diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/_patch.py b/sdk/ai/azure-ai-inference/azure/ai/inference/_patch.py index 050a8d1ab96c..f7dd32510333 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/_patch.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/_patch.py @@ -2,1285 +2,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -# pylint: disable=too-many-lines) """Customize generated code here. Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize - -Why do we patch auto-generated code? -1. Add support for input argument `model_extras` (all clients) -2. Add support for function load_client -3. Add support for setting sticky chat completions/embeddings input arguments in the client constructor -4. Add support for get_model_info, while caching the result (all clients) -5. Add support for chat completion streaming (ChatCompletionsClient client only) -6. Add support for friendly print of result objects (__str__ method) (all clients) -7. Add support for load() method in ImageUrl class (see /models/_patch.py) -8. Add support for sending two auth headers for api-key auth (all clients) - """ -import json -import logging -import sys - -from io import IOBase -from typing import Any, Dict, Union, IO, List, Literal, Optional, overload, Type, TYPE_CHECKING, Iterable - -from azure.core.pipeline import PipelineResponse -from azure.core.credentials import AzureKeyCredential -from azure.core.tracing.decorator import distributed_trace -from azure.core.utils import case_insensitive_dict -from azure.core.exceptions import ( - ClientAuthenticationError, - HttpResponseError, - map_error, - ResourceExistsError, - ResourceNotFoundError, - ResourceNotModifiedError, -) -from . import models as _models -from ._model_base import SdkJSONEncoder, _deserialize -from ._serialization import Serializer -from ._operations._operations import ( - build_chat_completions_complete_request, - build_embeddings_embed_request, - build_image_embeddings_embed_request, -) -from ._client import ChatCompletionsClient as ChatCompletionsClientGenerated -from ._client import EmbeddingsClient as EmbeddingsClientGenerated -from ._client import ImageEmbeddingsClient as ImageEmbeddingsClientGenerated - -if sys.version_info >= (3, 9): - from collections.abc import MutableMapping -else: - from typing import MutableMapping # type: ignore # pylint: disable=ungrouped-imports - -if TYPE_CHECKING: - # pylint: disable=unused-import,ungrouped-imports - from azure.core.credentials import TokenCredential - -JSON = MutableMapping[str, Any] # pylint: disable=unsubscriptable-object -_Unset: Any = object() - -_SERIALIZER = Serializer() -_SERIALIZER.client_side_validation = False - -_LOGGER = logging.getLogger(__name__) - - -def load_client( - endpoint: str, credential: Union[AzureKeyCredential, "TokenCredential"], **kwargs: Any -) -> Union["ChatCompletionsClient", "EmbeddingsClient", "ImageEmbeddingsClient"]: - """ - Load a client from a given endpoint URL. The method makes a REST API call to the `/info` route - on the given endpoint, to determine the model type and therefore which client to instantiate. - This method will only work when using Serverless API or Managed Compute endpoint. - It will not work for GitHub Models endpoint or Azure OpenAI endpoint. - - :param endpoint: Service host. Required. - :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is either a - AzureKeyCredential type or a TokenCredential type. Required. - :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials.TokenCredential - :keyword api_version: The API version to use for this operation. Default value is - "2024-05-01-preview". Note that overriding this default value may result in unsupported - behavior. - :paramtype api_version: str - :return: The appropriate synchronous client associated with the given endpoint - :rtype: ~azure.ai.inference.ChatCompletionsClient or ~azure.ai.inference.EmbeddingsClient - or ~azure.ai.inference.ImageEmbeddingsClient - :raises ~azure.core.exceptions.HttpResponseError: - """ - - with ChatCompletionsClient( - endpoint, credential, **kwargs - ) as client: # Pick any of the clients, it does not matter. - model_info = client.get_model_info() # type: ignore - - _LOGGER.info("model_info=%s", model_info) - if not model_info.model_type: - raise ValueError( - "The AI model information is missing a value for `model type`. Cannot create an appropriate client." - ) - - # TODO: Remove "completions", "chat-comletions" and "embedding" once Mistral Large and Cohere fixes their model type - if model_info.model_type in (_models.ModelType.CHAT, "completion", "chat-completion", "chat-completions"): - chat_completion_client = ChatCompletionsClient(endpoint, credential, **kwargs) - chat_completion_client._model_info = ( # pylint: disable=protected-access,attribute-defined-outside-init - model_info - ) - return chat_completion_client - - if model_info.model_type in (_models.ModelType.EMBEDDINGS, "embedding"): - embedding_client = EmbeddingsClient(endpoint, credential, **kwargs) - embedding_client._model_info = model_info # pylint: disable=protected-access,attribute-defined-outside-init - return embedding_client - - if model_info.model_type == _models.ModelType.IMAGE_EMBEDDINGS: - image_embedding_client = ImageEmbeddingsClient(endpoint, credential, **kwargs) - image_embedding_client._model_info = ( # pylint: disable=protected-access,attribute-defined-outside-init - model_info - ) - return image_embedding_client - - raise ValueError(f"No client available to support AI model type `{model_info.model_type}`") - - -class ChatCompletionsClient(ChatCompletionsClientGenerated): # pylint: disable=too-many-instance-attributes - """ChatCompletionsClient. - - :param endpoint: Service host. Required. - :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is either a - AzureKeyCredential type or a TokenCredential type. Required. - :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials.TokenCredential - :keyword frequency_penalty: A value that influences the probability of generated tokens - appearing based on their cumulative frequency in generated text. - Positive values will make tokens less likely to appear as their frequency increases and - decrease the likelihood of the model repeating the same statements verbatim. - Supported range is [-2, 2]. - Default value is None. - :paramtype frequency_penalty: float - :keyword presence_penalty: A value that influences the probability of generated tokens - appearing based on their existing - presence in generated text. - Positive values will make tokens less likely to appear when they already exist and increase - the model's likelihood to output new topics. - Supported range is [-2, 2]. - Default value is None. - :paramtype presence_penalty: float - :keyword temperature: The sampling temperature to use that controls the apparent creativity of - generated completions. - Higher values will make output more random while lower values will make results more focused - and deterministic. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. - Default value is None. - :paramtype temperature: float - :keyword top_p: An alternative to sampling with temperature called nucleus sampling. This value - causes the - model to consider the results of tokens with the provided probability mass. As an example, a - value of 0.15 will cause only the tokens comprising the top 15% of probability mass to be - considered. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. - Default value is None. - :paramtype top_p: float - :keyword max_tokens: The maximum number of tokens to generate. Default value is None. - :paramtype max_tokens: int - :keyword response_format: The format that the model must output. Use this to enable JSON mode - instead of the default text mode. - Note that to enable JSON mode, some AI models may also require you to instruct the model to - produce JSON via a system or user message. Default value is None. - :paramtype response_format: ~azure.ai.inference.models.ChatCompletionsResponseFormat - :keyword stop: A collection of textual sequences that will end completions generation. Default - value is None. - :paramtype stop: list[str] - :keyword tools: The available tool definitions that the chat completions request can use, - including caller-defined functions. Default value is None. - :paramtype tools: list[~azure.ai.inference.models.ChatCompletionsToolDefinition] - :keyword tool_choice: If specified, the model will configure which of the provided tools it can - use for the chat completions response. Is either a Union[str, - "_models.ChatCompletionsToolChoicePreset"] type or a ChatCompletionsNamedToolChoice type. - Default value is None. - :paramtype tool_choice: str or ~azure.ai.inference.models.ChatCompletionsToolChoicePreset or - ~azure.ai.inference.models.ChatCompletionsNamedToolChoice - :keyword seed: If specified, the system will make a best effort to sample deterministically - such that repeated requests with the - same seed and parameters should return the same result. Determinism is not guaranteed. - Default value is None. - :paramtype seed: int - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :keyword api_version: The API version to use for this operation. Default value is - "2024-05-01-preview". Note that overriding this default value may result in unsupported - behavior. - :paramtype api_version: str - """ - - def __init__( - self, - endpoint: str, - credential: Union[AzureKeyCredential, "TokenCredential"], - *, - frequency_penalty: Optional[float] = None, - presence_penalty: Optional[float] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, - response_format: Optional[_models.ChatCompletionsResponseFormat] = None, - stop: Optional[List[str]] = None, - tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, - tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] - ] = None, - seed: Optional[int] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> None: - - self._model_info: Optional[_models.ModelInfo] = None - - # Store default chat completions settings, to be applied in all future service calls - # unless overridden by arguments in the `complete` method. - self._frequency_penalty = frequency_penalty - self._presence_penalty = presence_penalty - self._temperature = temperature - self._top_p = top_p - self._max_tokens = max_tokens - self._response_format = response_format - self._stop = stop - self._tools = tools - self._tool_choice = tool_choice - self._seed = seed - self._model = model - self._model_extras = model_extras - - # For Key auth, we need to send these two auth HTTP request headers simultaneously: - # 1. "Authorization: Bearer " - # 2. "api-key: " - # This is because Serverless API, Managed Compute and GitHub endpoints support the first header, - # and Azure OpenAI and the new Unified Inference endpoints support the second header. - # The first header will be taken care of by auto-generated code. - # The second one is added here. - if isinstance(credential, AzureKeyCredential): - headers = kwargs.pop("headers", {}) - if "api-key" not in headers: - headers["api-key"] = credential.key - kwargs["headers"] = headers - - super().__init__(endpoint, credential, **kwargs) - - - @overload - def complete( - self, - *, - messages: List[_models.ChatRequestMessage], - stream: Literal[False] = False, - frequency_penalty: Optional[float] = None, - presence_penalty: Optional[float] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, - response_format: Optional[_models.ChatCompletionsResponseFormat] = None, - stop: Optional[List[str]] = None, - tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, - tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] - ] = None, - seed: Optional[int] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> _models.ChatCompletions: ... - - @overload - def complete( - self, - *, - messages: List[_models.ChatRequestMessage], - stream: Literal[True], - frequency_penalty: Optional[float] = None, - presence_penalty: Optional[float] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, - response_format: Optional[_models.ChatCompletionsResponseFormat] = None, - stop: Optional[List[str]] = None, - tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, - tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] - ] = None, - seed: Optional[int] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> Iterable[_models.StreamingChatCompletionsUpdate]: ... - - @overload - def complete( - self, - *, - messages: List[_models.ChatRequestMessage], - stream: Optional[bool] = None, - frequency_penalty: Optional[float] = None, - presence_penalty: Optional[float] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, - response_format: Optional[_models.ChatCompletionsResponseFormat] = None, - stop: Optional[List[str]] = None, - tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, - tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] - ] = None, - seed: Optional[int] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: - # pylint: disable=line-too-long - """Gets chat completions for the provided chat messages. - Completions support a wide variety of tasks and generate text that continues from or - "completes" provided prompt data. The method makes a REST API call to the `/chat/completions` route - on the given endpoint. - When using this method with `stream=True`, the response is streamed - back to the client. Iterate over the resulting StreamingChatCompletions - object to get content updates as they arrive. By default, the response is a ChatCompletions object - (non-streaming). - - :keyword messages: The collection of context messages associated with this chat completions - request. - Typical usage begins with a chat message for the System role that provides instructions for - the behavior of the assistant, followed by alternating messages between the User and - Assistant roles. Required. - :paramtype messages: list[~azure.ai.inference.models.ChatRequestMessage] - :keyword stream: A value indicating whether chat completions should be streamed for this request. - Default value is False. If streaming is enabled, the response will be a StreamingChatCompletions. - Otherwise the response will be a ChatCompletions. - :paramtype stream: bool - :keyword frequency_penalty: A value that influences the probability of generated tokens - appearing based on their cumulative frequency in generated text. - Positive values will make tokens less likely to appear as their frequency increases and - decrease the likelihood of the model repeating the same statements verbatim. - Supported range is [-2, 2]. - Default value is None. - :paramtype frequency_penalty: float - :keyword presence_penalty: A value that influences the probability of generated tokens - appearing based on their existing - presence in generated text. - Positive values will make tokens less likely to appear when they already exist and increase - the model's likelihood to output new topics. - Supported range is [-2, 2]. - Default value is None. - :paramtype presence_penalty: float - :keyword temperature: The sampling temperature to use that controls the apparent creativity of - generated completions. - Higher values will make output more random while lower values will make results more focused - and deterministic. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. - Default value is None. - :paramtype temperature: float - :keyword top_p: An alternative to sampling with temperature called nucleus sampling. This value - causes the - model to consider the results of tokens with the provided probability mass. As an example, a - value of 0.15 will cause only the tokens comprising the top 15% of probability mass to be - considered. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. - Default value is None. - :paramtype top_p: float - :keyword max_tokens: The maximum number of tokens to generate. Default value is None. - :paramtype max_tokens: int - :keyword response_format: The format that the model must output. Use this to enable JSON mode - instead of the default text mode. - Note that to enable JSON mode, some AI models may also require you to instruct the model to - produce JSON via a system or user message. Default value is None. - :paramtype response_format: ~azure.ai.inference.models.ChatCompletionsResponseFormat - :keyword stop: A collection of textual sequences that will end completions generation. Default - value is None. - :paramtype stop: list[str] - :keyword tools: The available tool definitions that the chat completions request can use, - including caller-defined functions. Default value is None. - :paramtype tools: list[~azure.ai.inference.models.ChatCompletionsToolDefinition] - :keyword tool_choice: If specified, the model will configure which of the provided tools it can - use for the chat completions response. Is either a Union[str, - "_models.ChatCompletionsToolChoicePreset"] type or a ChatCompletionsNamedToolChoice type. - Default value is None. - :paramtype tool_choice: str or ~azure.ai.inference.models.ChatCompletionsToolChoicePreset or - ~azure.ai.inference.models.ChatCompletionsNamedToolChoice - :keyword seed: If specified, the system will make a best effort to sample deterministically - such that repeated requests with the - same seed and parameters should return the same result. Determinism is not guaranteed. - Default value is None. - :paramtype seed: int - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :return: ChatCompletions for non-streaming, or Iterable[StreamingChatCompletionsUpdate] for streaming. - :rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.StreamingChatCompletions - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - def complete( - self, - body: JSON, - *, - content_type: str = "application/json", - **kwargs: Any, - ) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: - # pylint: disable=line-too-long - """Gets chat completions for the provided chat messages. - Completions support a wide variety of tasks and generate text that continues from or - "completes" provided prompt data. - - :param body: An object of type MutableMapping[str, Any], such as a dictionary, that - specifies the full request payload. Required. - :type body: JSON - :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. - Default value is "application/json". - :paramtype content_type: str - :return: ChatCompletions for non-streaming, or Iterable[StreamingChatCompletionsUpdate] for streaming. - :rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.StreamingChatCompletions - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - def complete( - self, - body: IO[bytes], - *, - content_type: str = "application/json", - **kwargs: Any, - ) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: - # pylint: disable=line-too-long - # pylint: disable=too-many-locals - """Gets chat completions for the provided chat messages. - Completions support a wide variety of tasks and generate text that continues from or - "completes" provided prompt data. - - :param body: Specifies the full request payload. Required. - :type body: IO[bytes] - :keyword content_type: Body Parameter content-type. Content type parameter for binary body. - Default value is "application/json". - :paramtype content_type: str - :return: ChatCompletions for non-streaming, or Iterable[StreamingChatCompletionsUpdate] for streaming. - :rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.StreamingChatCompletions - :raises ~azure.core.exceptions.HttpResponseError: - """ - - # pylint:disable=client-method-missing-tracing-decorator - def complete( - self, - body: Union[JSON, IO[bytes]] = _Unset, - *, - messages: List[_models.ChatRequestMessage] = _Unset, - stream: Optional[bool] = None, - frequency_penalty: Optional[float] = None, - presence_penalty: Optional[float] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, - response_format: Optional[_models.ChatCompletionsResponseFormat] = None, - stop: Optional[List[str]] = None, - tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, - tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] - ] = None, - seed: Optional[int] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: - # pylint: disable=line-too-long - # pylint: disable=too-many-locals - """Gets chat completions for the provided chat messages. - Completions support a wide variety of tasks and generate text that continues from or - "completes" provided prompt data. When using this method with `stream=True`, the response is streamed - back to the client. Iterate over the resulting :class:`~azure.ai.inference.models.StreamingChatCompletions` - object to get content updates as they arrive. - - :param body: Is either a MutableMapping[str, Any] type (like a dictionary) or a IO[bytes] type - that specifies the full request payload. Required. - :type body: JSON or IO[bytes] - :keyword messages: The collection of context messages associated with this chat completions - request. - Typical usage begins with a chat message for the System role that provides instructions for - the behavior of the assistant, followed by alternating messages between the User and - Assistant roles. Required. - :paramtype messages: list[~azure.ai.inference.models.ChatRequestMessage] - :keyword stream: A value indicating whether chat completions should be streamed for this request. - Default value is False. If streaming is enabled, the response will be a StreamingChatCompletions. - Otherwise the response will be a ChatCompletions. - :paramtype stream: bool - :keyword frequency_penalty: A value that influences the probability of generated tokens - appearing based on their cumulative frequency in generated text. - Positive values will make tokens less likely to appear as their frequency increases and - decrease the likelihood of the model repeating the same statements verbatim. - Supported range is [-2, 2]. - Default value is None. - :paramtype frequency_penalty: float - :keyword presence_penalty: A value that influences the probability of generated tokens - appearing based on their existing - presence in generated text. - Positive values will make tokens less likely to appear when they already exist and increase - the model's likelihood to output new topics. - Supported range is [-2, 2]. - Default value is None. - :paramtype presence_penalty: float - :keyword temperature: The sampling temperature to use that controls the apparent creativity of - generated completions. - Higher values will make output more random while lower values will make results more focused - and deterministic. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. - Default value is None. - :paramtype temperature: float - :keyword top_p: An alternative to sampling with temperature called nucleus sampling. This value - causes the - model to consider the results of tokens with the provided probability mass. As an example, a - value of 0.15 will cause only the tokens comprising the top 15% of probability mass to be - considered. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. - Default value is None. - :paramtype top_p: float - :keyword max_tokens: The maximum number of tokens to generate. Default value is None. - :paramtype max_tokens: int - :keyword response_format: The format that the model must output. Use this to enable JSON mode - instead of the default text mode. - Note that to enable JSON mode, some AI models may also require you to instruct the model to - produce JSON via a system or user message. Default value is None. - :paramtype response_format: ~azure.ai.inference.models.ChatCompletionsResponseFormat - :keyword stop: A collection of textual sequences that will end completions generation. Default - value is None. - :paramtype stop: list[str] - :keyword tools: The available tool definitions that the chat completions request can use, - including caller-defined functions. Default value is None. - :paramtype tools: list[~azure.ai.inference.models.ChatCompletionsToolDefinition] - :keyword tool_choice: If specified, the model will configure which of the provided tools it can - use for the chat completions response. Is either a Union[str, - "_models.ChatCompletionsToolChoicePreset"] type or a ChatCompletionsNamedToolChoice type. - Default value is None. - :paramtype tool_choice: str or ~azure.ai.inference.models.ChatCompletionsToolChoicePreset or - ~azure.ai.inference.models.ChatCompletionsNamedToolChoice - :keyword seed: If specified, the system will make a best effort to sample deterministically - such that repeated requests with the - same seed and parameters should return the same result. Determinism is not guaranteed. - Default value is None. - :paramtype seed: int - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :return: ChatCompletions for non-streaming, or Iterable[StreamingChatCompletionsUpdate] for streaming. - :rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.StreamingChatCompletions - :raises ~azure.core.exceptions.HttpResponseError: - """ - error_map = { - 401: ClientAuthenticationError, - 404: ResourceNotFoundError, - 409: ResourceExistsError, - 304: ResourceNotModifiedError, - } - error_map.update(kwargs.pop("error_map", {}) or {}) - - _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) - _params = kwargs.pop("params", {}) or {} - _extra_parameters: Union[_models._enums.ExtraParameters, None] = None - - content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - - if body is _Unset: - if messages is _Unset: - raise TypeError("missing required argument: messages") - body = { - "messages": messages, - "stream": stream, - "frequency_penalty": frequency_penalty if frequency_penalty is not None else self._frequency_penalty, - "max_tokens": max_tokens if max_tokens is not None else self._max_tokens, - "model": model if model is not None else self._model, - "presence_penalty": presence_penalty if presence_penalty is not None else self._presence_penalty, - "response_format": response_format if response_format is not None else self._response_format, - "seed": seed if seed is not None else self._seed, - "stop": stop if stop is not None else self._stop, - "temperature": temperature if temperature is not None else self._temperature, - "tool_choice": tool_choice if tool_choice is not None else self._tool_choice, - "tools": tools if tools is not None else self._tools, - "top_p": top_p if top_p is not None else self._top_p, - } - if model_extras is not None and bool(model_extras): - body.update(model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access - elif self._model_extras is not None and bool(self._model_extras): - body.update(self._model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access - body = {k: v for k, v in body.items() if v is not None} - elif isinstance(body, dict) and "stream" in body and isinstance(body["stream"], bool): - stream = body["stream"] - content_type = content_type or "application/json" - _content = None - if isinstance(body, (IOBase, bytes)): - _content = body - else: - _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore - - _request = build_chat_completions_complete_request( - extra_params=_extra_parameters, - content_type=content_type, - api_version=self._config.api_version, - content=_content, - headers=_headers, - params=_params, - ) - path_format_arguments = { - "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), - } - _request.url = self._client.format_url(_request.url, **path_format_arguments) - - _stream = stream or False - pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access - _request, stream=_stream, **kwargs - ) - - response = pipeline_response.http_response - - if response.status_code not in [200]: - if _stream: - response.read() # Load the body in memory and close the socket - map_error(status_code=response.status_code, response=response, error_map=error_map) - raise HttpResponseError(response=response) - - if _stream: - return _models.StreamingChatCompletions(response) - - return _deserialize(_models._patch.ChatCompletions, response.json()) # pylint: disable=protected-access - - @distributed_trace - def get_model_info(self, **kwargs: Any) -> _models.ModelInfo: - # pylint: disable=line-too-long - """Returns information about the AI model. - The method makes a REST API call to the ``/info`` route on the given endpoint. - This method will only work when using Serverless API or Managed Compute endpoint. - It will not work for GitHub Models endpoint or Azure OpenAI endpoint. - - :return: ModelInfo. The ModelInfo is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.ModelInfo - :raises ~azure.core.exceptions.HttpResponseError: - """ - if not self._model_info: - self._model_info = self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init - return self._model_info - - def __str__(self) -> str: - # pylint: disable=client-method-name-no-double-underscore - return super().__str__() + f"\n{self._model_info}" if self._model_info else super().__str__() - - -class EmbeddingsClient(EmbeddingsClientGenerated): - """EmbeddingsClient. - - :param endpoint: Service host. Required. - :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is either a - AzureKeyCredential type or a TokenCredential type. Required. - :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials.TokenCredential - :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should - have. Default value is None. - :paramtype dimensions: int - :keyword encoding_format: Optional. The desired format for the returned embeddings. - Known values are: - "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. - :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat - :keyword input_type: Optional. The type of the input. Known values are: - "text", "query", and "document". Default value is None. - :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :keyword api_version: The API version to use for this operation. Default value is - "2024-05-01-preview". Note that overriding this default value may result in unsupported - behavior. - :paramtype api_version: str - """ - - def __init__( - self, - endpoint: str, - credential: Union[AzureKeyCredential, "TokenCredential"], - *, - dimensions: Optional[int] = None, - encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, - input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> None: - - self._model_info: Optional[_models.ModelInfo] = None - - # Store default embeddings settings, to be applied in all future service calls - # unless overridden by arguments in the `embed` method. - self._dimensions = dimensions - self._encoding_format = encoding_format - self._input_type = input_type - self._model = model - self._model_extras = model_extras - - # For Key auth, we need to send these two auth HTTP request headers simultaneously: - # 1. "Authorization: Bearer " - # 2. "api-key: " - # This is because Serverless API, Managed Compute and GitHub endpoints support the first header, - # and Azure OpenAI and the new Unified Inference endpoints support the second header. - # The first header will be taken care of by auto-generated code. - # The second one is added here. - if isinstance(credential, AzureKeyCredential): - headers = kwargs.pop("headers", {}) - if "api-key" not in headers: - headers["api-key"] = credential.key - kwargs["headers"] = headers - - super().__init__(endpoint, credential, **kwargs) - - @overload - def embed( - self, - *, - input: List[str], - dimensions: Optional[int] = None, - encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, - input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> _models.EmbeddingsResult: - """Return the embedding vectors for given text prompts. - The method makes a REST API call to the `/embeddings` route on the given endpoint. - - :keyword input: Input text to embed, encoded as a string or array of tokens. - To embed multiple inputs in a single request, pass an array - of strings or array of token arrays. Required. - :paramtype input: list[str] - :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should - have. Default value is None. - :paramtype dimensions: int - :keyword encoding_format: Optional. The desired format for the returned embeddings. - Known values are: - "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. - :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat - :keyword input_type: Optional. The type of the input. Known values are: - "text", "query", and "document". Default value is None. - :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - def embed( - self, - body: JSON, - *, - content_type: str = "application/json", - **kwargs: Any, - ) -> _models.EmbeddingsResult: - """Return the embedding vectors for given text prompts. - The method makes a REST API call to the `/embeddings` route on the given endpoint. - - :param body: An object of type MutableMapping[str, Any], such as a dictionary, that - specifies the full request payload. Required. - :type body: JSON - :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. - Default value is "application/json". - :paramtype content_type: str - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - def embed( - self, - body: IO[bytes], - *, - content_type: str = "application/json", - **kwargs: Any, - ) -> _models.EmbeddingsResult: - """Return the embedding vectors for given text prompts. - The method makes a REST API call to the `/embeddings` route on the given endpoint. - - :param body: Specifies the full request payload. Required. - :type body: IO[bytes] - :keyword content_type: Body Parameter content-type. Content type parameter for binary body. - Default value is "application/json". - :paramtype content_type: str - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @distributed_trace - def embed( - self, - body: Union[JSON, IO[bytes]] = _Unset, - *, - input: List[str] = _Unset, - dimensions: Optional[int] = None, - encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, - input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> _models.EmbeddingsResult: - # pylint: disable=line-too-long - """Return the embedding vectors for given text prompts. - The method makes a REST API call to the `/embeddings` route on the given endpoint. - - :param body: Is either a MutableMapping[str, Any] type (like a dictionary) or a IO[bytes] type - that specifies the full request payload. Required. - :type body: JSON or IO[bytes] - :keyword input: Input text to embed, encoded as a string or array of tokens. - To embed multiple inputs in a single request, pass an array - of strings or array of token arrays. Required. - :paramtype input: list[str] - :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should - have. Default value is None. - :paramtype dimensions: int - :keyword encoding_format: Optional. The desired format for the returned embeddings. - Known values are: - "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. - :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat - :keyword input_type: Optional. The type of the input. Known values are: - "text", "query", and "document". Default value is None. - :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { - 401: ClientAuthenticationError, - 404: ResourceNotFoundError, - 409: ResourceExistsError, - 304: ResourceNotModifiedError, - } - error_map.update(kwargs.pop("error_map", {}) or {}) - - _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) - _params = kwargs.pop("params", {}) or {} - _extra_parameters: Union[_models._enums.ExtraParameters, None] = None - - content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - - if body is _Unset: - if input is _Unset: - raise TypeError("missing required argument: input") - body = { - "input": input, - "dimensions": dimensions if dimensions is not None else self._dimensions, - "encoding_format": encoding_format if encoding_format is not None else self._encoding_format, - "input_type": input_type if input_type is not None else self._input_type, - "model": model if model is not None else self._model, - } - if model_extras is not None and bool(model_extras): - body.update(model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access - elif self._model_extras is not None and bool(self._model_extras): - body.update(self._model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access - body = {k: v for k, v in body.items() if v is not None} - content_type = content_type or "application/json" - _content = None - if isinstance(body, (IOBase, bytes)): - _content = body - else: - _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore - - _request = build_embeddings_embed_request( - extra_params=_extra_parameters, - content_type=content_type, - api_version=self._config.api_version, - content=_content, - headers=_headers, - params=_params, - ) - path_format_arguments = { - "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), - } - _request.url = self._client.format_url(_request.url, **path_format_arguments) - - _stream = kwargs.pop("stream", False) - pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access - _request, stream=_stream, **kwargs - ) - - response = pipeline_response.http_response - - if response.status_code not in [200]: - if _stream: - response.read() # Load the body in memory and close the socket - map_error(status_code=response.status_code, response=response, error_map=error_map) - raise HttpResponseError(response=response) - - if _stream: - deserialized = response.iter_bytes() - else: - deserialized = _deserialize( - _models._patch.EmbeddingsResult, response.json() # pylint: disable=protected-access - ) - - return deserialized # type: ignore - - @distributed_trace - def get_model_info(self, **kwargs: Any) -> _models.ModelInfo: - # pylint: disable=line-too-long - """Returns information about the AI model. - The method makes a REST API call to the ``/info`` route on the given endpoint. - This method will only work when using Serverless API or Managed Compute endpoint. - It will not work for GitHub Models endpoint or Azure OpenAI endpoint. - - :return: ModelInfo. The ModelInfo is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.ModelInfo - :raises ~azure.core.exceptions.HttpResponseError: - """ - if not self._model_info: - self._model_info = self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init - return self._model_info - - def __str__(self) -> str: - # pylint: disable=client-method-name-no-double-underscore - return super().__str__() + f"\n{self._model_info}" if self._model_info else super().__str__() - - -class ImageEmbeddingsClient(ImageEmbeddingsClientGenerated): - """ImageEmbeddingsClient. - - :param endpoint: Service host. Required. - :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is either a - AzureKeyCredential type or a TokenCredential type. Required. - :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials.TokenCredential - :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should - have. Default value is None. - :paramtype dimensions: int - :keyword encoding_format: Optional. The desired format for the returned embeddings. - Known values are: - "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. - :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat - :keyword input_type: Optional. The type of the input. Known values are: - "text", "query", and "document". Default value is None. - :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :keyword api_version: The API version to use for this operation. Default value is - "2024-05-01-preview". Note that overriding this default value may result in unsupported - behavior. - :paramtype api_version: str - """ - - def __init__( - self, - endpoint: str, - credential: Union[AzureKeyCredential, "TokenCredential"], - *, - dimensions: Optional[int] = None, - encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, - input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> None: - - self._model_info: Optional[_models.ModelInfo] = None - - # Store default embeddings settings, to be applied in all future service calls - # unless overridden by arguments in the `embed` method. - self._dimensions = dimensions - self._encoding_format = encoding_format - self._input_type = input_type - self._model = model - self._model_extras = model_extras - - # For Key auth, we need to send these two auth HTTP request headers simultaneously: - # 1. "Authorization: Bearer " - # 2. "api-key: " - # This is because Serverless API, Managed Compute and GitHub endpoints support the first header, - # and Azure OpenAI and the new Unified Inference endpoints support the second header. - # The first header will be taken care of by auto-generated code. - # The second one is added here. - if isinstance(credential, AzureKeyCredential): - headers = kwargs.pop("headers", {}) - if "api-key" not in headers: - headers["api-key"] = credential.key - kwargs["headers"] = headers - - super().__init__(endpoint, credential, **kwargs) - - @overload - def embed( - self, - *, - input: List[_models.EmbeddingInput], - dimensions: Optional[int] = None, - encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, - input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> _models.EmbeddingsResult: - """Return the embedding vectors for given images. - The method makes a REST API call to the `/images/embeddings` route on the given endpoint. - - :keyword input: Input image to embed. To embed multiple inputs in a single request, pass an - array. - The input must not exceed the max input tokens for the model. Required. - :paramtype input: list[~azure.ai.inference.models.EmbeddingInput] - :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should - have. Default value is None. - :paramtype dimensions: int - :keyword encoding_format: Optional. The desired format for the returned embeddings. - Known values are: - "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. - :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat - :keyword input_type: Optional. The type of the input. Known values are: - "text", "query", and "document". Default value is None. - :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - def embed( - self, - body: JSON, - *, - content_type: str = "application/json", - **kwargs: Any, - ) -> _models.EmbeddingsResult: - """Return the embedding vectors for given images. - The method makes a REST API call to the `/images/embeddings` route on the given endpoint. - - :param body: An object of type MutableMapping[str, Any], such as a dictionary, that - specifies the full request payload. Required. - :type body: JSON - :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. - Default value is "application/json". - :paramtype content_type: str - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - def embed( - self, - body: IO[bytes], - *, - content_type: str = "application/json", - **kwargs: Any, - ) -> _models.EmbeddingsResult: - """Return the embedding vectors for given images. - The method makes a REST API call to the `/images/embeddings` route on the given endpoint. - - :param body: Specifies the full request payload. Required. - :type body: IO[bytes] - :keyword content_type: Body Parameter content-type. Content type parameter for binary body. - Default value is "application/json". - :paramtype content_type: str - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @distributed_trace - def embed( - self, - body: Union[JSON, IO[bytes]] = _Unset, - *, - input: List[_models.EmbeddingInput] = _Unset, - dimensions: Optional[int] = None, - encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, - input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> _models.EmbeddingsResult: - # pylint: disable=line-too-long - """Return the embedding vectors for given images. - The method makes a REST API call to the `/images/embeddings` route on the given endpoint. - - :param body: Is either a MutableMapping[str, Any] type (like a dictionary) or a IO[bytes] type - that specifies the full request payload. Required. - :type body: JSON or IO[bytes] - :keyword input: Input image to embed. To embed multiple inputs in a single request, pass an - array. - The input must not exceed the max input tokens for the model. Required. - :paramtype input: list[~azure.ai.inference.models.EmbeddingInput] - :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should - have. Default value is None. - :paramtype dimensions: int - :keyword encoding_format: Optional. The desired format for the returned embeddings. - Known values are: - "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. - :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat - :keyword input_type: Optional. The type of the input. Known values are: - "text", "query", and "document". Default value is None. - :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { - 401: ClientAuthenticationError, - 404: ResourceNotFoundError, - 409: ResourceExistsError, - 304: ResourceNotModifiedError, - } - error_map.update(kwargs.pop("error_map", {}) or {}) - - _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) - _params = kwargs.pop("params", {}) or {} - _extra_parameters: Union[_models._enums.ExtraParameters, None] = None - - content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - - if body is _Unset: - if input is _Unset: - raise TypeError("missing required argument: input") - body = { - "input": input, - "dimensions": dimensions if dimensions is not None else self._dimensions, - "encoding_format": encoding_format if encoding_format is not None else self._encoding_format, - "input_type": input_type if input_type is not None else self._input_type, - "model": model if model is not None else self._model, - } - if model_extras is not None and bool(model_extras): - body.update(model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access - elif self._model_extras is not None and bool(self._model_extras): - body.update(self._model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access - body = {k: v for k, v in body.items() if v is not None} - content_type = content_type or "application/json" - _content = None - if isinstance(body, (IOBase, bytes)): - _content = body - else: - _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore - - _request = build_image_embeddings_embed_request( - extra_params=_extra_parameters, - content_type=content_type, - api_version=self._config.api_version, - content=_content, - headers=_headers, - params=_params, - ) - path_format_arguments = { - "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), - } - _request.url = self._client.format_url(_request.url, **path_format_arguments) - - _stream = kwargs.pop("stream", False) - pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access - _request, stream=_stream, **kwargs - ) - - response = pipeline_response.http_response - - if response.status_code not in [200]: - if _stream: - response.read() # Load the body in memory and close the socket - map_error(status_code=response.status_code, response=response, error_map=error_map) - raise HttpResponseError(response=response) - - if _stream: - deserialized = response.iter_bytes() - else: - deserialized = _deserialize( - _models._patch.EmbeddingsResult, response.json() # pylint: disable=protected-access - ) - - return deserialized # type: ignore - - @distributed_trace - def get_model_info(self, **kwargs: Any) -> _models.ModelInfo: - # pylint: disable=line-too-long - """Returns information about the AI model. - The method makes a REST API call to the ``/info`` route on the given endpoint. - This method will only work when using Serverless API or Managed Compute endpoint. - It will not work for GitHub Models endpoint or Azure OpenAI endpoint. - - :return: ModelInfo. The ModelInfo is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.ModelInfo - :raises ~azure.core.exceptions.HttpResponseError: - """ - if not self._model_info: - self._model_info = self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init - return self._model_info - - def __str__(self) -> str: - # pylint: disable=client-method-name-no-double-underscore - return super().__str__() + f"\n{self._model_info}" if self._model_info else super().__str__() - +from typing import List -__all__: List[str] = [ - "load_client", - "ChatCompletionsClient", - "EmbeddingsClient", - "ImageEmbeddingsClient", -] # Add all objects you want publicly available to users at this package level +__all__: List[str] = [] # Add all objects you want publicly available to users at this package level def patch_sdk(): diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/_serialization.py b/sdk/ai/azure-ai-inference/azure/ai/inference/_serialization.py index 8139854b97bb..ce17d1798ce7 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/_serialization.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/_serialization.py @@ -1,3 +1,4 @@ +# pylint: disable=too-many-lines # -------------------------------------------------------------------------- # # Copyright (c) Microsoft Corporation. All rights reserved. @@ -24,7 +25,6 @@ # # -------------------------------------------------------------------------- -# pylint: skip-file # pyright: reportUnnecessaryTypeIgnoreComment=false from base64 import b64decode, b64encode @@ -52,7 +52,6 @@ MutableMapping, Type, List, - Mapping, ) try: @@ -91,6 +90,8 @@ def deserialize_from_text(cls, data: Optional[Union[AnyStr, IO]], content_type: :param data: Input, could be bytes or stream (will be decoded with UTF8) or text :type data: str or bytes or IO :param str content_type: The content type. + :return: The deserialized data. + :rtype: object """ if hasattr(data, "read"): # Assume a stream @@ -112,7 +113,7 @@ def deserialize_from_text(cls, data: Optional[Union[AnyStr, IO]], content_type: try: return json.loads(data_as_str) except ValueError as err: - raise DeserializationError("JSON is invalid: {}".format(err), err) + raise DeserializationError("JSON is invalid: {}".format(err), err) from err elif "xml" in (content_type or []): try: @@ -155,6 +156,11 @@ def deserialize_from_http_generics(cls, body_bytes: Optional[Union[AnyStr, IO]], Use bytes and headers to NOT use any requests/aiohttp or whatever specific implementation. Headers will tested for "content-type" + + :param bytes body_bytes: The body of the response. + :param dict headers: The headers of the response. + :returns: The deserialized data. + :rtype: object """ # Try to use content-type from headers if available content_type = None @@ -184,15 +190,30 @@ class UTC(datetime.tzinfo): """Time Zone info for handling UTC""" def utcoffset(self, dt): - """UTF offset for UTC is 0.""" + """UTF offset for UTC is 0. + + :param datetime.datetime dt: The datetime + :returns: The offset + :rtype: datetime.timedelta + """ return datetime.timedelta(0) def tzname(self, dt): - """Timestamp representation.""" + """Timestamp representation. + + :param datetime.datetime dt: The datetime + :returns: The timestamp representation + :rtype: str + """ return "Z" def dst(self, dt): - """No daylight saving for UTC.""" + """No daylight saving for UTC. + + :param datetime.datetime dt: The datetime + :returns: The daylight saving time + :rtype: datetime.timedelta + """ return datetime.timedelta(hours=1) @@ -206,7 +227,7 @@ class _FixedOffset(datetime.tzinfo): # type: ignore :param datetime.timedelta offset: offset in timedelta format """ - def __init__(self, offset): + def __init__(self, offset) -> None: self.__offset = offset def utcoffset(self, dt): @@ -235,24 +256,26 @@ def __getinitargs__(self): _FLATTEN = re.compile(r"(? None: self.additional_properties: Optional[Dict[str, Any]] = {} - for k in kwargs: + for k in kwargs: # pylint: disable=consider-using-dict-items if k not in self._attribute_map: _LOGGER.warning("%s is not a known attribute of class %s and will be ignored", k, self.__class__) elif k in self._validation and self._validation[k].get("readonly", False): @@ -300,13 +330,23 @@ def __init__(self, **kwargs: Any) -> None: setattr(self, k, kwargs[k]) def __eq__(self, other: Any) -> bool: - """Compare objects by comparing all attributes.""" + """Compare objects by comparing all attributes. + + :param object other: The object to compare + :returns: True if objects are equal + :rtype: bool + """ if isinstance(other, self.__class__): return self.__dict__ == other.__dict__ return False def __ne__(self, other: Any) -> bool: - """Compare objects by comparing all attributes.""" + """Compare objects by comparing all attributes. + + :param object other: The object to compare + :returns: True if objects are not equal + :rtype: bool + """ return not self.__eq__(other) def __str__(self) -> str: @@ -326,7 +366,11 @@ def is_xml_model(cls) -> bool: @classmethod def _create_xml_node(cls): - """Create XML node.""" + """Create XML node. + + :returns: The XML node + :rtype: xml.etree.ElementTree.Element + """ try: xml_map = cls._xml_map # type: ignore except AttributeError: @@ -346,7 +390,9 @@ def serialize(self, keep_readonly: bool = False, **kwargs: Any) -> JSON: :rtype: dict """ serializer = Serializer(self._infer_class_models()) - return serializer._serialize(self, keep_readonly=keep_readonly, **kwargs) # type: ignore + return serializer._serialize( # type: ignore # pylint: disable=protected-access + self, keep_readonly=keep_readonly, **kwargs + ) def as_dict( self, @@ -380,12 +426,15 @@ def my_key_transformer(key, attr_desc, value): If you want XML serialization, you can pass the kwargs is_xml=True. + :param bool keep_readonly: If you want to serialize the readonly attributes :param function key_transformer: A key transformer function. :returns: A dict JSON compatible object :rtype: dict """ serializer = Serializer(self._infer_class_models()) - return serializer._serialize(self, key_transformer=key_transformer, keep_readonly=keep_readonly, **kwargs) # type: ignore + return serializer._serialize( # type: ignore # pylint: disable=protected-access + self, key_transformer=key_transformer, keep_readonly=keep_readonly, **kwargs + ) @classmethod def _infer_class_models(cls): @@ -395,7 +444,7 @@ def _infer_class_models(cls): client_models = {k: v for k, v in models.__dict__.items() if isinstance(v, type)} if cls.__name__ not in client_models: raise ValueError("Not Autorest generated code") - except Exception: + except Exception: # pylint: disable=broad-exception-caught # Assume it's not Autorest generated (tests?). Add ourselves as dependencies. client_models = {cls.__name__: cls} return client_models @@ -408,6 +457,7 @@ def deserialize(cls: Type[ModelType], data: Any, content_type: Optional[str] = N :param str content_type: JSON by default, set application/xml if XML. :returns: An instance of this model :raises: DeserializationError if something went wrong + :rtype: ModelType """ deserializer = Deserializer(cls._infer_class_models()) return deserializer(cls.__name__, data, content_type=content_type) # type: ignore @@ -426,9 +476,11 @@ def from_dict( and last_rest_key_case_insensitive_extractor) :param dict data: A dict using RestAPI structure + :param function key_extractors: A key extractor function. :param str content_type: JSON by default, set application/xml if XML. :returns: An instance of this model :raises: DeserializationError if something went wrong + :rtype: ModelType """ deserializer = Deserializer(cls._infer_class_models()) deserializer.key_extractors = ( # type: ignore @@ -448,21 +500,25 @@ def _flatten_subtype(cls, key, objects): return {} result = dict(cls._subtype_map[key]) for valuetype in cls._subtype_map[key].values(): - result.update(objects[valuetype]._flatten_subtype(key, objects)) + result.update(objects[valuetype]._flatten_subtype(key, objects)) # pylint: disable=protected-access return result @classmethod def _classify(cls, response, objects): """Check the class _subtype_map for any child classes. We want to ignore any inherited _subtype_maps. - Remove the polymorphic key from the initial data. + + :param dict response: The initial data + :param dict objects: The class objects + :returns: The class to be used + :rtype: class """ for subtype_key in cls.__dict__.get("_subtype_map", {}).keys(): subtype_value = None if not isinstance(response, ET.Element): rest_api_response_key = cls._get_rest_key_parts(subtype_key)[-1] - subtype_value = response.pop(rest_api_response_key, None) or response.pop(subtype_key, None) + subtype_value = response.get(rest_api_response_key, None) or response.get(subtype_key, None) else: subtype_value = xml_key_extractor(subtype_key, cls._attribute_map[subtype_key], response) if subtype_value: @@ -501,11 +557,13 @@ def _decode_attribute_map_key(key): inside the received data. :param str key: A key string from the generated code + :returns: The decoded key + :rtype: str """ return key.replace("\\.", ".") -class Serializer(object): +class Serializer(object): # pylint: disable=too-many-public-methods """Request object model serializer.""" basic_types = {str: "str", int: "int", bool: "bool", float: "float"} @@ -540,7 +598,7 @@ class Serializer(object): "multiple": lambda x, y: x % y != 0, } - def __init__(self, classes: Optional[Mapping[str, type]] = None): + def __init__(self, classes: Optional[Mapping[str, type]] = None) -> None: self.serialize_type = { "iso-8601": Serializer.serialize_iso, "rfc-1123": Serializer.serialize_rfc, @@ -560,13 +618,16 @@ def __init__(self, classes: Optional[Mapping[str, type]] = None): self.key_transformer = full_restapi_key_transformer self.client_side_validation = True - def _serialize(self, target_obj, data_type=None, **kwargs): + def _serialize( # pylint: disable=too-many-nested-blocks, too-many-branches, too-many-statements, too-many-locals + self, target_obj, data_type=None, **kwargs + ): """Serialize data into a string according to type. - :param target_obj: The data to be serialized. + :param object target_obj: The data to be serialized. :param str data_type: The type to be serialized from. :rtype: str, dict :raises: SerializationError if serialization fails. + :returns: The serialized data. """ key_transformer = kwargs.get("key_transformer", self.key_transformer) keep_readonly = kwargs.get("keep_readonly", False) @@ -592,12 +653,14 @@ def _serialize(self, target_obj, data_type=None, **kwargs): serialized = {} if is_xml_model_serialization: - serialized = target_obj._create_xml_node() + serialized = target_obj._create_xml_node() # pylint: disable=protected-access try: - attributes = target_obj._attribute_map + attributes = target_obj._attribute_map # pylint: disable=protected-access for attr, attr_desc in attributes.items(): attr_name = attr - if not keep_readonly and target_obj._validation.get(attr_name, {}).get("readonly", False): + if not keep_readonly and target_obj._validation.get( # pylint: disable=protected-access + attr_name, {} + ).get("readonly", False): continue if attr_name == "additional_properties" and attr_desc["key"] == "": @@ -633,7 +696,8 @@ def _serialize(self, target_obj, data_type=None, **kwargs): if isinstance(new_attr, list): serialized.extend(new_attr) # type: ignore elif isinstance(new_attr, ET.Element): - # If the down XML has no XML/Name, we MUST replace the tag with the local tag. But keeping the namespaces. + # If the down XML has no XML/Name, + # we MUST replace the tag with the local tag. But keeping the namespaces. if "name" not in getattr(orig_attr, "_xml_map", {}): splitted_tag = new_attr.tag.split("}") if len(splitted_tag) == 2: # Namespace @@ -664,17 +728,17 @@ def _serialize(self, target_obj, data_type=None, **kwargs): except (AttributeError, KeyError, TypeError) as err: msg = "Attribute {} in object {} cannot be serialized.\n{}".format(attr_name, class_name, str(target_obj)) raise SerializationError(msg) from err - else: - return serialized + return serialized def body(self, data, data_type, **kwargs): """Serialize data intended for a request body. - :param data: The data to be serialized. + :param object data: The data to be serialized. :param str data_type: The type to be serialized from. :rtype: dict :raises: SerializationError if serialization fails. :raises: ValueError if data is None + :returns: The serialized request body """ # Just in case this is a dict @@ -703,7 +767,7 @@ def body(self, data, data_type, **kwargs): attribute_key_case_insensitive_extractor, last_rest_key_case_insensitive_extractor, ] - data = deserializer._deserialize(data_type, data) + data = deserializer._deserialize(data_type, data) # pylint: disable=protected-access except DeserializationError as err: raise SerializationError("Unable to build a model: " + str(err)) from err @@ -712,9 +776,11 @@ def body(self, data, data_type, **kwargs): def url(self, name, data, data_type, **kwargs): """Serialize data intended for a URL path. - :param data: The data to be serialized. + :param str name: The name of the URL path parameter. + :param object data: The data to be serialized. :param str data_type: The type to be serialized from. :rtype: str + :returns: The serialized URL path :raises: TypeError if serialization fails. :raises: ValueError if data is None """ @@ -728,21 +794,20 @@ def url(self, name, data, data_type, **kwargs): output = output.replace("{", quote("{")).replace("}", quote("}")) else: output = quote(str(output), safe="") - except SerializationError: - raise TypeError("{} must be type {}.".format(name, data_type)) - else: - return output + except SerializationError as exc: + raise TypeError("{} must be type {}.".format(name, data_type)) from exc + return output def query(self, name, data, data_type, **kwargs): """Serialize data intended for a URL query. - :param data: The data to be serialized. + :param str name: The name of the query parameter. + :param object data: The data to be serialized. :param str data_type: The type to be serialized from. - :keyword bool skip_quote: Whether to skip quote the serialized result. - Defaults to False. :rtype: str, list :raises: TypeError if serialization fails. :raises: ValueError if data is None + :returns: The serialized query parameter """ try: # Treat the list aside, since we don't want to encode the div separator @@ -759,19 +824,20 @@ def query(self, name, data, data_type, **kwargs): output = str(output) else: output = quote(str(output), safe="") - except SerializationError: - raise TypeError("{} must be type {}.".format(name, data_type)) - else: - return str(output) + except SerializationError as exc: + raise TypeError("{} must be type {}.".format(name, data_type)) from exc + return str(output) def header(self, name, data, data_type, **kwargs): """Serialize data intended for a request header. - :param data: The data to be serialized. + :param str name: The name of the header. + :param object data: The data to be serialized. :param str data_type: The type to be serialized from. :rtype: str :raises: TypeError if serialization fails. :raises: ValueError if data is None + :returns: The serialized header """ try: if data_type in ["[str]"]: @@ -780,21 +846,20 @@ def header(self, name, data, data_type, **kwargs): output = self.serialize_data(data, data_type, **kwargs) if data_type == "bool": output = json.dumps(output) - except SerializationError: - raise TypeError("{} must be type {}.".format(name, data_type)) - else: - return str(output) + except SerializationError as exc: + raise TypeError("{} must be type {}.".format(name, data_type)) from exc + return str(output) def serialize_data(self, data, data_type, **kwargs): """Serialize generic data according to supplied data type. - :param data: The data to be serialized. + :param object data: The data to be serialized. :param str data_type: The type to be serialized from. - :param bool required: Whether it's essential that the data not be - empty or None :raises: AttributeError if required data is None. :raises: ValueError if data is None :raises: SerializationError if serialization fails. + :returns: The serialized data. + :rtype: str, int, float, bool, dict, list """ if data is None: raise ValueError("No value for given attribute") @@ -805,7 +870,7 @@ def serialize_data(self, data, data_type, **kwargs): if data_type in self.basic_types.values(): return self.serialize_basic(data, data_type, **kwargs) - elif data_type in self.serialize_type: + if data_type in self.serialize_type: return self.serialize_type[data_type](data, **kwargs) # If dependencies is empty, try with current data class @@ -821,11 +886,10 @@ def serialize_data(self, data, data_type, **kwargs): except (ValueError, TypeError) as err: msg = "Unable to serialize value: {!r} as type: {!r}." raise SerializationError(msg.format(data, data_type)) from err - else: - return self._serialize(data, **kwargs) + return self._serialize(data, **kwargs) @classmethod - def _get_custom_serializers(cls, data_type, **kwargs): + def _get_custom_serializers(cls, data_type, **kwargs): # pylint: disable=inconsistent-return-statements custom_serializer = kwargs.get("basic_types_serializers", {}).get(data_type) if custom_serializer: return custom_serializer @@ -841,23 +905,26 @@ def serialize_basic(cls, data, data_type, **kwargs): - basic_types_serializers dict[str, callable] : If set, use the callable as serializer - is_xml bool : If set, use xml_basic_types_serializers - :param data: Object to be serialized. + :param obj data: Object to be serialized. :param str data_type: Type of object in the iterable. + :rtype: str, int, float, bool + :return: serialized object """ custom_serializer = cls._get_custom_serializers(data_type, **kwargs) if custom_serializer: return custom_serializer(data) if data_type == "str": return cls.serialize_unicode(data) - return eval(data_type)(data) # nosec + return eval(data_type)(data) # nosec # pylint: disable=eval-used @classmethod def serialize_unicode(cls, data): """Special handling for serializing unicode strings in Py2. Encode to UTF-8 if unicode, otherwise handle as a str. - :param data: Object to be serialized. + :param str data: Object to be serialized. :rtype: str + :return: serialized object """ try: # If I received an enum, return its value return data.value @@ -871,8 +938,7 @@ def serialize_unicode(cls, data): return data except NameError: return str(data) - else: - return str(data) + return str(data) def serialize_iter(self, data, iter_type, div=None, **kwargs): """Serialize iterable. @@ -882,15 +948,13 @@ def serialize_iter(self, data, iter_type, div=None, **kwargs): serialization_ctxt['type'] should be same as data_type. - is_xml bool : If set, serialize as XML - :param list attr: Object to be serialized. + :param list data: Object to be serialized. :param str iter_type: Type of object in the iterable. - :param bool required: Whether the objects in the iterable must - not be None or empty. :param str div: If set, this str will be used to combine the elements in the iterable into a combined string. Default is 'None'. - :keyword bool do_quote: Whether to quote the serialized result of each iterable element. Defaults to False. :rtype: list, str + :return: serialized iterable """ if isinstance(data, str): raise SerializationError("Refuse str type as a valid iter type.") @@ -945,9 +1009,8 @@ def serialize_dict(self, attr, dict_type, **kwargs): :param dict attr: Object to be serialized. :param str dict_type: Type of object in the dictionary. - :param bool required: Whether the objects in the dictionary must - not be None or empty. :rtype: dict + :return: serialized dictionary """ serialization_ctxt = kwargs.get("serialization_ctxt", {}) serialized = {} @@ -971,7 +1034,7 @@ def serialize_dict(self, attr, dict_type, **kwargs): return serialized - def serialize_object(self, attr, **kwargs): + def serialize_object(self, attr, **kwargs): # pylint: disable=too-many-return-statements """Serialize a generic object. This will be handled as a dictionary. If object passed in is not a basic type (str, int, float, dict, list) it will simply be @@ -979,6 +1042,7 @@ def serialize_object(self, attr, **kwargs): :param dict attr: Object to be serialized. :rtype: dict or str + :return: serialized object """ if attr is None: return None @@ -1003,7 +1067,7 @@ def serialize_object(self, attr, **kwargs): return self.serialize_decimal(attr) # If it's a model or I know this dependency, serialize as a Model - elif obj_type in self.dependencies.values() or isinstance(attr, Model): + if obj_type in self.dependencies.values() or isinstance(attr, Model): return self._serialize(attr) if obj_type == dict: @@ -1034,56 +1098,61 @@ def serialize_enum(attr, enum_obj=None): try: enum_obj(result) # type: ignore return result - except ValueError: + except ValueError as exc: for enum_value in enum_obj: # type: ignore if enum_value.value.lower() == str(attr).lower(): return enum_value.value error = "{!r} is not valid value for enum {!r}" - raise SerializationError(error.format(attr, enum_obj)) + raise SerializationError(error.format(attr, enum_obj)) from exc @staticmethod - def serialize_bytearray(attr, **kwargs): + def serialize_bytearray(attr, **kwargs): # pylint: disable=unused-argument """Serialize bytearray into base-64 string. - :param attr: Object to be serialized. + :param str attr: Object to be serialized. :rtype: str + :return: serialized base64 """ return b64encode(attr).decode() @staticmethod - def serialize_base64(attr, **kwargs): + def serialize_base64(attr, **kwargs): # pylint: disable=unused-argument """Serialize str into base-64 string. - :param attr: Object to be serialized. + :param str attr: Object to be serialized. :rtype: str + :return: serialized base64 """ encoded = b64encode(attr).decode("ascii") return encoded.strip("=").replace("+", "-").replace("/", "_") @staticmethod - def serialize_decimal(attr, **kwargs): + def serialize_decimal(attr, **kwargs): # pylint: disable=unused-argument """Serialize Decimal object to float. - :param attr: Object to be serialized. + :param decimal attr: Object to be serialized. :rtype: float + :return: serialized decimal """ return float(attr) @staticmethod - def serialize_long(attr, **kwargs): + def serialize_long(attr, **kwargs): # pylint: disable=unused-argument """Serialize long (Py2) or int (Py3). - :param attr: Object to be serialized. + :param int attr: Object to be serialized. :rtype: int/long + :return: serialized long """ return _long_type(attr) @staticmethod - def serialize_date(attr, **kwargs): + def serialize_date(attr, **kwargs): # pylint: disable=unused-argument """Serialize Date object into ISO-8601 formatted string. :param Date attr: Object to be serialized. :rtype: str + :return: serialized date """ if isinstance(attr, str): attr = isodate.parse_date(attr) @@ -1091,11 +1160,12 @@ def serialize_date(attr, **kwargs): return t @staticmethod - def serialize_time(attr, **kwargs): + def serialize_time(attr, **kwargs): # pylint: disable=unused-argument """Serialize Time object into ISO-8601 formatted string. :param datetime.time attr: Object to be serialized. :rtype: str + :return: serialized time """ if isinstance(attr, str): attr = isodate.parse_time(attr) @@ -1105,30 +1175,32 @@ def serialize_time(attr, **kwargs): return t @staticmethod - def serialize_duration(attr, **kwargs): + def serialize_duration(attr, **kwargs): # pylint: disable=unused-argument """Serialize TimeDelta object into ISO-8601 formatted string. :param TimeDelta attr: Object to be serialized. :rtype: str + :return: serialized duration """ if isinstance(attr, str): attr = isodate.parse_duration(attr) return isodate.duration_isoformat(attr) @staticmethod - def serialize_rfc(attr, **kwargs): + def serialize_rfc(attr, **kwargs): # pylint: disable=unused-argument """Serialize Datetime object into RFC-1123 formatted string. :param Datetime attr: Object to be serialized. :rtype: str :raises: TypeError if format invalid. + :return: serialized rfc """ try: if not attr.tzinfo: _LOGGER.warning("Datetime with no tzinfo will be considered UTC.") utc = attr.utctimetuple() - except AttributeError: - raise TypeError("RFC1123 object must be valid Datetime object.") + except AttributeError as exc: + raise TypeError("RFC1123 object must be valid Datetime object.") from exc return "{}, {:02} {} {:04} {:02}:{:02}:{:02} GMT".format( Serializer.days[utc.tm_wday], @@ -1141,12 +1213,13 @@ def serialize_rfc(attr, **kwargs): ) @staticmethod - def serialize_iso(attr, **kwargs): + def serialize_iso(attr, **kwargs): # pylint: disable=unused-argument """Serialize Datetime object into ISO-8601 formatted string. :param Datetime attr: Object to be serialized. :rtype: str :raises: SerializationError if format invalid. + :return: serialized iso """ if isinstance(attr, str): attr = isodate.parse_datetime(attr) @@ -1172,13 +1245,14 @@ def serialize_iso(attr, **kwargs): raise TypeError(msg) from err @staticmethod - def serialize_unix(attr, **kwargs): + def serialize_unix(attr, **kwargs): # pylint: disable=unused-argument """Serialize Datetime object into IntTime format. This is represented as seconds. :param Datetime attr: Object to be serialized. :rtype: int :raises: SerializationError if format invalid + :return: serialied unix """ if isinstance(attr, int): return attr @@ -1186,11 +1260,11 @@ def serialize_unix(attr, **kwargs): if not attr.tzinfo: _LOGGER.warning("Datetime with no tzinfo will be considered UTC.") return int(calendar.timegm(attr.utctimetuple())) - except AttributeError: - raise TypeError("Unix time object must be valid Datetime object.") + except AttributeError as exc: + raise TypeError("Unix time object must be valid Datetime object.") from exc -def rest_key_extractor(attr, attr_desc, data): +def rest_key_extractor(attr, attr_desc, data): # pylint: disable=unused-argument key = attr_desc["key"] working_data = data @@ -1211,7 +1285,9 @@ def rest_key_extractor(attr, attr_desc, data): return working_data.get(key) -def rest_key_case_insensitive_extractor(attr, attr_desc, data): +def rest_key_case_insensitive_extractor( # pylint: disable=unused-argument, inconsistent-return-statements + attr, attr_desc, data +): key = attr_desc["key"] working_data = data @@ -1232,17 +1308,29 @@ def rest_key_case_insensitive_extractor(attr, attr_desc, data): return attribute_key_case_insensitive_extractor(key, None, working_data) -def last_rest_key_extractor(attr, attr_desc, data): - """Extract the attribute in "data" based on the last part of the JSON path key.""" +def last_rest_key_extractor(attr, attr_desc, data): # pylint: disable=unused-argument + """Extract the attribute in "data" based on the last part of the JSON path key. + + :param str attr: The attribute to extract + :param dict attr_desc: The attribute description + :param dict data: The data to extract from + :rtype: object + :returns: The extracted attribute + """ key = attr_desc["key"] dict_keys = _FLATTEN.split(key) return attribute_key_extractor(dict_keys[-1], None, data) -def last_rest_key_case_insensitive_extractor(attr, attr_desc, data): +def last_rest_key_case_insensitive_extractor(attr, attr_desc, data): # pylint: disable=unused-argument """Extract the attribute in "data" based on the last part of the JSON path key. This is the case insensitive version of "last_rest_key_extractor" + :param str attr: The attribute to extract + :param dict attr_desc: The attribute description + :param dict data: The data to extract from + :rtype: object + :returns: The extracted attribute """ key = attr_desc["key"] dict_keys = _FLATTEN.split(key) @@ -1279,7 +1367,7 @@ def _extract_name_from_internal_type(internal_type): return xml_name -def xml_key_extractor(attr, attr_desc, data): +def xml_key_extractor(attr, attr_desc, data): # pylint: disable=unused-argument,too-many-return-statements if isinstance(data, dict): return None @@ -1331,22 +1419,21 @@ def xml_key_extractor(attr, attr_desc, data): if is_iter_type: if is_wrapped: return None # is_wrapped no node, we want None - else: - return [] # not wrapped, assume empty list + return [] # not wrapped, assume empty list return None # Assume it's not there, maybe an optional node. # If is_iter_type and not wrapped, return all found children if is_iter_type: if not is_wrapped: return children - else: # Iter and wrapped, should have found one node only (the wrap one) - if len(children) != 1: - raise DeserializationError( - "Tried to deserialize an array not wrapped, and found several nodes '{}'. Maybe you should declare this array as wrapped?".format( - xml_name - ) + # Iter and wrapped, should have found one node only (the wrap one) + if len(children) != 1: + raise DeserializationError( + "Tried to deserialize an array not wrapped, and found several nodes '{}'. Maybe you should declare this array as wrapped?".format( # pylint: disable=line-too-long + xml_name ) - return list(children[0]) # Might be empty list and that's ok. + ) + return list(children[0]) # Might be empty list and that's ok. # Here it's not a itertype, we should have found one element only or empty if len(children) > 1: @@ -1363,9 +1450,9 @@ class Deserializer(object): basic_types = {str: "str", int: "int", bool: "bool", float: "float"} - valid_date = re.compile(r"\d{4}[-]\d{2}[-]\d{2}T\d{2}:\d{2}:\d{2}" r"\.?\d*Z?[-+]?[\d{2}]?:?[\d{2}]?") + valid_date = re.compile(r"\d{4}[-]\d{2}[-]\d{2}T\d{2}:\d{2}:\d{2}\.?\d*Z?[-+]?[\d{2}]?:?[\d{2}]?") - def __init__(self, classes: Optional[Mapping[str, type]] = None): + def __init__(self, classes: Optional[Mapping[str, type]] = None) -> None: self.deserialize_type = { "iso-8601": Deserializer.deserialize_iso, "rfc-1123": Deserializer.deserialize_rfc, @@ -1403,11 +1490,12 @@ def __call__(self, target_obj, response_data, content_type=None): :param str content_type: Swagger "produces" if available. :raises: DeserializationError if deserialization fails. :return: Deserialized object. + :rtype: object """ data = self._unpack_content(response_data, content_type) return self._deserialize(target_obj, data) - def _deserialize(self, target_obj, data): + def _deserialize(self, target_obj, data): # pylint: disable=inconsistent-return-statements """Call the deserializer on a model. Data needs to be already deserialized as JSON or XML ElementTree @@ -1416,12 +1504,13 @@ def _deserialize(self, target_obj, data): :param object data: Object to deserialize. :raises: DeserializationError if deserialization fails. :return: Deserialized object. + :rtype: object """ # This is already a model, go recursive just in case if hasattr(data, "_attribute_map"): constants = [name for name, config in getattr(data, "_validation", {}).items() if config.get("constant")] try: - for attr, mapconfig in data._attribute_map.items(): + for attr, mapconfig in data._attribute_map.items(): # pylint: disable=protected-access if attr in constants: continue value = getattr(data, attr) @@ -1440,13 +1529,13 @@ def _deserialize(self, target_obj, data): if isinstance(response, str): return self.deserialize_data(data, response) - elif isinstance(response, type) and issubclass(response, Enum): + if isinstance(response, type) and issubclass(response, Enum): return self.deserialize_enum(data, response) if data is None or data is CoreNull: return data try: - attributes = response._attribute_map # type: ignore + attributes = response._attribute_map # type: ignore # pylint: disable=protected-access d_attrs = {} for attr, attr_desc in attributes.items(): # Check empty string. If it's not empty, someone has a real "additionalProperties"... @@ -1476,9 +1565,8 @@ def _deserialize(self, target_obj, data): except (AttributeError, TypeError, KeyError) as err: msg = "Unable to deserialize to object: " + class_name # type: ignore raise DeserializationError(msg) from err - else: - additional_properties = self._build_additional_properties(attributes, data) - return self._instantiate_model(response, d_attrs, additional_properties) + additional_properties = self._build_additional_properties(attributes, data) + return self._instantiate_model(response, d_attrs, additional_properties) def _build_additional_properties(self, attribute_map, data): if not self.additional_properties_detection: @@ -1505,6 +1593,8 @@ def _classify_target(self, target, data): :param str target: The target object type to deserialize to. :param str/dict data: The response data to deserialize. + :return: The classified target object and its class name. + :rtype: tuple """ if target is None: return None, None @@ -1516,7 +1606,7 @@ def _classify_target(self, target, data): return target, target try: - target = target._classify(data, self.dependencies) # type: ignore + target = target._classify(data, self.dependencies) # type: ignore # pylint: disable=protected-access except AttributeError: pass # Target is not a Model, no classify return target, target.__class__.__name__ # type: ignore @@ -1531,10 +1621,12 @@ def failsafe_deserialize(self, target_obj, data, content_type=None): :param str target_obj: The target object type to deserialize to. :param str/dict data: The response data to deserialize. :param str content_type: Swagger "produces" if available. + :return: Deserialized object. + :rtype: object """ try: return self(target_obj, data, content_type=content_type) - except: + except: # pylint: disable=bare-except _LOGGER.debug( "Ran into a deserialization error. Ignoring since this is failsafe deserialization", exc_info=True ) @@ -1552,10 +1644,12 @@ def _unpack_content(raw_data, content_type=None): If raw_data is something else, bypass all logic and return it directly. - :param raw_data: Data to be processed. - :param content_type: How to parse if raw_data is a string/bytes. + :param obj raw_data: Data to be processed. + :param str content_type: How to parse if raw_data is a string/bytes. :raises JSONDecodeError: If JSON is requested and parsing is impossible. :raises UnicodeDecodeError: If bytes is not UTF8 + :rtype: object + :return: Unpacked content. """ # Assume this is enough to detect a Pipeline Response without importing it context = getattr(raw_data, "context", {}) @@ -1579,14 +1673,21 @@ def _unpack_content(raw_data, content_type=None): def _instantiate_model(self, response, attrs, additional_properties=None): """Instantiate a response model passing in deserialized args. - :param response: The response model class. - :param d_attrs: The deserialized response attributes. + :param Response response: The response model class. + :param dict attrs: The deserialized response attributes. + :param dict additional_properties: Additional properties to be set. + :rtype: Response + :return: The instantiated response model. """ if callable(response): subtype = getattr(response, "_subtype_map", {}) try: - readonly = [k for k, v in response._validation.items() if v.get("readonly")] - const = [k for k, v in response._validation.items() if v.get("constant")] + readonly = [ + k for k, v in response._validation.items() if v.get("readonly") # pylint: disable=protected-access + ] + const = [ + k for k, v in response._validation.items() if v.get("constant") # pylint: disable=protected-access + ] kwargs = {k: v for k, v in attrs.items() if k not in subtype and k not in readonly + const} response_obj = response(**kwargs) for attr in readonly: @@ -1596,7 +1697,7 @@ def _instantiate_model(self, response, attrs, additional_properties=None): return response_obj except TypeError as err: msg = "Unable to deserialize {} into model {}. ".format(kwargs, response) # type: ignore - raise DeserializationError(msg + str(err)) + raise DeserializationError(msg + str(err)) from err else: try: for attr, value in attrs.items(): @@ -1605,15 +1706,16 @@ def _instantiate_model(self, response, attrs, additional_properties=None): except Exception as exp: msg = "Unable to populate response model. " msg += "Type: {}, Error: {}".format(type(response), exp) - raise DeserializationError(msg) + raise DeserializationError(msg) from exp - def deserialize_data(self, data, data_type): + def deserialize_data(self, data, data_type): # pylint: disable=too-many-return-statements """Process data for deserialization according to data type. :param str data: The response string to be deserialized. :param str data_type: The type to deserialize to. :raises: DeserializationError if deserialization fails. :return: Deserialized object. + :rtype: object """ if data is None: return data @@ -1627,7 +1729,11 @@ def deserialize_data(self, data, data_type): if isinstance(data, self.deserialize_expected_types.get(data_type, tuple())): return data - is_a_text_parsing_type = lambda x: x not in ["object", "[]", r"{}"] + is_a_text_parsing_type = lambda x: x not in [ # pylint: disable=unnecessary-lambda-assignment + "object", + "[]", + r"{}", + ] if isinstance(data, ET.Element) and is_a_text_parsing_type(data_type) and not data.text: return None data_val = self.deserialize_type[data_type](data) @@ -1647,14 +1753,14 @@ def deserialize_data(self, data, data_type): msg = "Unable to deserialize response data." msg += " Data: {}, {}".format(data, data_type) raise DeserializationError(msg) from err - else: - return self._deserialize(obj_type, data) + return self._deserialize(obj_type, data) def deserialize_iter(self, attr, iter_type): """Deserialize an iterable. :param list attr: Iterable to be deserialized. :param str iter_type: The type of object in the iterable. + :return: Deserialized iterable. :rtype: list """ if attr is None: @@ -1671,6 +1777,7 @@ def deserialize_dict(self, attr, dict_type): :param dict/list attr: Dictionary to be deserialized. Also accepts a list of key, value pairs. :param str dict_type: The object type of the items in the dictionary. + :return: Deserialized dictionary. :rtype: dict """ if isinstance(attr, list): @@ -1681,11 +1788,12 @@ def deserialize_dict(self, attr, dict_type): attr = {el.tag: el.text for el in attr} return {k: self.deserialize_data(v, dict_type) for k, v in attr.items()} - def deserialize_object(self, attr, **kwargs): + def deserialize_object(self, attr, **kwargs): # pylint: disable=too-many-return-statements """Deserialize a generic object. This will be handled as a dictionary. :param dict attr: Dictionary to be deserialized. + :return: Deserialized object. :rtype: dict :raises: TypeError if non-builtin datatype encountered. """ @@ -1720,11 +1828,10 @@ def deserialize_object(self, attr, **kwargs): pass return deserialized - else: - error = "Cannot deserialize generic object with type: " - raise TypeError(error + str(obj_type)) + error = "Cannot deserialize generic object with type: " + raise TypeError(error + str(obj_type)) - def deserialize_basic(self, attr, data_type): + def deserialize_basic(self, attr, data_type): # pylint: disable=too-many-return-statements """Deserialize basic builtin data type from string. Will attempt to convert to str, int, float and bool. This function will also accept '1', '0', 'true' and 'false' as @@ -1732,6 +1839,7 @@ def deserialize_basic(self, attr, data_type): :param str attr: response string to be deserialized. :param str data_type: deserialization data type. + :return: Deserialized basic type. :rtype: str, int, float or bool :raises: TypeError if string format is not valid. """ @@ -1743,24 +1851,23 @@ def deserialize_basic(self, attr, data_type): if data_type == "str": # None or '', node is empty string. return "" - else: - # None or '', node with a strong type is None. - # Don't try to model "empty bool" or "empty int" - return None + # None or '', node with a strong type is None. + # Don't try to model "empty bool" or "empty int" + return None if data_type == "bool": if attr in [True, False, 1, 0]: return bool(attr) - elif isinstance(attr, str): + if isinstance(attr, str): if attr.lower() in ["true", "1"]: return True - elif attr.lower() in ["false", "0"]: + if attr.lower() in ["false", "0"]: return False raise TypeError("Invalid boolean value: {}".format(attr)) if data_type == "str": return self.deserialize_unicode(attr) - return eval(data_type)(attr) # nosec + return eval(data_type)(attr) # nosec # pylint: disable=eval-used @staticmethod def deserialize_unicode(data): @@ -1768,6 +1875,7 @@ def deserialize_unicode(data): as a string. :param str data: response string to be deserialized. + :return: Deserialized string. :rtype: str or unicode """ # We might be here because we have an enum modeled as string, @@ -1781,8 +1889,7 @@ def deserialize_unicode(data): return data except NameError: return str(data) - else: - return str(data) + return str(data) @staticmethod def deserialize_enum(data, enum_obj): @@ -1794,6 +1901,7 @@ def deserialize_enum(data, enum_obj): :param str data: Response string to be deserialized. If this value is None or invalid it will be returned as-is. :param Enum enum_obj: Enum object to deserialize to. + :return: Deserialized enum object. :rtype: Enum """ if isinstance(data, enum_obj) or data is None: @@ -1804,9 +1912,9 @@ def deserialize_enum(data, enum_obj): # Workaround. We might consider remove it in the future. try: return list(enum_obj.__members__.values())[data] - except IndexError: + except IndexError as exc: error = "{!r} is not a valid index for enum {!r}" - raise DeserializationError(error.format(data, enum_obj)) + raise DeserializationError(error.format(data, enum_obj)) from exc try: return enum_obj(str(data)) except ValueError: @@ -1822,6 +1930,7 @@ def deserialize_bytearray(attr): """Deserialize string into bytearray. :param str attr: response string to be deserialized. + :return: Deserialized bytearray :rtype: bytearray :raises: TypeError if string format invalid. """ @@ -1834,6 +1943,7 @@ def deserialize_base64(attr): """Deserialize base64 encoded string into string. :param str attr: response string to be deserialized. + :return: Deserialized base64 string :rtype: bytearray :raises: TypeError if string format invalid. """ @@ -1849,8 +1959,9 @@ def deserialize_decimal(attr): """Deserialize string into Decimal object. :param str attr: response string to be deserialized. - :rtype: Decimal + :return: Deserialized decimal :raises: DeserializationError if string format invalid. + :rtype: decimal """ if isinstance(attr, ET.Element): attr = attr.text @@ -1865,6 +1976,7 @@ def deserialize_long(attr): """Deserialize string into long (Py2) or int (Py3). :param str attr: response string to be deserialized. + :return: Deserialized int :rtype: long or int :raises: ValueError if string format invalid. """ @@ -1877,6 +1989,7 @@ def deserialize_duration(attr): """Deserialize ISO-8601 formatted string into TimeDelta object. :param str attr: response string to be deserialized. + :return: Deserialized duration :rtype: TimeDelta :raises: DeserializationError if string format invalid. """ @@ -1887,14 +2000,14 @@ def deserialize_duration(attr): except (ValueError, OverflowError, AttributeError) as err: msg = "Cannot deserialize duration object." raise DeserializationError(msg) from err - else: - return duration + return duration @staticmethod def deserialize_date(attr): """Deserialize ISO-8601 formatted string into Date object. :param str attr: response string to be deserialized. + :return: Deserialized date :rtype: Date :raises: DeserializationError if string format invalid. """ @@ -1910,6 +2023,7 @@ def deserialize_time(attr): """Deserialize ISO-8601 formatted string into time object. :param str attr: response string to be deserialized. + :return: Deserialized time :rtype: datetime.time :raises: DeserializationError if string format invalid. """ @@ -1924,6 +2038,7 @@ def deserialize_rfc(attr): """Deserialize RFC-1123 formatted string into Datetime object. :param str attr: response string to be deserialized. + :return: Deserialized RFC datetime :rtype: Datetime :raises: DeserializationError if string format invalid. """ @@ -1939,14 +2054,14 @@ def deserialize_rfc(attr): except ValueError as err: msg = "Cannot deserialize to rfc datetime object." raise DeserializationError(msg) from err - else: - return date_obj + return date_obj @staticmethod def deserialize_iso(attr): """Deserialize ISO-8601 formatted string into Datetime object. :param str attr: response string to be deserialized. + :return: Deserialized ISO datetime :rtype: Datetime :raises: DeserializationError if string format invalid. """ @@ -1976,8 +2091,7 @@ def deserialize_iso(attr): except (ValueError, OverflowError, AttributeError) as err: msg = "Cannot deserialize datetime object." raise DeserializationError(msg) from err - else: - return date_obj + return date_obj @staticmethod def deserialize_unix(attr): @@ -1985,6 +2099,7 @@ def deserialize_unix(attr): This is represented as seconds. :param int attr: Object to be serialized. + :return: Deserialized datetime :rtype: Datetime :raises: DeserializationError if format invalid """ @@ -1996,5 +2111,4 @@ def deserialize_unix(attr): except ValueError as err: msg = "Cannot deserialize to unix datetime object." raise DeserializationError(msg) from err - else: - return date_obj + return date_obj diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/_vendor.py b/sdk/ai/azure-ai-inference/azure/ai/inference/_vendor.py index 8ea240fb008b..147e96be133e 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/_vendor.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/_vendor.py @@ -15,7 +15,6 @@ ) if TYPE_CHECKING: - # pylint: disable=unused-import,ungrouped-imports from azure.core import PipelineClient from ._serialization import Deserializer, Serializer diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/_version.py b/sdk/ai/azure-ai-inference/azure/ai/inference/_version.py index c7d155d924dd..be71c81bd282 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/_version.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/_version.py @@ -6,4 +6,4 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -VERSION = "1.0.0b5" +VERSION = "1.0.0b1" diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/__init__.py b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/__init__.py index c31764c00803..668f989a5838 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/__init__.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/__init__.py @@ -5,21 +5,29 @@ # Code generated by Microsoft (R) Python Code Generator. # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- +# pylint: disable=wrong-import-position -from ._patch import ChatCompletionsClient -from ._patch import EmbeddingsClient -from ._patch import ImageEmbeddingsClient +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ._patch import * # pylint: disable=unused-wildcard-import -from ._patch import load_client +from ._client import ChatCompletionsClient # type: ignore +from ._client import EmbeddingsClient # type: ignore +from ._client import ImageEmbeddingsClient # type: ignore + +try: + from ._patch import __all__ as _patch_all + from ._patch import * +except ImportError: + _patch_all = [] from ._patch import patch_sdk as _patch_sdk __all__ = [ - "load_client", "ChatCompletionsClient", "EmbeddingsClient", "ImageEmbeddingsClient", ] - +__all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_client.py b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_client.py index 30c7afbfbd91..0dff39293bd7 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_client.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_client.py @@ -28,19 +28,17 @@ ) if TYPE_CHECKING: - # pylint: disable=unused-import,ungrouped-imports from azure.core.credentials_async import AsyncTokenCredential -class ChatCompletionsClient(ChatCompletionsClientOperationsMixin): # pylint: disable=client-accepts-api-version-keyword +class ChatCompletionsClient(ChatCompletionsClientOperationsMixin): """ChatCompletionsClient. :param endpoint: Service host. Required. :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is one of the - following types: AzureKeyCredential, AzureKeyCredential, TokenCredential Required. + :param credential: Credential used to authenticate requests to the service. Is either a + AzureKeyCredential type or a TokenCredential type. Required. :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials_async.AsyncTokenCredential :keyword api_version: The API version to use for this operation. Default value is "2024-05-01-preview". Note that overriding this default value may result in unsupported @@ -115,15 +113,14 @@ async def __aexit__(self, *exc_details: Any) -> None: await self._client.__aexit__(*exc_details) -class EmbeddingsClient(EmbeddingsClientOperationsMixin): # pylint: disable=client-accepts-api-version-keyword +class EmbeddingsClient(EmbeddingsClientOperationsMixin): """EmbeddingsClient. :param endpoint: Service host. Required. :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is one of the - following types: AzureKeyCredential, AzureKeyCredential, TokenCredential Required. + :param credential: Credential used to authenticate requests to the service. Is either a + AzureKeyCredential type or a TokenCredential type. Required. :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials_async.AsyncTokenCredential :keyword api_version: The API version to use for this operation. Default value is "2024-05-01-preview". Note that overriding this default value may result in unsupported @@ -198,15 +195,14 @@ async def __aexit__(self, *exc_details: Any) -> None: await self._client.__aexit__(*exc_details) -class ImageEmbeddingsClient(ImageEmbeddingsClientOperationsMixin): # pylint: disable=client-accepts-api-version-keyword +class ImageEmbeddingsClient(ImageEmbeddingsClientOperationsMixin): """ImageEmbeddingsClient. :param endpoint: Service host. Required. :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is one of the - following types: AzureKeyCredential, AzureKeyCredential, TokenCredential Required. + :param credential: Credential used to authenticate requests to the service. Is either a + AzureKeyCredential type or a TokenCredential type. Required. :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials_async.AsyncTokenCredential :keyword api_version: The API version to use for this operation. Default value is "2024-05-01-preview". Note that overriding this default value may result in unsupported diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_configuration.py b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_configuration.py index e4c5d7111d22..36e81c17a277 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_configuration.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_configuration.py @@ -14,11 +14,10 @@ from .._version import VERSION if TYPE_CHECKING: - # pylint: disable=unused-import,ungrouped-imports from azure.core.credentials_async import AsyncTokenCredential -class ChatCompletionsClientConfiguration: # pylint: disable=too-many-instance-attributes,name-too-long +class ChatCompletionsClientConfiguration: # pylint: disable=too-many-instance-attributes """Configuration for ChatCompletionsClient. Note that all parameters used to create this instance are saved as instance @@ -26,10 +25,9 @@ class ChatCompletionsClientConfiguration: # pylint: disable=too-many-instance-a :param endpoint: Service host. Required. :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is one of the - following types: AzureKeyCredential, AzureKeyCredential, TokenCredential Required. + :param credential: Credential used to authenticate requests to the service. Is either a + AzureKeyCredential type or a TokenCredential type. Required. :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials_async.AsyncTokenCredential :keyword api_version: The API version to use for this operation. Default value is "2024-05-01-preview". Note that overriding this default value may result in unsupported @@ -58,8 +56,6 @@ def __init__( def _infer_policy(self, **kwargs): if isinstance(self.credential, AzureKeyCredential): return policies.AzureKeyCredentialPolicy(self.credential, "Authorization", prefix="Bearer", **kwargs) - if isinstance(self.credential, AzureKeyCredential): - return policies.AzureKeyCredentialPolicy(self.credential, "api-key", **kwargs) if hasattr(self.credential, "get_token"): return policies.AsyncBearerTokenCredentialPolicy(self.credential, *self.credential_scopes, **kwargs) raise TypeError(f"Unsupported credential: {self.credential}") @@ -78,7 +74,7 @@ def _configure(self, **kwargs: Any) -> None: self.authentication_policy = self._infer_policy(**kwargs) -class EmbeddingsClientConfiguration: # pylint: disable=too-many-instance-attributes,name-too-long +class EmbeddingsClientConfiguration: # pylint: disable=too-many-instance-attributes """Configuration for EmbeddingsClient. Note that all parameters used to create this instance are saved as instance @@ -86,10 +82,9 @@ class EmbeddingsClientConfiguration: # pylint: disable=too-many-instance-attrib :param endpoint: Service host. Required. :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is one of the - following types: AzureKeyCredential, AzureKeyCredential, TokenCredential Required. + :param credential: Credential used to authenticate requests to the service. Is either a + AzureKeyCredential type or a TokenCredential type. Required. :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials_async.AsyncTokenCredential :keyword api_version: The API version to use for this operation. Default value is "2024-05-01-preview". Note that overriding this default value may result in unsupported @@ -118,8 +113,6 @@ def __init__( def _infer_policy(self, **kwargs): if isinstance(self.credential, AzureKeyCredential): return policies.AzureKeyCredentialPolicy(self.credential, "Authorization", prefix="Bearer", **kwargs) - if isinstance(self.credential, AzureKeyCredential): - return policies.AzureKeyCredentialPolicy(self.credential, "api-key", **kwargs) if hasattr(self.credential, "get_token"): return policies.AsyncBearerTokenCredentialPolicy(self.credential, *self.credential_scopes, **kwargs) raise TypeError(f"Unsupported credential: {self.credential}") @@ -138,7 +131,7 @@ def _configure(self, **kwargs: Any) -> None: self.authentication_policy = self._infer_policy(**kwargs) -class ImageEmbeddingsClientConfiguration: # pylint: disable=too-many-instance-attributes,name-too-long +class ImageEmbeddingsClientConfiguration: # pylint: disable=too-many-instance-attributes """Configuration for ImageEmbeddingsClient. Note that all parameters used to create this instance are saved as instance @@ -146,10 +139,9 @@ class ImageEmbeddingsClientConfiguration: # pylint: disable=too-many-instance-a :param endpoint: Service host. Required. :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is one of the - following types: AzureKeyCredential, AzureKeyCredential, TokenCredential Required. + :param credential: Credential used to authenticate requests to the service. Is either a + AzureKeyCredential type or a TokenCredential type. Required. :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials_async.AsyncTokenCredential :keyword api_version: The API version to use for this operation. Default value is "2024-05-01-preview". Note that overriding this default value may result in unsupported @@ -178,8 +170,6 @@ def __init__( def _infer_policy(self, **kwargs): if isinstance(self.credential, AzureKeyCredential): return policies.AzureKeyCredentialPolicy(self.credential, "Authorization", prefix="Bearer", **kwargs) - if isinstance(self.credential, AzureKeyCredential): - return policies.AzureKeyCredentialPolicy(self.credential, "api-key", **kwargs) if hasattr(self.credential, "get_token"): return policies.AsyncBearerTokenCredentialPolicy(self.credential, *self.credential_scopes, **kwargs) raise TypeError(f"Unsupported credential: {self.credential}") diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_operations/__init__.py b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_operations/__init__.py index d3ebd561f739..ab87088736aa 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_operations/__init__.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_operations/__init__.py @@ -5,13 +5,19 @@ # Code generated by Microsoft (R) Python Code Generator. # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- +# pylint: disable=wrong-import-position -from ._operations import ChatCompletionsClientOperationsMixin -from ._operations import EmbeddingsClientOperationsMixin -from ._operations import ImageEmbeddingsClientOperationsMixin +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ._patch import * # pylint: disable=unused-wildcard-import + +from ._operations import ChatCompletionsClientOperationsMixin # type: ignore +from ._operations import EmbeddingsClientOperationsMixin # type: ignore +from ._operations import ImageEmbeddingsClientOperationsMixin # type: ignore from ._patch import __all__ as _patch_all -from ._patch import * # pylint: disable=unused-wildcard-import +from ._patch import * from ._patch import patch_sdk as _patch_sdk __all__ = [ @@ -19,5 +25,5 @@ "EmbeddingsClientOperationsMixin", "ImageEmbeddingsClientOperationsMixin", ] -__all__.extend([p for p in _patch_all if p not in __all__]) +__all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_operations/_operations.py b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_operations/_operations.py index 0be948bd275d..ebffaf05663c 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_operations/_operations.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_operations/_operations.py @@ -1,4 +1,3 @@ -# pylint: disable=too-many-lines,too-many-statements # coding=utf-8 # -------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. @@ -9,7 +8,7 @@ from io import IOBase import json import sys -from typing import Any, Callable, Dict, IO, List, Optional, Type, TypeVar, Union, overload +from typing import Any, Callable, Dict, IO, List, Optional, TypeVar, Union, overload from azure.core.exceptions import ( ClientAuthenticationError, @@ -41,7 +40,7 @@ if sys.version_info >= (3, 9): from collections.abc import MutableMapping else: - from typing import MutableMapping # type: ignore # pylint: disable=ungrouped-imports + from typing import MutableMapping # type: ignore JSON = MutableMapping[str, Any] # pylint: disable=unsubscriptable-object _Unset: Any = object() T = TypeVar("T") @@ -63,23 +62,9 @@ async def _complete( async def _complete( self, *, - messages: List[_models.ChatRequestMessage], + chat_completions_options: _models._models.ChatCompletionsOptions, extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, content_type: str = "application/json", - frequency_penalty: Optional[float] = None, - stream_parameter: Optional[bool] = None, - presence_penalty: Optional[float] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, - response_format: Optional[_models.ChatCompletionsResponseFormat] = None, - stop: Optional[List[str]] = None, - tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, - tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] - ] = None, - seed: Optional[int] = None, - model: Optional[str] = None, **kwargs: Any ) -> _models.ChatCompletions: ... @overload @@ -97,25 +82,10 @@ async def _complete( self, body: Union[JSON, IO[bytes]] = _Unset, *, - messages: List[_models.ChatRequestMessage] = _Unset, + chat_completions_options: _models._models.ChatCompletionsOptions = _Unset, extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, - frequency_penalty: Optional[float] = None, - stream_parameter: Optional[bool] = None, - presence_penalty: Optional[float] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, - response_format: Optional[_models.ChatCompletionsResponseFormat] = None, - stop: Optional[List[str]] = None, - tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, - tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] - ] = None, - seed: Optional[int] = None, - model: Optional[str] = None, **kwargs: Any ) -> _models.ChatCompletions: - # pylint: disable=too-many-locals """Gets chat completions for the provided chat messages. Completions support a wide variety of tasks and generate text that continues from or "completes" @@ -124,87 +94,16 @@ async def _complete( :param body: Is either a JSON type or a IO[bytes] type. Required. :type body: JSON or IO[bytes] - :keyword messages: The collection of context messages associated with this chat completions - request. - Typical usage begins with a chat message for the System role that provides instructions for - the behavior of the assistant, followed by alternating messages between the User and - Assistant roles. Required. - :paramtype messages: list[~azure.ai.inference.models.ChatRequestMessage] - :keyword extra_params: Controls what happens if extra parameters, undefined by the REST API, - are passed in the JSON request payload. - This sets the HTTP request header ``extra-parameters``. Known values are: "error", "drop", and - "pass-through". Default value is None. + :keyword chat_completions_options: Required. + :paramtype chat_completions_options: ~azure.ai.inference.models._models.ChatCompletionsOptions + :keyword extra_params: Known values are: "error", "drop", and "pass-through". Default value is + None. :paramtype extra_params: str or ~azure.ai.inference.models.ExtraParameters - :keyword frequency_penalty: A value that influences the probability of generated tokens - appearing based on their cumulative - frequency in generated text. - Positive values will make tokens less likely to appear as their frequency increases and - decrease the likelihood of the model repeating the same statements verbatim. - Supported range is [-2, 2]. Default value is None. - :paramtype frequency_penalty: float - :keyword stream_parameter: A value indicating whether chat completions should be streamed for - this request. Default value is None. - :paramtype stream_parameter: bool - :keyword presence_penalty: A value that influences the probability of generated tokens - appearing based on their existing - presence in generated text. - Positive values will make tokens less likely to appear when they already exist and increase - the - model's likelihood to output new topics. - Supported range is [-2, 2]. Default value is None. - :paramtype presence_penalty: float - :keyword temperature: The sampling temperature to use that controls the apparent creativity of - generated completions. - Higher values will make output more random while lower values will make results more focused - and deterministic. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. Default value is None. - :paramtype temperature: float - :keyword top_p: An alternative to sampling with temperature called nucleus sampling. This value - causes the - model to consider the results of tokens with the provided probability mass. As an example, a - value of 0.15 will cause only the tokens comprising the top 15% of probability mass to be - considered. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. Default value is None. - :paramtype top_p: float - :keyword max_tokens: The maximum number of tokens to generate. Default value is None. - :paramtype max_tokens: int - :keyword response_format: The format that the model must output. Use this to enable JSON mode - instead of the default text mode. - Note that to enable JSON mode, some AI models may also require you to instruct the model to - produce JSON - via a system or user message. Default value is None. - :paramtype response_format: ~azure.ai.inference.models.ChatCompletionsResponseFormat - :keyword stop: A collection of textual sequences that will end completions generation. Default - value is None. - :paramtype stop: list[str] - :keyword tools: A list of tools the model may request to call. Currently, only functions are - supported as a tool. The model - may response with a function call request and provide the input arguments in JSON format for - that function. Default value is None. - :paramtype tools: list[~azure.ai.inference.models.ChatCompletionsToolDefinition] - :keyword tool_choice: If specified, the model will configure which of the provided tools it can - use for the chat completions response. Is either a Union[str, - "_models.ChatCompletionsToolChoicePreset"] type or a ChatCompletionsNamedToolChoice type. - Default value is None. - :paramtype tool_choice: str or ~azure.ai.inference.models.ChatCompletionsToolChoicePreset or - ~azure.ai.inference.models.ChatCompletionsNamedToolChoice - :keyword seed: If specified, the system will make a best effort to sample deterministically - such that repeated requests with the - same seed and parameters should return the same result. Determinism is not guaranteed. Default - value is None. - :paramtype seed: int - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str :return: ChatCompletions. The ChatCompletions is compatible with MutableMapping :rtype: ~azure.ai.inference.models.ChatCompletions :raises ~azure.core.exceptions.HttpResponseError: """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { + error_map: MutableMapping = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError, @@ -219,23 +118,9 @@ async def _complete( cls: ClsType[_models.ChatCompletions] = kwargs.pop("cls", None) if body is _Unset: - if messages is _Unset: - raise TypeError("missing required argument: messages") - body = { - "frequency_penalty": frequency_penalty, - "max_tokens": max_tokens, - "messages": messages, - "model": model, - "presence_penalty": presence_penalty, - "response_format": response_format, - "seed": seed, - "stop": stop, - "stream": stream_parameter, - "temperature": temperature, - "tool_choice": tool_choice, - "tools": tools, - "top_p": top_p, - } + if chat_completions_options is _Unset: + raise TypeError("missing required argument: chat_completions_options") + body = {"chatCompletionsOptions": chat_completions_options} body = {k: v for k, v in body.items() if v is not None} content_type = content_type or "application/json" _content = None @@ -287,14 +172,12 @@ async def _complete( async def _get_model_info(self, **kwargs: Any) -> _models.ModelInfo: """Returns information about the AI model. The method makes a REST API call to the ``/info`` route on the given endpoint. - This method will only work when using Serverless API or Managed Compute endpoint. - It will not work for GitHub Models endpoint or Azure OpenAI endpoint. :return: ModelInfo. The ModelInfo is compatible with MutableMapping :rtype: ~azure.ai.inference.models.ModelInfo :raises ~azure.core.exceptions.HttpResponseError: """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { + error_map: MutableMapping = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError, @@ -425,7 +308,7 @@ async def _embed( :rtype: ~azure.ai.inference.models.EmbeddingsResult :raises ~azure.core.exceptions.HttpResponseError: """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { + error_map: MutableMapping = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError, @@ -500,14 +383,12 @@ async def _embed( async def _get_model_info(self, **kwargs: Any) -> _models.ModelInfo: """Returns information about the AI model. The method makes a REST API call to the ``/info`` route on the given endpoint. - This method will only work when using Serverless API or Managed Compute endpoint. - It will not work for GitHub Models endpoint or Azure OpenAI endpoint. :return: ModelInfo. The ModelInfo is compatible with MutableMapping :rtype: ~azure.ai.inference.models.ModelInfo :raises ~azure.core.exceptions.HttpResponseError: """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { + error_map: MutableMapping = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError, @@ -641,7 +522,7 @@ async def _embed( :rtype: ~azure.ai.inference.models.EmbeddingsResult :raises ~azure.core.exceptions.HttpResponseError: """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { + error_map: MutableMapping = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError, @@ -716,14 +597,12 @@ async def _embed( async def _get_model_info(self, **kwargs: Any) -> _models.ModelInfo: """Returns information about the AI model. The method makes a REST API call to the ``/info`` route on the given endpoint. - This method will only work when using Serverless API or Managed Compute endpoint. - It will not work for GitHub Models endpoint or Azure OpenAI endpoint. :return: ModelInfo. The ModelInfo is compatible with MutableMapping :rtype: ~azure.ai.inference.models.ModelInfo :raises ~azure.core.exceptions.HttpResponseError: """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { + error_map: MutableMapping = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError, diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_patch.py b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_patch.py index 2bdfd67a40cb..f7dd32510333 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_patch.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_patch.py @@ -2,1266 +2,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -# pylint: disable=too-many-lines) """Customize generated code here. Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize """ -import json -import logging -import sys +from typing import List -from io import IOBase -from typing import Any, Dict, Union, IO, List, Literal, Optional, overload, Type, TYPE_CHECKING, AsyncIterable - -from azure.core.pipeline import PipelineResponse -from azure.core.credentials import AzureKeyCredential -from azure.core.tracing.decorator_async import distributed_trace_async -from azure.core.utils import case_insensitive_dict -from azure.core.exceptions import ( - ClientAuthenticationError, - HttpResponseError, - map_error, - ResourceExistsError, - ResourceNotFoundError, - ResourceNotModifiedError, -) -from .. import models as _models -from .._model_base import SdkJSONEncoder, _deserialize -from ._client import ChatCompletionsClient as ChatCompletionsClientGenerated -from ._client import EmbeddingsClient as EmbeddingsClientGenerated -from ._client import ImageEmbeddingsClient as ImageEmbeddingsClientGenerated -from .._operations._operations import ( - build_chat_completions_complete_request, - build_embeddings_embed_request, - build_image_embeddings_embed_request, -) - -if TYPE_CHECKING: - # pylint: disable=unused-import,ungrouped-imports - from azure.core.credentials_async import AsyncTokenCredential - -if sys.version_info >= (3, 9): - from collections.abc import MutableMapping -else: - from typing import MutableMapping # type: ignore # pylint: disable=ungrouped-imports - -JSON = MutableMapping[str, Any] # pylint: disable=unsubscriptable-object -_Unset: Any = object() -_LOGGER = logging.getLogger(__name__) - - -async def load_client( - endpoint: str, credential: Union[AzureKeyCredential, "AsyncTokenCredential"], **kwargs: Any -) -> Union["ChatCompletionsClient", "EmbeddingsClient", "ImageEmbeddingsClient"]: - # pylint: disable=line-too-long - """ - Load a client from a given endpoint URL. The method makes a REST API call to the `/info` route - on the given endpoint, to determine the model type and therefore which client to instantiate. - - :param endpoint: Service host. Required. - :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is either a - AzureKeyCredential type or a AsyncTokenCredential type. Required. - :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials_async.AsyncTokenCredential - :keyword api_version: The API version to use for this operation. Default value is - "2024-05-01-preview". Note that overriding this default value may result in unsupported - behavior. - :paramtype api_version: str - :return: The appropriate asynchronous client associated with the given endpoint - :rtype: ~azure.ai.inference.aio.ChatCompletionsClient or ~azure.ai.inference.aio.EmbeddingsClient - or ~azure.ai.inference.aio.ImageEmbeddingsClient - :raises ~azure.core.exceptions.HttpResponseError: - """ - - async with ChatCompletionsClient( - endpoint, credential, **kwargs - ) as client: # Pick any of the clients, it does not matter. - model_info = await client.get_model_info() # type: ignore - - _LOGGER.info("model_info=%s", model_info) - if not model_info.model_type: - raise ValueError( - "The AI model information is missing a value for `model type`. Cannot create an appropriate client." - ) - - # TODO: Remove "completions" and "embedding" once Mistral Large and Cohere fixes their model type - if model_info.model_type in (_models.ModelType.CHAT, "completion", "chat-completion", "chat-completions"): - chat_completion_client = ChatCompletionsClient(endpoint, credential, **kwargs) - chat_completion_client._model_info = ( # pylint: disable=protected-access,attribute-defined-outside-init - model_info - ) - return chat_completion_client - - if model_info.model_type in (_models.ModelType.EMBEDDINGS, "embedding"): - embedding_client = EmbeddingsClient(endpoint, credential, **kwargs) - embedding_client._model_info = model_info # pylint: disable=protected-access,attribute-defined-outside-init - return embedding_client - - if model_info.model_type == _models.ModelType.IMAGE_EMBEDDINGS: - image_embedding_client = ImageEmbeddingsClient(endpoint, credential, **kwargs) - image_embedding_client._model_info = ( # pylint: disable=protected-access,attribute-defined-outside-init - model_info - ) - return image_embedding_client - - raise ValueError(f"No client available to support AI model type `{model_info.model_type}`") - - -class ChatCompletionsClient(ChatCompletionsClientGenerated): # pylint: disable=too-many-instance-attributes - """ChatCompletionsClient. - - :param endpoint: Service host. Required. - :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is either a - AzureKeyCredential type or a AsyncTokenCredential type. Required. - :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials_async.AsyncTokenCredential - :keyword frequency_penalty: A value that influences the probability of generated tokens - appearing based on their cumulative frequency in generated text. - Positive values will make tokens less likely to appear as their frequency increases and - decrease the likelihood of the model repeating the same statements verbatim. - Supported range is [-2, 2]. - Default value is None. - :paramtype frequency_penalty: float - :keyword presence_penalty: A value that influences the probability of generated tokens - appearing based on their existing - presence in generated text. - Positive values will make tokens less likely to appear when they already exist and increase - the model's likelihood to output new topics. - Supported range is [-2, 2]. - Default value is None. - :paramtype presence_penalty: float - :keyword temperature: The sampling temperature to use that controls the apparent creativity of - generated completions. - Higher values will make output more random while lower values will make results more focused - and deterministic. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. - Default value is None. - :paramtype temperature: float - :keyword top_p: An alternative to sampling with temperature called nucleus sampling. This value - causes the - model to consider the results of tokens with the provided probability mass. As an example, a - value of 0.15 will cause only the tokens comprising the top 15% of probability mass to be - considered. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. - Default value is None. - :paramtype top_p: float - :keyword max_tokens: The maximum number of tokens to generate. Default value is None. - :paramtype max_tokens: int - :keyword response_format: The format that the model must output. Use this to enable JSON mode - instead of the default text mode. - Note that to enable JSON mode, some AI models may also require you to instruct the model to - produce JSON via a system or user message. Default value is None. - :paramtype response_format: ~azure.ai.inference.models.ChatCompletionsResponseFormat - :keyword stop: A collection of textual sequences that will end completions generation. Default - value is None. - :paramtype stop: list[str] - :keyword tools: The available tool definitions that the chat completions request can use, - including caller-defined functions. Default value is None. - :paramtype tools: list[~azure.ai.inference.models.ChatCompletionsToolDefinition] - :keyword tool_choice: If specified, the model will configure which of the provided tools it can - use for the chat completions response. Is either a Union[str, - "_models.ChatCompletionsToolChoicePreset"] type or a ChatCompletionsNamedToolChoice type. - Default value is None. - :paramtype tool_choice: str or ~azure.ai.inference.models.ChatCompletionsToolChoicePreset or - ~azure.ai.inference.models.ChatCompletionsNamedToolChoice - :keyword seed: If specified, the system will make a best effort to sample deterministically - such that repeated requests with the - same seed and parameters should return the same result. Determinism is not guaranteed. - Default value is None. - :paramtype seed: int - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :keyword api_version: The API version to use for this operation. Default value is - "2024-05-01-preview". Note that overriding this default value may result in unsupported - behavior. - :paramtype api_version: str - """ - - def __init__( - self, - endpoint: str, - credential: Union[AzureKeyCredential, "AsyncTokenCredential"], - *, - frequency_penalty: Optional[float] = None, - presence_penalty: Optional[float] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, - response_format: Optional[_models.ChatCompletionsResponseFormat] = None, - stop: Optional[List[str]] = None, - tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, - tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] - ] = None, - seed: Optional[int] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> None: - - self._model_info: Optional[_models.ModelInfo] = None - - # Store default chat completions settings, to be applied in all future service calls - # unless overridden by arguments in the `complete` method. - self._frequency_penalty = frequency_penalty - self._presence_penalty = presence_penalty - self._temperature = temperature - self._top_p = top_p - self._max_tokens = max_tokens - self._response_format = response_format - self._stop = stop - self._tools = tools - self._tool_choice = tool_choice - self._seed = seed - self._model = model - self._model_extras = model_extras - - # For Key auth, we need to send these two auth HTTP request headers simultaneously: - # 1. "Authorization: Bearer " - # 2. "api-key: " - # This is because Serverless API, Managed Compute and GitHub endpoints support the first header, - # and Azure OpenAI and the new Unified Inference endpoints support the second header. - # The first header will be taken care of by auto-generated code. - # The second one is added here. - if isinstance(credential, AzureKeyCredential): - headers = kwargs.pop("headers", {}) - if "api-key" not in headers: - headers["api-key"] = credential.key - kwargs["headers"] = headers - - super().__init__(endpoint, credential, **kwargs) - - @overload - async def complete( - self, - *, - messages: List[_models.ChatRequestMessage], - stream: Literal[False] = False, - frequency_penalty: Optional[float] = None, - presence_penalty: Optional[float] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, - response_format: Optional[_models.ChatCompletionsResponseFormat] = None, - stop: Optional[List[str]] = None, - tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, - tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] - ] = None, - seed: Optional[int] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> _models.ChatCompletions: ... - - @overload - async def complete( - self, - *, - messages: List[_models.ChatRequestMessage], - stream: Literal[True], - frequency_penalty: Optional[float] = None, - presence_penalty: Optional[float] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, - response_format: Optional[_models.ChatCompletionsResponseFormat] = None, - stop: Optional[List[str]] = None, - tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, - tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] - ] = None, - seed: Optional[int] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> AsyncIterable[_models.StreamingChatCompletionsUpdate]: ... - - @overload - async def complete( - self, - *, - messages: List[_models.ChatRequestMessage], - stream: Optional[bool] = None, - frequency_penalty: Optional[float] = None, - presence_penalty: Optional[float] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, - response_format: Optional[_models.ChatCompletionsResponseFormat] = None, - stop: Optional[List[str]] = None, - tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, - tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] - ] = None, - seed: Optional[int] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> Union[AsyncIterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: - # pylint: disable=line-too-long - """Gets chat completions for the provided chat messages. - Completions support a wide variety of tasks and generate text that continues from or - "completes" provided prompt data. The method makes a REST API call to the `/chat/completions` route - on the given endpoint. - When using this method with `stream=True`, the response is streamed - back to the client. Iterate over the resulting StreamingChatCompletions - object to get content updates as they arrive. By default, the response is a ChatCompletions object - (non-streaming). - - :keyword messages: The collection of context messages associated with this chat completions - request. - Typical usage begins with a chat message for the System role that provides instructions for - the behavior of the assistant, followed by alternating messages between the User and - Assistant roles. Required. - :paramtype messages: list[~azure.ai.inference.models.ChatRequestMessage] - :keyword stream: A value indicating whether chat completions should be streamed for this request. - Default value is False. If streaming is enabled, the response will be a StreamingChatCompletions. - Otherwise the response will be a ChatCompletions. - :paramtype stream: bool - :keyword frequency_penalty: A value that influences the probability of generated tokens - appearing based on their cumulative frequency in generated text. - Positive values will make tokens less likely to appear as their frequency increases and - decrease the likelihood of the model repeating the same statements verbatim. - Supported range is [-2, 2]. - Default value is None. - :paramtype frequency_penalty: float - :keyword presence_penalty: A value that influences the probability of generated tokens - appearing based on their existing - presence in generated text. - Positive values will make tokens less likely to appear when they already exist and increase - the model's likelihood to output new topics. - Supported range is [-2, 2]. - Default value is None. - :paramtype presence_penalty: float - :keyword temperature: The sampling temperature to use that controls the apparent creativity of - generated completions. - Higher values will make output more random while lower values will make results more focused - and deterministic. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. - Default value is None. - :paramtype temperature: float - :keyword top_p: An alternative to sampling with temperature called nucleus sampling. This value - causes the - model to consider the results of tokens with the provided probability mass. As an example, a - value of 0.15 will cause only the tokens comprising the top 15% of probability mass to be - considered. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. - Default value is None. - :paramtype top_p: float - :keyword max_tokens: The maximum number of tokens to generate. Default value is None. - :paramtype max_tokens: int - :keyword response_format: The format that the model must output. Use this to enable JSON mode - instead of the default text mode. - Note that to enable JSON mode, some AI models may also require you to instruct the model to - produce JSON via a system or user message. Default value is None. - :paramtype response_format: ~azure.ai.inference.models.ChatCompletionsResponseFormat - :keyword stop: A collection of textual sequences that will end completions generation. Default - value is None. - :paramtype stop: list[str] - :keyword tools: The available tool definitions that the chat completions request can use, - including caller-defined functions. Default value is None. - :paramtype tools: list[~azure.ai.inference.models.ChatCompletionsToolDefinition] - :keyword tool_choice: If specified, the model will configure which of the provided tools it can - use for the chat completions response. Is either a Union[str, - "_models.ChatCompletionsToolChoicePreset"] type or a ChatCompletionsNamedToolChoice type. - Default value is None. - :paramtype tool_choice: str or ~azure.ai.inference.models.ChatCompletionsToolChoicePreset or - ~azure.ai.inference.models.ChatCompletionsNamedToolChoice - :keyword seed: If specified, the system will make a best effort to sample deterministically - such that repeated requests with the - same seed and parameters should return the same result. Determinism is not guaranteed. - Default value is None. - :paramtype seed: int - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :return: ChatCompletions for non-streaming, or AsyncIterable[StreamingChatCompletionsUpdate] for streaming. - :rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.AsyncStreamingChatCompletions - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - async def complete( - self, - body: JSON, - *, - content_type: str = "application/json", - **kwargs: Any, - ) -> Union[AsyncIterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: - # pylint: disable=line-too-long - """Gets chat completions for the provided chat messages. - Completions support a wide variety of tasks and generate text that continues from or - "completes" provided prompt data. - - :param body: An object of type MutableMapping[str, Any], such as a dictionary, that - specifies the full request payload. Required. - :type body: JSON - :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. - Default value is "application/json". - :paramtype content_type: str - :return: ChatCompletions for non-streaming, or AsyncIterable[StreamingChatCompletionsUpdate] for streaming. - :rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.AsyncStreamingChatCompletions - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - async def complete( - self, - body: IO[bytes], - *, - content_type: str = "application/json", - **kwargs: Any, - ) -> Union[AsyncIterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: - # pylint: disable=line-too-long - """Gets chat completions for the provided chat messages. - Completions support a wide variety of tasks and generate text that continues from or - "completes" provided prompt data. - - :param body: Specifies the full request payload. Required. - :type body: IO[bytes] - :keyword content_type: Body Parameter content-type. Content type parameter for binary body. - Default value is "application/json". - :paramtype content_type: str - :return: ChatCompletions for non-streaming, or AsyncIterable[StreamingChatCompletionsUpdate] for streaming. - :rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.AsyncStreamingChatCompletions - :raises ~azure.core.exceptions.HttpResponseError: - """ - - # pylint:disable=client-method-missing-tracing-decorator-async - async def complete( - self, - body: Union[JSON, IO[bytes]] = _Unset, - *, - messages: List[_models.ChatRequestMessage] = _Unset, - stream: Optional[bool] = None, - frequency_penalty: Optional[float] = None, - presence_penalty: Optional[float] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, - response_format: Optional[_models.ChatCompletionsResponseFormat] = None, - stop: Optional[List[str]] = None, - tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, - tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] - ] = None, - seed: Optional[int] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> Union[AsyncIterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: - # pylint: disable=line-too-long - # pylint: disable=too-many-locals - """Gets chat completions for the provided chat messages. - Completions support a wide variety of tasks and generate text that continues from or - "completes" provided prompt data. When using this method with `stream=True`, the response is streamed - back to the client. Iterate over the resulting :class:`~azure.ai.inference.models.StreamingChatCompletions` - object to get content updates as they arrive. - - :param body: Is either a MutableMapping[str, Any] type (like a dictionary) or a IO[bytes] type - that specifies the full request payload. Required. - :type body: JSON or IO[bytes] - :keyword messages: The collection of context messages associated with this chat completions - request. - Typical usage begins with a chat message for the System role that provides instructions for - the behavior of the assistant, followed by alternating messages between the User and - Assistant roles. Required. - :paramtype messages: list[~azure.ai.inference.models.ChatRequestMessage] - :keyword stream: A value indicating whether chat completions should be streamed for this request. - Default value is False. If streaming is enabled, the response will be a StreamingChatCompletions. - Otherwise the response will be a ChatCompletions. - :paramtype stream: bool - :keyword frequency_penalty: A value that influences the probability of generated tokens - appearing based on their cumulative frequency in generated text. - Positive values will make tokens less likely to appear as their frequency increases and - decrease the likelihood of the model repeating the same statements verbatim. - Supported range is [-2, 2]. - Default value is None. - :paramtype frequency_penalty: float - :keyword presence_penalty: A value that influences the probability of generated tokens - appearing based on their existing - presence in generated text. - Positive values will make tokens less likely to appear when they already exist and increase - the model's likelihood to output new topics. - Supported range is [-2, 2]. - Default value is None. - :paramtype presence_penalty: float - :keyword temperature: The sampling temperature to use that controls the apparent creativity of - generated completions. - Higher values will make output more random while lower values will make results more focused - and deterministic. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. - Default value is None. - :paramtype temperature: float - :keyword top_p: An alternative to sampling with temperature called nucleus sampling. This value - causes the - model to consider the results of tokens with the provided probability mass. As an example, a - value of 0.15 will cause only the tokens comprising the top 15% of probability mass to be - considered. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. - Default value is None. - :paramtype top_p: float - :keyword max_tokens: The maximum number of tokens to generate. Default value is None. - :paramtype max_tokens: int - :keyword response_format: The format that the model must output. Use this to enable JSON mode - instead of the default text mode. - Note that to enable JSON mode, some AI models may also require you to instruct the model to - produce JSON via a system or user message. Default value is None. - :paramtype response_format: ~azure.ai.inference.models.ChatCompletionsResponseFormat - :keyword stop: A collection of textual sequences that will end completions generation. Default - value is None. - :paramtype stop: list[str] - :keyword tools: The available tool definitions that the chat completions request can use, - including caller-defined functions. Default value is None. - :paramtype tools: list[~azure.ai.inference.models.ChatCompletionsToolDefinition] - :keyword tool_choice: If specified, the model will configure which of the provided tools it can - use for the chat completions response. Is either a Union[str, - "_models.ChatCompletionsToolChoicePreset"] type or a ChatCompletionsNamedToolChoice type. - Default value is None. - :paramtype tool_choice: str or ~azure.ai.inference.models.ChatCompletionsToolChoicePreset or - ~azure.ai.inference.models.ChatCompletionsNamedToolChoice - :keyword seed: If specified, the system will make a best effort to sample deterministically - such that repeated requests with the - same seed and parameters should return the same result. Determinism is not guaranteed. - Default value is None. - :paramtype seed: int - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :return: ChatCompletions for non-streaming, or AsyncIterable[StreamingChatCompletionsUpdate] for streaming. - :rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.AsyncStreamingChatCompletions - :raises ~azure.core.exceptions.HttpResponseError: - """ - error_map = { - 401: ClientAuthenticationError, - 404: ResourceNotFoundError, - 409: ResourceExistsError, - 304: ResourceNotModifiedError, - } - error_map.update(kwargs.pop("error_map", {}) or {}) - - _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) - _params = kwargs.pop("params", {}) or {} - _extra_parameters: Union[_models._enums.ExtraParameters, None] = None - - content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - - if body is _Unset: - if messages is _Unset: - raise TypeError("missing required argument: messages") - body = { - "messages": messages, - "stream": stream, - "frequency_penalty": frequency_penalty if frequency_penalty is not None else self._frequency_penalty, - "max_tokens": max_tokens if max_tokens is not None else self._max_tokens, - "model": model if model is not None else self._model, - "presence_penalty": presence_penalty if presence_penalty is not None else self._presence_penalty, - "response_format": response_format if response_format is not None else self._response_format, - "seed": seed if seed is not None else self._seed, - "stop": stop if stop is not None else self._stop, - "temperature": temperature if temperature is not None else self._temperature, - "tool_choice": tool_choice if tool_choice is not None else self._tool_choice, - "tools": tools if tools is not None else self._tools, - "top_p": top_p if top_p is not None else self._top_p, - } - if model_extras is not None and bool(model_extras): - body.update(model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access - elif self._model_extras is not None and bool(self._model_extras): - body.update(self._model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access - body = {k: v for k, v in body.items() if v is not None} - elif isinstance(body, dict) and "stream" in body and isinstance(body["stream"], bool): - stream = body["stream"] - content_type = content_type or "application/json" - _content = None - if isinstance(body, (IOBase, bytes)): - _content = body - else: - _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore - - _request = build_chat_completions_complete_request( - extra_params=_extra_parameters, - content_type=content_type, - api_version=self._config.api_version, - content=_content, - headers=_headers, - params=_params, - ) - path_format_arguments = { - "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), - } - _request.url = self._client.format_url(_request.url, **path_format_arguments) - - _stream = stream or False - pipeline_response: PipelineResponse = await self._client._pipeline.run( # type: ignore # pylint: disable=protected-access - _request, stream=_stream, **kwargs - ) - - response = pipeline_response.http_response - - if response.status_code not in [200]: - if _stream: - await response.read() # Load the body in memory and close the socket - map_error(status_code=response.status_code, response=response, error_map=error_map) - raise HttpResponseError(response=response) - - if _stream: - return _models.AsyncStreamingChatCompletions(response) - - return _deserialize(_models._patch.ChatCompletions, response.json()) # pylint: disable=protected-access - - @distributed_trace_async - async def get_model_info(self, **kwargs: Any) -> _models.ModelInfo: - # pylint: disable=line-too-long - """Returns information about the AI model. - The method makes a REST API call to the ``/info`` route on the given endpoint. - This method will only work when using Serverless API or Managed Compute endpoint. - It will not work for GitHub Models endpoint or Azure OpenAI endpoint. - - :return: ModelInfo. The ModelInfo is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.ModelInfo - :raises ~azure.core.exceptions.HttpResponseError: - """ - if not self._model_info: - self._model_info = await self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init - return self._model_info - - def __str__(self) -> str: - # pylint: disable=client-method-name-no-double-underscore - return super().__str__() + f"\n{self._model_info}" if self._model_info else super().__str__() - - -class EmbeddingsClient(EmbeddingsClientGenerated): - """EmbeddingsClient. - - :param endpoint: Service host. Required. - :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is either a - AzureKeyCredential type or a AsyncTokenCredential type. Required. - :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials_async.AsyncTokenCredential - :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should - have. Default value is None. - :paramtype dimensions: int - :keyword encoding_format: Optional. The desired format for the returned embeddings. - Known values are: - "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. - :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat - :keyword input_type: Optional. The type of the input. Known values are: - "text", "query", and "document". Default value is None. - :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :keyword api_version: The API version to use for this operation. Default value is - "2024-05-01-preview". Note that overriding this default value may result in unsupported - behavior. - :paramtype api_version: str - """ - - def __init__( - self, - endpoint: str, - credential: Union[AzureKeyCredential, "AsyncTokenCredential"], - *, - dimensions: Optional[int] = None, - encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, - input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> None: - - self._model_info: Optional[_models.ModelInfo] = None - - # Store default embeddings settings, to be applied in all future service calls - # unless overridden by arguments in the `embed` method. - self._dimensions = dimensions - self._encoding_format = encoding_format - self._input_type = input_type - self._model = model - self._model_extras = model_extras - - # For Key auth, we need to send these two auth HTTP request headers simultaneously: - # 1. "Authorization: Bearer " - # 2. "api-key: " - # This is because Serverless API, Managed Compute and GitHub endpoints support the first header, - # and Azure OpenAI and the new Unified Inference endpoints support the second header. - # The first header will be taken care of by auto-generated code. - # The second one is added here. - if isinstance(credential, AzureKeyCredential): - headers = kwargs.pop("headers", {}) - if "api-key" not in headers: - headers["api-key"] = credential.key - kwargs["headers"] = headers - - super().__init__(endpoint, credential, **kwargs) - - @overload - async def embed( - self, - *, - input: List[str], - dimensions: Optional[int] = None, - encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, - input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> _models.EmbeddingsResult: - """Return the embedding vectors for given text prompts. - The method makes a REST API call to the `/embeddings` route on the given endpoint. - - :keyword input: Input text to embed, encoded as a string or array of tokens. - To embed multiple inputs in a single request, pass an array - of strings or array of token arrays. Required. - :paramtype input: list[str] - :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should - have. Default value is None. - :paramtype dimensions: int - :keyword encoding_format: Optional. The desired format for the returned embeddings. - Known values are: - "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. - :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat - :keyword input_type: Optional. The type of the input. Known values are: - "text", "query", and "document". Default value is None. - :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - async def embed( - self, - body: JSON, - *, - content_type: str = "application/json", - **kwargs: Any, - ) -> _models.EmbeddingsResult: - """Return the embedding vectors for given text prompts. - The method makes a REST API call to the `/embeddings` route on the given endpoint. - - :param body: An object of type MutableMapping[str, Any], such as a dictionary, that - specifies the full request payload. Required. - :type body: JSON - :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. - Default value is "application/json". - :paramtype content_type: str - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - async def embed( - self, - body: IO[bytes], - *, - content_type: str = "application/json", - **kwargs: Any, - ) -> _models.EmbeddingsResult: - """Return the embedding vectors for given text prompts. - The method makes a REST API call to the `/embeddings` route on the given endpoint. - - :param body: Specifies the full request payload. Required. - :type body: IO[bytes] - :keyword content_type: Body Parameter content-type. Content type parameter for binary body. - Default value is "application/json". - :paramtype content_type: str - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @distributed_trace_async - async def embed( - self, - body: Union[JSON, IO[bytes]] = _Unset, - *, - input: List[str] = _Unset, - dimensions: Optional[int] = None, - encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, - input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> _models.EmbeddingsResult: - # pylint: disable=line-too-long - """Return the embedding vectors for given text prompts. - The method makes a REST API call to the `/embeddings` route on the given endpoint. - - :param body: Is either a MutableMapping[str, Any] type (like a dictionary) or a IO[bytes] type - that specifies the full request payload. Required. - :type body: JSON or IO[bytes] - :keyword input: Input text to embed, encoded as a string or array of tokens. - To embed multiple inputs in a single request, pass an array - of strings or array of token arrays. Required. - :paramtype input: list[str] - :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should - have. Default value is None. - :paramtype dimensions: int - :keyword encoding_format: Optional. The desired format for the returned embeddings. - Known values are: - "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. - :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat - :keyword input_type: Optional. The type of the input. Known values are: - "text", "query", and "document". Default value is None. - :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { - 401: ClientAuthenticationError, - 404: ResourceNotFoundError, - 409: ResourceExistsError, - 304: ResourceNotModifiedError, - } - error_map.update(kwargs.pop("error_map", {}) or {}) - - _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) - _params = kwargs.pop("params", {}) or {} - _extra_parameters: Union[_models._enums.ExtraParameters, None] = None - - content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - - if body is _Unset: - if input is _Unset: - raise TypeError("missing required argument: input") - body = { - "input": input, - "dimensions": dimensions if dimensions is not None else self._dimensions, - "encoding_format": encoding_format if encoding_format is not None else self._encoding_format, - "input_type": input_type if input_type is not None else self._input_type, - "model": model if model is not None else self._model, - } - if model_extras is not None and bool(model_extras): - body.update(model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access - elif self._model_extras is not None and bool(self._model_extras): - body.update(self._model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access - body = {k: v for k, v in body.items() if v is not None} - content_type = content_type or "application/json" - _content = None - if isinstance(body, (IOBase, bytes)): - _content = body - else: - _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore - - _request = build_embeddings_embed_request( - extra_params=_extra_parameters, - content_type=content_type, - api_version=self._config.api_version, - content=_content, - headers=_headers, - params=_params, - ) - path_format_arguments = { - "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), - } - _request.url = self._client.format_url(_request.url, **path_format_arguments) - - _stream = kwargs.pop("stream", False) - pipeline_response: PipelineResponse = await self._client._pipeline.run( # type: ignore # pylint: disable=protected-access - _request, stream=_stream, **kwargs - ) - - response = pipeline_response.http_response - - if response.status_code not in [200]: - if _stream: - await response.read() # Load the body in memory and close the socket - map_error(status_code=response.status_code, response=response, error_map=error_map) - raise HttpResponseError(response=response) - - if _stream: - deserialized = response.iter_bytes() - else: - deserialized = _deserialize( - _models._patch.EmbeddingsResult, response.json() # pylint: disable=protected-access - ) - - return deserialized # type: ignore - - @distributed_trace_async - async def get_model_info(self, **kwargs: Any) -> _models.ModelInfo: - # pylint: disable=line-too-long - """Returns information about the AI model. - The method makes a REST API call to the ``/info`` route on the given endpoint. - This method will only work when using Serverless API or Managed Compute endpoint. - It will not work for GitHub Models endpoint or Azure OpenAI endpoint. - - :return: ModelInfo. The ModelInfo is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.ModelInfo - :raises ~azure.core.exceptions.HttpResponseError: - """ - if not self._model_info: - self._model_info = await self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init - return self._model_info - - def __str__(self) -> str: - # pylint: disable=client-method-name-no-double-underscore - return super().__str__() + f"\n{self._model_info}" if self._model_info else super().__str__() - - -class ImageEmbeddingsClient(ImageEmbeddingsClientGenerated): - """ImageEmbeddingsClient. - - :param endpoint: Service host. Required. - :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is either a - AzureKeyCredential type or a AsyncTokenCredential type. Required. - :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials_async.AsyncTokenCredential - :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should - have. Default value is None. - :paramtype dimensions: int - :keyword encoding_format: Optional. The desired format for the returned embeddings. - Known values are: - "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. - :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat - :keyword input_type: Optional. The type of the input. Known values are: - "text", "query", and "document". Default value is None. - :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :keyword api_version: The API version to use for this operation. Default value is - "2024-05-01-preview". Note that overriding this default value may result in unsupported - behavior. - :paramtype api_version: str - """ - - def __init__( - self, - endpoint: str, - credential: Union[AzureKeyCredential, "AsyncTokenCredential"], - *, - dimensions: Optional[int] = None, - encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, - input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> None: - - self._model_info: Optional[_models.ModelInfo] = None - - # Store default embeddings settings, to be applied in all future service calls - # unless overridden by arguments in the `embed` method. - self._dimensions = dimensions - self._encoding_format = encoding_format - self._input_type = input_type - self._model = model - self._model_extras = model_extras - - # For Key auth, we need to send these two auth HTTP request headers simultaneously: - # 1. "Authorization: Bearer " - # 2. "api-key: " - # This is because Serverless API, Managed Compute and GitHub endpoints support the first header, - # and Azure OpenAI and the new Unified Inference endpoints support the second header. - # The first header will be taken care of by auto-generated code. - # The second one is added here. - if isinstance(credential, AzureKeyCredential): - headers = kwargs.pop("headers", {}) - if "api-key" not in headers: - headers["api-key"] = credential.key - kwargs["headers"] = headers - - super().__init__(endpoint, credential, **kwargs) - - @overload - async def embed( - self, - *, - input: List[_models.EmbeddingInput], - dimensions: Optional[int] = None, - encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, - input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> _models.EmbeddingsResult: - """Return the embedding vectors for given images. - The method makes a REST API call to the `/images/embeddings` route on the given endpoint. - - :keyword input: Input image to embed. To embed multiple inputs in a single request, pass an - array. - The input must not exceed the max input tokens for the model. Required. - :paramtype input: list[~azure.ai.inference.models.EmbeddingInput] - :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should - have. Default value is None. - :paramtype dimensions: int - :keyword encoding_format: Optional. The desired format for the returned embeddings. - Known values are: - "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. - :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat - :keyword input_type: Optional. Known values are: - "text", "query", and "document". Default value is None. - :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - async def embed( - self, - body: JSON, - *, - content_type: str = "application/json", - **kwargs: Any, - ) -> _models.EmbeddingsResult: - """Return the embedding vectors for given images. - The method makes a REST API call to the `/images/embeddings` route on the given endpoint. - - :param body: An object of type MutableMapping[str, Any], such as a dictionary, that - specifies the full request payload. Required. - :type body: JSON - :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. - Default value is "application/json". - :paramtype content_type: str - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - async def embed( - self, - body: IO[bytes], - *, - content_type: str = "application/json", - **kwargs: Any, - ) -> _models.EmbeddingsResult: - """Return the embedding vectors for given images. - The method makes a REST API call to the `/images/embeddings` route on the given endpoint. - - :param body: Specifies the full request payload. Required. - :type body: IO[bytes] - :keyword content_type: Body Parameter content-type. Content type parameter for binary body. - Default value is "application/json". - :paramtype content_type: str - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @distributed_trace_async - async def embed( - self, - body: Union[JSON, IO[bytes]] = _Unset, - *, - input: List[_models.EmbeddingInput] = _Unset, - dimensions: Optional[int] = None, - encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, - input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> _models.EmbeddingsResult: - # pylint: disable=line-too-long - """Return the embedding vectors for given images. - The method makes a REST API call to the `/images/embeddings` route on the given endpoint. - - :param body: Is either a MutableMapping[str, Any] type (like a dictionary) or a IO[bytes] type - that specifies the full request payload. Required. - :type body: JSON or IO[bytes] - :keyword input: Input image to embed. To embed multiple inputs in a single request, pass an - array. - The input must not exceed the max input tokens for the model. Required. - :paramtype input: list[~azure.ai.inference.models.EmbeddingInput] - :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should - have. Default value is None. - :paramtype dimensions: int - :keyword encoding_format: Optional. The desired format for the returned embeddings. - Known values are: - "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. - :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat - :keyword input_type: Optional. The type of the input. Known values are: - "text", "query", and "document". Default value is None. - :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { - 401: ClientAuthenticationError, - 404: ResourceNotFoundError, - 409: ResourceExistsError, - 304: ResourceNotModifiedError, - } - error_map.update(kwargs.pop("error_map", {}) or {}) - - _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) - _params = kwargs.pop("params", {}) or {} - _extra_parameters: Union[_models._enums.ExtraParameters, None] = None - - content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - - if body is _Unset: - if input is _Unset: - raise TypeError("missing required argument: input") - body = { - "input": input, - "dimensions": dimensions if dimensions is not None else self._dimensions, - "encoding_format": encoding_format if encoding_format is not None else self._encoding_format, - "input_type": input_type if input_type is not None else self._input_type, - "model": model if model is not None else self._model, - } - if model_extras is not None and bool(model_extras): - body.update(model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access - elif self._model_extras is not None and bool(self._model_extras): - body.update(self._model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access - body = {k: v for k, v in body.items() if v is not None} - content_type = content_type or "application/json" - _content = None - if isinstance(body, (IOBase, bytes)): - _content = body - else: - _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore - - _request = build_image_embeddings_embed_request( - extra_params=_extra_parameters, - content_type=content_type, - api_version=self._config.api_version, - content=_content, - headers=_headers, - params=_params, - ) - path_format_arguments = { - "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), - } - _request.url = self._client.format_url(_request.url, **path_format_arguments) - - _stream = kwargs.pop("stream", False) - pipeline_response: PipelineResponse = await self._client._pipeline.run( # type: ignore # pylint: disable=protected-access - _request, stream=_stream, **kwargs - ) - - response = pipeline_response.http_response - - if response.status_code not in [200]: - if _stream: - await response.read() # Load the body in memory and close the socket - map_error(status_code=response.status_code, response=response, error_map=error_map) - raise HttpResponseError(response=response) - - if _stream: - deserialized = response.iter_bytes() - else: - deserialized = _deserialize( - _models._patch.EmbeddingsResult, response.json() # pylint: disable=protected-access - ) - - return deserialized # type: ignore - - @distributed_trace_async - async def get_model_info(self, **kwargs: Any) -> _models.ModelInfo: - # pylint: disable=line-too-long - """Returns information about the AI model. - The method makes a REST API call to the ``/info`` route on the given endpoint. - This method will only work when using Serverless API or Managed Compute endpoint. - It will not work for GitHub Models endpoint or Azure OpenAI endpoint. - - :return: ModelInfo. The ModelInfo is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.ModelInfo - :raises ~azure.core.exceptions.HttpResponseError: - """ - if not self._model_info: - self._model_info = await self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init - return self._model_info - - def __str__(self) -> str: - # pylint: disable=client-method-name-no-double-underscore - return super().__str__() + f"\n{self._model_info}" if self._model_info else super().__str__() - - -__all__: List[str] = [ - "load_client", - "ChatCompletionsClient", - "EmbeddingsClient", - "ImageEmbeddingsClient", -] # Add all objects you want publicly available to users at this package level +__all__: List[str] = [] # Add all objects you want publicly available to users at this package level def patch_sdk(): diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_vendor.py b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_vendor.py index dd91e1ea130f..b430582ca1fc 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_vendor.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_vendor.py @@ -15,7 +15,6 @@ ) if TYPE_CHECKING: - # pylint: disable=unused-import,ungrouped-imports from azure.core import AsyncPipelineClient from .._serialization import Deserializer, Serializer diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/models/__init__.py b/sdk/ai/azure-ai-inference/azure/ai/inference/models/__init__.py index 1832edc83399..5b296cd9a896 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/models/__init__.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/models/__init__.py @@ -5,59 +5,71 @@ # Code generated by Microsoft (R) Python Code Generator. # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- +# pylint: disable=wrong-import-position -from ._models import AssistantMessage -from ._models import ChatChoice -from ._patch import ChatCompletions -from ._models import ChatCompletionsNamedToolChoice -from ._models import ChatCompletionsNamedToolChoiceFunction -from ._models import ChatCompletionsResponseFormat -from ._models import ChatCompletionsResponseFormatJSON -from ._models import ChatCompletionsResponseFormatText -from ._models import ChatCompletionsToolCall -from ._models import ChatCompletionsToolDefinition -from ._models import ChatRequestMessage -from ._models import ChatResponseMessage -from ._models import CompletionsUsage -from ._models import ContentItem -from ._models import EmbeddingInput -from ._models import EmbeddingItem -from ._patch import EmbeddingsResult -from ._models import EmbeddingsUsage -from ._models import FunctionCall -from ._models import FunctionDefinition -from ._models import ImageContentItem -from ._patch import ImageUrl -from ._models import ModelInfo -from ._models import StreamingChatChoiceUpdate -from ._models import StreamingChatCompletionsUpdate -from ._models import StreamingChatResponseMessageUpdate -from ._models import StreamingChatResponseToolCallUpdate -from ._models import SystemMessage -from ._models import TextContentItem -from ._models import ToolMessage -from ._models import UserMessage +from typing import TYPE_CHECKING -from ._enums import ChatCompletionsToolChoicePreset -from ._enums import ChatRole -from ._enums import CompletionsFinishReason -from ._enums import EmbeddingEncodingFormat -from ._enums import EmbeddingInputType -from ._enums import ImageDetailLevel -from ._enums import ModelType +if TYPE_CHECKING: + from ._patch import * # pylint: disable=unused-wildcard-import -from ._patch import StreamingChatCompletions -from ._patch import AsyncStreamingChatCompletions + +from ._models import ( # type: ignore + AssistantMessage, + ChatChoice, + ChatCompletions, + ChatCompletionsFunctionToolCall, + ChatCompletionsFunctionToolDefinition, + ChatCompletionsFunctionToolSelection, + ChatCompletionsNamedFunctionToolSelection, + ChatCompletionsNamedToolSelection, + ChatCompletionsResponseFormat, + ChatCompletionsResponseFormatJSON, + ChatCompletionsResponseFormatText, + ChatCompletionsToolCall, + ChatCompletionsToolDefinition, + ChatRequestMessage, + ChatResponseMessage, + CompletionsUsage, + ContentItem, + EmbeddingInput, + EmbeddingItem, + EmbeddingsResult, + EmbeddingsUsage, + FunctionCall, + FunctionDefinition, + ImageContentItem, + ImageUrl, + ModelInfo, + StreamingChatChoiceUpdate, + StreamingChatCompletionsUpdate, + SystemMessage, + TextContentItem, + ToolMessage, + UserMessage, +) + +from ._enums import ( # type: ignore + ChatCompletionsToolSelectionPreset, + ChatRole, + CompletionsFinishReason, + EmbeddingEncodingFormat, + EmbeddingInputType, + ImageDetailLevel, + ModelType, +) +from ._patch import __all__ as _patch_all +from ._patch import * from ._patch import patch_sdk as _patch_sdk __all__ = [ - "StreamingChatCompletions", - "AsyncStreamingChatCompletions", "AssistantMessage", "ChatChoice", "ChatCompletions", - "ChatCompletionsNamedToolChoice", - "ChatCompletionsNamedToolChoiceFunction", + "ChatCompletionsFunctionToolCall", + "ChatCompletionsFunctionToolDefinition", + "ChatCompletionsFunctionToolSelection", + "ChatCompletionsNamedFunctionToolSelection", + "ChatCompletionsNamedToolSelection", "ChatCompletionsResponseFormat", "ChatCompletionsResponseFormatJSON", "ChatCompletionsResponseFormatText", @@ -78,13 +90,11 @@ "ModelInfo", "StreamingChatChoiceUpdate", "StreamingChatCompletionsUpdate", - "StreamingChatResponseMessageUpdate", - "StreamingChatResponseToolCallUpdate", "SystemMessage", "TextContentItem", "ToolMessage", "UserMessage", - "ChatCompletionsToolChoicePreset", + "ChatCompletionsToolSelectionPreset", "ChatRole", "CompletionsFinishReason", "EmbeddingEncodingFormat", @@ -92,5 +102,5 @@ "ImageDetailLevel", "ModelType", ] - +__all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/models/_enums.py b/sdk/ai/azure-ai-inference/azure/ai/inference/models/_enums.py index 830a93f75472..f999084b481c 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/models/_enums.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/models/_enums.py @@ -10,7 +10,7 @@ from azure.core import CaseInsensitiveEnumMeta -class ChatCompletionsToolChoicePreset(str, Enum, metaclass=CaseInsensitiveEnumMeta): +class ChatCompletionsToolSelectionPreset(str, Enum, metaclass=CaseInsensitiveEnumMeta): """Represents a generic policy for how a chat completions tool may be selected.""" AUTO = "auto" @@ -76,12 +76,11 @@ class EmbeddingInputType(str, Enum, metaclass=CaseInsensitiveEnumMeta): """Represents the input types used for embedding search.""" TEXT = "text" - """Indicates the input is a general text input.""" + """to do""" QUERY = "query" - """Indicates the input represents a search query to find the most relevant documents in your - vector database.""" + """to do""" DOCUMENT = "document" - """Indicates the input represents a document that is stored in a vector database.""" + """to do""" class ExtraParameters(str, Enum, metaclass=CaseInsensitiveEnumMeta): diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/models/_models.py b/sdk/ai/azure-ai-inference/azure/ai/inference/models/_models.py index 4ac8f16f94d1..fb140c2903b5 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/models/_models.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/models/_models.py @@ -1,11 +1,12 @@ -# coding=utf-8 # pylint: disable=too-many-lines +# coding=utf-8 # -------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # Code generated by Microsoft (R) Python Code Generator. # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- +# pylint: disable=useless-super-delegation import datetime from typing import Any, Dict, List, Literal, Mapping, Optional, TYPE_CHECKING, Union, overload @@ -15,7 +16,6 @@ from ._enums import ChatRole if TYPE_CHECKING: - # pylint: disable=unused-import,ungrouped-imports from .. import models as _models @@ -42,16 +42,16 @@ def __init__( self, *, role: str, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -88,16 +88,16 @@ def __init__( *, content: Optional[str] = None, tool_calls: Optional[List["_models.ChatCompletionsToolCall"]] = None, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, role=ChatRole.ASSISTANT, **kwargs) @@ -132,16 +132,16 @@ def __init__( index: int, finish_reason: Union[str, "_models.CompletionsFinishReason"], message: "_models.ChatResponseMessage", - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -193,49 +193,124 @@ def __init__( model: str, usage: "_models.CompletionsUsage", choices: List["_models.ChatChoice"], - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) -class ChatCompletionsNamedToolChoice(_model_base.Model): - """A tool selection of a specific, named function tool that will limit chat completions to using - the named function. +class ChatCompletionsToolCall(_model_base.Model): + """An abstract representation of a tool call that must be resolved in a subsequent request to + perform the requested + chat completion. + + You probably want to use the sub-classes and not this class directly. Known sub-classes are: + ChatCompletionsFunctionToolCall + + + :ivar type: The object type. Required. Default value is None. + :vartype type: str + :ivar id: The ID of the tool call. Required. + :vartype id: str + """ + + __mapping__: Dict[str, _model_base.Model] = {} + type: str = rest_discriminator(name="type") + """The object type. Required. Default value is None.""" + id: str = rest_field() + """The ID of the tool call. Required.""" + + @overload + def __init__( + self, + *, + type: str, + id: str, # pylint: disable=redefined-builtin + ) -> None: ... - Readonly variables are only populated by the server, and will be ignored when sending a request. + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class ChatCompletionsFunctionToolCall(ChatCompletionsToolCall, discriminator="function"): + """A tool call to a function tool, issued by the model in evaluation of a configured function + tool, that represents + a function invocation needed for a subsequent chat completions request to resolve. + + + :ivar id: The ID of the tool call. Required. + :vartype id: str + :ivar type: The type of tool call, in this case always 'function'. Required. Default value is + "function". + :vartype type: str + :ivar function: The details of the function invocation requested by the tool call. Required. + :vartype function: ~azure.ai.inference.models.FunctionCall + """ + + type: Literal["function"] = rest_discriminator(name="type") # type: ignore + """The type of tool call, in this case always 'function'. Required. Default value is \"function\".""" + function: "_models.FunctionCall" = rest_field() + """The details of the function invocation requested by the tool call. Required.""" + + @overload + def __init__( + self, + *, + id: str, # pylint: disable=redefined-builtin + function: "_models.FunctionCall", + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, type="function", **kwargs) + + +class ChatCompletionsToolDefinition(_model_base.Model): + """An abstract representation of a tool that can be used by the model to improve a chat + completions response. + + You probably want to use the sub-classes and not this class directly. Known sub-classes are: + ChatCompletionsFunctionToolDefinition All required parameters must be populated in order to send to server. - :ivar type: The type of the tool. Currently, only ``function`` is supported. Required. Default - value is "function". + :ivar type: The object type. Required. Default value is None. :vartype type: str - :ivar function: The function that should be called. Required. - :vartype function: ~azure.ai.inference.models.ChatCompletionsNamedToolChoiceFunction """ - type: Literal["function"] = rest_field() - """The type of the tool. Currently, only ``function`` is supported. Required. Default value is - \"function\".""" - function: "_models.ChatCompletionsNamedToolChoiceFunction" = rest_field() - """The function that should be called. Required.""" + __mapping__: Dict[str, _model_base.Model] = {} + type: str = rest_discriminator(name="type") + """The object type. Required. Default value is None.""" @overload def __init__( self, *, - function: "_models.ChatCompletionsNamedToolChoiceFunction", - ): ... + type: str, + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] @@ -243,10 +318,44 @@ def __init__(self, mapping: Mapping[str, Any]): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - self.type: Literal["function"] = "function" -class ChatCompletionsNamedToolChoiceFunction(_model_base.Model): +class ChatCompletionsFunctionToolDefinition(ChatCompletionsToolDefinition, discriminator="function"): + """The definition information for a chat completions function tool that can call a function in + response to a tool call. + + All required parameters must be populated in order to send to server. + + :ivar type: The object name, which is always 'function'. Required. Default value is "function". + :vartype type: str + :ivar function: The function definition details for the function tool. Required. + :vartype function: ~azure.ai.inference.models.FunctionDefinition + """ + + type: Literal["function"] = rest_discriminator(name="type") # type: ignore + """The object name, which is always 'function'. Required. Default value is \"function\".""" + function: "_models.FunctionDefinition" = rest_field() + """The function definition details for the function tool. Required.""" + + @overload + def __init__( + self, + *, + function: "_models.FunctionDefinition", + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, type="function", **kwargs) + + +class ChatCompletionsFunctionToolSelection(_model_base.Model): """A tool selection of a specific, named function tool that will limit chat completions to using the named function. @@ -264,201 +373,327 @@ def __init__( self, *, name: str, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) -class ChatCompletionsResponseFormat(_model_base.Model): - """Represents the format that the model must output. Use this to enable JSON mode instead of the - default text mode. - Note that to enable JSON mode, some AI models may also require you to instruct the model to - produce JSON - via a system or user message. +class ChatCompletionsNamedToolSelection(_model_base.Model): + """An abstract representation of an explicit, named tool selection to use for a chat completions + request. You probably want to use the sub-classes and not this class directly. Known sub-classes are: - ChatCompletionsResponseFormatJSON, ChatCompletionsResponseFormatText + ChatCompletionsNamedFunctionToolSelection All required parameters must be populated in order to send to server. - :ivar type: The response format type to use for chat completions. Required. Default value is - None. + :ivar type: The object type. Required. Default value is None. :vartype type: str """ __mapping__: Dict[str, _model_base.Model] = {} type: str = rest_discriminator(name="type") - """The response format type to use for chat completions. Required. Default value is None.""" + """The object type. Required. Default value is None.""" @overload def __init__( self, *, type: str, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) -class ChatCompletionsResponseFormatJSON(ChatCompletionsResponseFormat, discriminator="json_object"): - """A response format for Chat Completions that restricts responses to emitting valid JSON objects. - Note that to enable JSON mode, some AI models may also require you to instruct the model to - produce JSON - via a system or user message. +class ChatCompletionsNamedFunctionToolSelection( + ChatCompletionsNamedToolSelection, discriminator="function" +): # pylint: disable=name-too-long + """A tool selection of a specific, named function tool that will limit chat completions to using + the named function. All required parameters must be populated in order to send to server. - :ivar type: Response format type: always 'json_object' for this object. Required. Default value - is "json_object". + :ivar type: The object type, which is always 'function'. Required. Default value is "function". :vartype type: str + :ivar function: The function that should be called. Required. + :vartype function: ~azure.ai.inference.models.ChatCompletionsFunctionToolSelection """ - type: Literal["json_object"] = rest_discriminator(name="type") # type: ignore - """Response format type: always 'json_object' for this object. Required. Default value is - \"json_object\".""" + type: Literal["function"] = rest_discriminator(name="type") # type: ignore + """The object type, which is always 'function'. Required. Default value is \"function\".""" + function: "_models.ChatCompletionsFunctionToolSelection" = rest_field() + """The function that should be called. Required.""" @overload def __init__( self, - ): ... + *, + function: "_models.ChatCompletionsFunctionToolSelection", + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation - super().__init__(*args, type="json_object", **kwargs) + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, type="function", **kwargs) -class ChatCompletionsResponseFormatText(ChatCompletionsResponseFormat, discriminator="text"): - """A response format for Chat Completions that emits text responses. This is the default response - format. +class ChatCompletionsOptions(_model_base.Model): + """The configuration information for a chat completions request. + Completions support a wide variety of tasks and generate text that continues from or + "completes" + provided prompt data. All required parameters must be populated in order to send to server. - :ivar type: Response format type: always 'text' for this object. Required. Default value is - "text". + :ivar messages: The collection of context messages associated with this chat completions + request. + Typical usage begins with a chat message for the System role that provides instructions for + the behavior of the assistant, followed by alternating messages between the User and + Assistant roles. Required. + :vartype messages: list[~azure.ai.inference.models.ChatRequestMessage] + :ivar frequency_penalty: A value that influences the probability of generated tokens appearing + based on their cumulative + frequency in generated text. + Positive values will make tokens less likely to appear as their frequency increases and + decrease the likelihood of the model repeating the same statements verbatim. + Supported range is [-2, 2]. + :vartype frequency_penalty: float + :ivar stream: A value indicating whether chat completions should be streamed for this request. + :vartype stream: bool + :ivar presence_penalty: A value that influences the probability of generated tokens appearing + based on their existing + presence in generated text. + Positive values will make tokens less likely to appear when they already exist and increase + the + model's likelihood to output new topics. + Supported range is [-2, 2]. + :vartype presence_penalty: float + :ivar temperature: The sampling temperature to use that controls the apparent creativity of + generated completions. + Higher values will make output more random while lower values will make results more focused + and deterministic. + It is not recommended to modify temperature and top_p for the same completions request as the + interaction of these two settings is difficult to predict. + Supported range is [0, 1]. + :vartype temperature: float + :ivar top_p: An alternative to sampling with temperature called nucleus sampling. This value + causes the + model to consider the results of tokens with the provided probability mass. As an example, a + value of 0.15 will cause only the tokens comprising the top 15% of probability mass to be + considered. + It is not recommended to modify temperature and top_p for the same completions request as the + interaction of these two settings is difficult to predict. + Supported range is [0, 1]. + :vartype top_p: float + :ivar max_tokens: The maximum number of tokens to generate. + :vartype max_tokens: int + :ivar response_format: The format that the model must output. Use this to enable JSON mode + instead of the default text mode. + Note that to enable JSON mode, some AI models may also require you to instruct the model to + produce JSON + via a system or user message. + :vartype response_format: ~azure.ai.inference.models.ChatCompletionsResponseFormat + :ivar stop: A collection of textual sequences that will end completions generation. + :vartype stop: list[str] + :ivar tools: The available tool definitions that the chat completions request can use, + including caller-defined functions. + :vartype tools: list[~azure.ai.inference.models.ChatCompletionsToolDefinition] + :ivar tool_choice: If specified, the model will configure which of the provided tools it can + use for the chat completions response. Is either a Union[str, + "_models.ChatCompletionsToolSelectionPreset"] type or a ChatCompletionsNamedToolSelection type. + :vartype tool_choice: str or ~azure.ai.inference.models.ChatCompletionsToolSelectionPreset or + ~azure.ai.inference.models.ChatCompletionsNamedToolSelection + :ivar seed: If specified, the system will make a best effort to sample deterministically such + that repeated requests with the + same seed and parameters should return the same result. Determinism is not guaranteed. + :vartype seed: int + :ivar model: ID of the specific AI model to use, if more than one model is available on the + endpoint. + :vartype model: str + """ + + messages: List["_models.ChatRequestMessage"] = rest_field() + """The collection of context messages associated with this chat completions request. + Typical usage begins with a chat message for the System role that provides instructions for + the behavior of the assistant, followed by alternating messages between the User and + Assistant roles. Required.""" + frequency_penalty: Optional[float] = rest_field() + """A value that influences the probability of generated tokens appearing based on their cumulative + frequency in generated text. + Positive values will make tokens less likely to appear as their frequency increases and + decrease the likelihood of the model repeating the same statements verbatim. + Supported range is [-2, 2].""" + stream: Optional[bool] = rest_field() + """A value indicating whether chat completions should be streamed for this request.""" + presence_penalty: Optional[float] = rest_field() + """A value that influences the probability of generated tokens appearing based on their existing + presence in generated text. + Positive values will make tokens less likely to appear when they already exist and increase the + model's likelihood to output new topics. + Supported range is [-2, 2].""" + temperature: Optional[float] = rest_field() + """The sampling temperature to use that controls the apparent creativity of generated completions. + Higher values will make output more random while lower values will make results more focused + and deterministic. + It is not recommended to modify temperature and top_p for the same completions request as the + interaction of these two settings is difficult to predict. + Supported range is [0, 1].""" + top_p: Optional[float] = rest_field() + """An alternative to sampling with temperature called nucleus sampling. This value causes the + model to consider the results of tokens with the provided probability mass. As an example, a + value of 0.15 will cause only the tokens comprising the top 15% of probability mass to be + considered. + It is not recommended to modify temperature and top_p for the same completions request as the + interaction of these two settings is difficult to predict. + Supported range is [0, 1].""" + max_tokens: Optional[int] = rest_field() + """The maximum number of tokens to generate.""" + response_format: Optional["_models.ChatCompletionsResponseFormat"] = rest_field() + """The format that the model must output. Use this to enable JSON mode instead of the default text + mode. + Note that to enable JSON mode, some AI models may also require you to instruct the model to + produce JSON + via a system or user message.""" + stop: Optional[List[str]] = rest_field() + """A collection of textual sequences that will end completions generation.""" + tools: Optional[List["_models.ChatCompletionsToolDefinition"]] = rest_field() + """The available tool definitions that the chat completions request can use, including + caller-defined functions.""" + tool_choice: Optional[ + Union[str, "_models.ChatCompletionsToolSelectionPreset", "_models.ChatCompletionsNamedToolSelection"] + ] = rest_field() + """If specified, the model will configure which of the provided tools it can use for the chat + completions response. Is either a Union[str, \"_models.ChatCompletionsToolSelectionPreset\"] + type or a ChatCompletionsNamedToolSelection type.""" + seed: Optional[int] = rest_field() + """If specified, the system will make a best effort to sample deterministically such that repeated + requests with the + same seed and parameters should return the same result. Determinism is not guaranteed.""" + model: Optional[str] = rest_field() + """ID of the specific AI model to use, if more than one model is available on the endpoint.""" + + +class ChatCompletionsResponseFormat(_model_base.Model): + """Represents the format that the model must output. Use this to enable JSON mode instead of the + default text mode. + Note that to enable JSON mode, some AI models may also require you to instruct the model to + produce JSON + via a system or user message. + + You probably want to use the sub-classes and not this class directly. Known sub-classes are: + ChatCompletionsResponseFormatJSON, ChatCompletionsResponseFormatText + + All required parameters must be populated in order to send to server. + + :ivar type: The response format type to use for chat completions. Required. Default value is + None. :vartype type: str """ - type: Literal["text"] = rest_discriminator(name="type") # type: ignore - """Response format type: always 'text' for this object. Required. Default value is \"text\".""" + __mapping__: Dict[str, _model_base.Model] = {} + type: str = rest_discriminator(name="type") + """The response format type to use for chat completions. Required. Default value is None.""" @overload def __init__( self, - ): ... + *, + type: str, + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation - super().__init__(*args, type="text", **kwargs) - + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) -class ChatCompletionsToolCall(_model_base.Model): - """A function tool call requested by the AI model. - Readonly variables are only populated by the server, and will be ignored when sending a request. +class ChatCompletionsResponseFormatJSON(ChatCompletionsResponseFormat, discriminator="json_object"): + """A response format for Chat Completions that restricts responses to emitting valid JSON objects. + Note that to enable JSON mode, some AI models may also require you to instruct the model to + produce JSON + via a system or user message. + All required parameters must be populated in order to send to server. - :ivar id: The ID of the tool call. Required. - :vartype id: str - :ivar type: The type of tool call. Currently, only ``function`` is supported. Required. Default - value is "function". + :ivar type: Response format type: always 'json_object' for this object. Required. Default value + is "json_object". :vartype type: str - :ivar function: The details of the function call requested by the AI model. Required. - :vartype function: ~azure.ai.inference.models.FunctionCall """ - id: str = rest_field() - """The ID of the tool call. Required.""" - type: Literal["function"] = rest_field() - """The type of tool call. Currently, only ``function`` is supported. Required. Default value is - \"function\".""" - function: "_models.FunctionCall" = rest_field() - """The details of the function call requested by the AI model. Required.""" + type: Literal["json_object"] = rest_discriminator(name="type") # type: ignore + """Response format type: always 'json_object' for this object. Required. Default value is + \"json_object\".""" @overload def __init__( self, - *, - id: str, # pylint: disable=redefined-builtin - function: "_models.FunctionCall", - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.type: Literal["function"] = "function" + super().__init__(*args, type="json_object", **kwargs) -class ChatCompletionsToolDefinition(_model_base.Model): - """The definition of a chat completions tool that can call a function. - - Readonly variables are only populated by the server, and will be ignored when sending a request. +class ChatCompletionsResponseFormatText(ChatCompletionsResponseFormat, discriminator="text"): + """A response format for Chat Completions that emits text responses. This is the default response + format. All required parameters must be populated in order to send to server. - :ivar type: The type of the tool. Currently, only ``function`` is supported. Required. Default - value is "function". + :ivar type: Response format type: always 'text' for this object. Required. Default value is + "text". :vartype type: str - :ivar function: The function definition details for the function tool. Required. - :vartype function: ~azure.ai.inference.models.FunctionDefinition """ - type: Literal["function"] = rest_field() - """The type of the tool. Currently, only ``function`` is supported. Required. Default value is - \"function\".""" - function: "_models.FunctionDefinition" = rest_field() - """The function definition details for the function tool. Required.""" + type: Literal["text"] = rest_discriminator(name="type") # type: ignore + """Response format type: always 'text' for this object. Required. Default value is \"text\".""" @overload def __init__( self, - *, - function: "_models.FunctionDefinition", - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.type: Literal["function"] = "function" + super().__init__(*args, type="text", **kwargs) class ChatResponseMessage(_model_base.Model): @@ -493,16 +728,16 @@ def __init__( role: Union[str, "_models.ChatRole"], content: str, tool_calls: Optional[List["_models.ChatCompletionsToolCall"]] = None, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -537,16 +772,16 @@ def __init__( completion_tokens: int, prompt_tokens: int, total_tokens: int, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -571,16 +806,16 @@ def __init__( self, *, type: str, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -608,16 +843,16 @@ def __init__( *, image: str, text: Optional[str] = None, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -625,19 +860,17 @@ class EmbeddingItem(_model_base.Model): """Representation of a single embeddings relatedness comparison. - :ivar embedding: List of embedding values for the input prompt. These represent a measurement + :ivar embedding: List of embeddings value for the input prompt. These represent a measurement of the - vector-based relatedness of the provided input. Or a base64 encoded string of the embedding - vector. Required. Is either a str type or a [float] type. - :vartype embedding: str or list[float] + vector-based relatedness of the provided input. Required. + :vartype embedding: list[float] :ivar index: Index of the prompt to which the EmbeddingItem corresponds. Required. :vartype index: int """ - embedding: Union["str", List[float]] = rest_field() - """List of embedding values for the input prompt. These represent a measurement of the - vector-based relatedness of the provided input. Or a base64 encoded string of the embedding - vector. Required. Is either a str type or a [float] type.""" + embedding: List[float] = rest_field() + """List of embeddings value for the input prompt. These represent a measurement of the + vector-based relatedness of the provided input. Required.""" index: int = rest_field() """Index of the prompt to which the EmbeddingItem corresponds. Required.""" @@ -645,18 +878,18 @@ class EmbeddingItem(_model_base.Model): def __init__( self, *, - embedding: Union[str, List[float]], + embedding: List[float], index: int, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -694,16 +927,16 @@ def __init__( data: List["_models.EmbeddingItem"], usage: "_models.EmbeddingsUsage", model: str, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -731,16 +964,16 @@ def __init__( *, prompt_tokens: int, total_tokens: int, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -772,16 +1005,16 @@ def __init__( *, name: str, arguments: str, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -817,16 +1050,16 @@ def __init__( name: str, description: Optional[str] = None, parameters: Optional[Any] = None, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -855,16 +1088,16 @@ def __init__( self, *, image_url: "_models.ImageUrl", - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, type="image_url", **kwargs) @@ -894,16 +1127,16 @@ def __init__( *, url: str, detail: Optional[Union[str, "_models.ImageDetailLevel"]] = None, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -938,16 +1171,16 @@ def __init__( model_name: str, model_type: Union[str, "_models.ModelType"], model_provider_name: str, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -964,7 +1197,7 @@ class StreamingChatChoiceUpdate(_model_base.Model): Required. Known values are: "stop", "length", "content_filter", and "tool_calls". :vartype finish_reason: str or ~azure.ai.inference.models.CompletionsFinishReason :ivar delta: An update to the chat message for a given chat completions prompt. Required. - :vartype delta: ~azure.ai.inference.models.StreamingChatResponseMessageUpdate + :vartype delta: ~azure.ai.inference.models.ChatResponseMessage """ index: int = rest_field() @@ -972,7 +1205,7 @@ class StreamingChatChoiceUpdate(_model_base.Model): finish_reason: Union[str, "_models.CompletionsFinishReason"] = rest_field() """The reason that this chat completions choice completed its generated. Required. Known values are: \"stop\", \"length\", \"content_filter\", and \"tool_calls\".""" - delta: "_models.StreamingChatResponseMessageUpdate" = rest_field() + delta: "_models.ChatResponseMessage" = rest_field() """An update to the chat message for a given chat completions prompt. Required.""" @overload @@ -981,17 +1214,17 @@ def __init__( *, index: int, finish_reason: Union[str, "_models.CompletionsFinishReason"], - delta: "_models.StreamingChatResponseMessageUpdate", - ): ... + delta: "_models.ChatResponseMessage", + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -1046,94 +1279,16 @@ def __init__( model: str, usage: "_models.CompletionsUsage", choices: List["_models.StreamingChatChoiceUpdate"], - ): ... - - @overload - def __init__(self, mapping: Mapping[str, Any]): - """ - :param mapping: raw JSON to initialize the model. - :type mapping: Mapping[str, Any] - """ - - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation - super().__init__(*args, **kwargs) - - -class StreamingChatResponseMessageUpdate(_model_base.Model): - """A representation of a chat message update as received in a streaming response. - - :ivar role: The chat role associated with the message. If present, should always be - 'assistant'. Known values are: "system", "user", "assistant", and "tool". - :vartype role: str or ~azure.ai.inference.models.ChatRole - :ivar content: The content of the message. - :vartype content: str - :ivar tool_calls: The tool calls that must be resolved and have their outputs appended to - subsequent input messages for the chat - completions request to resolve as configured. - :vartype tool_calls: list[~azure.ai.inference.models.StreamingChatResponseToolCallUpdate] - """ - - role: Optional[Union[str, "_models.ChatRole"]] = rest_field() - """The chat role associated with the message. If present, should always be 'assistant'. Known - values are: \"system\", \"user\", \"assistant\", and \"tool\".""" - content: Optional[str] = rest_field() - """The content of the message.""" - tool_calls: Optional[List["_models.StreamingChatResponseToolCallUpdate"]] = rest_field() - """The tool calls that must be resolved and have their outputs appended to subsequent input - messages for the chat - completions request to resolve as configured.""" - - @overload - def __init__( - self, - *, - role: Optional[Union[str, "_models.ChatRole"]] = None, - content: Optional[str] = None, - tool_calls: Optional[List["_models.StreamingChatResponseToolCallUpdate"]] = None, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation - super().__init__(*args, **kwargs) - - -class StreamingChatResponseToolCallUpdate(_model_base.Model): - """An update to the function tool call information requested by the AI model. - - - :ivar id: The ID of the tool call. Required. - :vartype id: str - :ivar function: Updates to the function call requested by the AI model. Required. - :vartype function: ~azure.ai.inference.models.FunctionCall - """ - - id: str = rest_field() - """The ID of the tool call. Required.""" - function: "_models.FunctionCall" = rest_field() - """Updates to the function call requested by the AI model. Required.""" - - @overload - def __init__( - self, - *, - id: str, # pylint: disable=redefined-builtin - function: "_models.FunctionCall", - ): ... - - @overload - def __init__(self, mapping: Mapping[str, Any]): - """ - :param mapping: raw JSON to initialize the model. - :type mapping: Mapping[str, Any] - """ - - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -1162,16 +1317,16 @@ def __init__( self, *, content: str, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, role=ChatRole.SYSTEM, **kwargs) @@ -1198,16 +1353,16 @@ def __init__( self, *, text: str, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, type="text", **kwargs) @@ -1240,16 +1395,16 @@ def __init__( *, content: str, tool_call_id: str, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, role=ChatRole.TOOL, **kwargs) @@ -1269,7 +1424,7 @@ class UserMessage(ChatRequestMessage, discriminator="user"): role: Literal[ChatRole.USER] = rest_discriminator(name="role") # type: ignore """The chat role associated with this message, which is always 'user' for user messages. Required. The role that provides input for chat completions.""" - content: Union["str", List["_models.ContentItem"]] = rest_field() + content: Union[str, List["_models.ContentItem"]] = rest_field() """The contents of the user message, with available input types varying by selected model. Required. Is either a str type or a [ContentItem] type.""" @@ -1278,14 +1433,14 @@ def __init__( self, *, content: Union[str, List["_models.ContentItem"]], - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, role=ChatRole.USER, **kwargs) diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/models/_patch.py b/sdk/ai/azure-ai-inference/azure/ai/inference/models/_patch.py index 61c718eea63f..f7dd32510333 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/models/_patch.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/models/_patch.py @@ -6,273 +6,9 @@ Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize """ -import asyncio -import base64 -import json -import logging -import queue -import re -import sys +from typing import List -from typing import Any, List, AsyncIterator, Iterator, Optional, Union -from azure.core.rest import HttpResponse, AsyncHttpResponse -from ._models import ImageUrl as ImageUrlGenerated -from ._models import ChatCompletions as ChatCompletionsGenerated -from ._models import EmbeddingsResult as EmbeddingsResultGenerated -from .. import models as _models - -if sys.version_info >= (3, 11): - from typing import Self -else: - from typing_extensions import Self - -logger = logging.getLogger(__name__) - - -class ChatCompletions(ChatCompletionsGenerated): - """Representation of the response data from a chat completions request. - Completions support a wide variety of tasks and generate text that continues from or - "completes" - provided prompt data. - - - :ivar id: A unique identifier associated with this chat completions response. Required. - :vartype id: str - :ivar created: The first timestamp associated with generation activity for this completions - response, - represented as seconds since the beginning of the Unix epoch of 00:00 on 1 Jan 1970. Required. - :vartype created: ~datetime.datetime - :ivar model: The model used for the chat completion. Required. - :vartype model: str - :ivar usage: Usage information for tokens processed and generated as part of this completions - operation. Required. - :vartype usage: ~azure.ai.inference.models.CompletionsUsage - :ivar choices: The collection of completions choices associated with this completions response. - Generally, ``n`` choices are generated per provided prompt with a default value of 1. - Token limits and other settings may limit the number of choices generated. Required. - :vartype choices: list[~azure.ai.inference.models.ChatChoice] - """ - - def __str__(self) -> str: - # pylint: disable=client-method-name-no-double-underscore - return json.dumps(self.as_dict(), indent=2) - - -class EmbeddingsResult(EmbeddingsResultGenerated): - """Representation of the response data from an embeddings request. - Embeddings measure the relatedness of text strings and are commonly used for search, - clustering, - recommendations, and other similar scenarios. - - - :ivar data: Embedding values for the prompts submitted in the request. Required. - :vartype data: list[~azure.ai.inference.models.EmbeddingItem] - :ivar usage: Usage counts for tokens input using the embeddings API. Required. - :vartype usage: ~azure.ai.inference.models.EmbeddingsUsage - :ivar model: The model ID used to generate this result. Required. - :vartype model: str - """ - - def __str__(self) -> str: - # pylint: disable=client-method-name-no-double-underscore - return json.dumps(self.as_dict(), indent=2) - - -class ImageUrl(ImageUrlGenerated): - - @classmethod - def load( - cls, *, image_file: str, image_format: str, detail: Optional[Union[str, "_models.ImageDetailLevel"]] = None - ) -> Self: - """ - Create an ImageUrl object from a local image file. The method reads the image - file and encodes it as a base64 string, which together with the image format - is then used to format the JSON `url` value passed in the request payload. - - :ivar image_file: The name of the local image file to load. Required. - :vartype image_file: str - :ivar image_format: The MIME type format of the image. For example: "jpeg", "png". Required. - :vartype image_format: str - :ivar detail: The evaluation quality setting to use, which controls relative prioritization of - speed, token consumption, and accuracy. Known values are: "auto", "low", and "high". - :vartype detail: str or ~azure.ai.inference.models.ImageDetailLevel - :return: An ImageUrl object with the image data encoded as a base64 string. - :rtype: ~azure.ai.inference.models.ImageUrl - :raises FileNotFoundError: when the image file could not be opened. - """ - with open(image_file, "rb") as f: - image_data = base64.b64encode(f.read()).decode("utf-8") - url = f"data:image/{image_format};base64,{image_data}" - return cls(url=url, detail=detail) - - -class BaseStreamingChatCompletions: - """A base class for the sync and async streaming chat completions responses, holding any common code - to deserializes the Server Sent Events (SSE) response stream into chat completions updates, each one - represented by a StreamingChatCompletionsUpdate object. - """ - - # Enable detailed logs of SSE parsing. For development only, should be `False` by default. - _ENABLE_CLASS_LOGS = False - - # The prefix of each line in the SSE stream that contains a JSON string - # to deserialize into a StreamingChatCompletionsUpdate object - _SSE_DATA_EVENT_PREFIX = "data: " - - # The line indicating the end of the SSE stream - _SSE_DATA_EVENT_DONE = "data: [DONE]" - - def __init__(self): - self._queue: "queue.Queue[_models.StreamingChatCompletionsUpdate]" = queue.Queue() - self._incomplete_json = "" - self._done = False # Will be set to True when reading 'data: [DONE]' line - - def _deserialize_and_add_to_queue(self, element: bytes) -> bool: - - # Clear the queue of StreamingChatCompletionsUpdate before processing the next block - self._queue.queue.clear() - - # Convert `bytes` to string and split the string by newline, while keeping the new line char. - # the last may be a partial "line" that does not contain a newline char at the end. - line_list: List[str] = re.split(r"(?<=\n)", element.decode("utf-8")) - for index, line in enumerate(line_list): - - if self._ENABLE_CLASS_LOGS: - logger.debug("[Original line] %s", repr(line)) - - if index == 0: - line = self._incomplete_json + line - self._incomplete_json = "" - - if index == len(line_list) - 1 and not line.endswith("\n"): - self._incomplete_json = line - return False - - if self._ENABLE_CLASS_LOGS: - logger.debug("[Modified line] %s", repr(line)) - - if line == "\n": # Empty line, indicating flush output to client - continue - - if not line.startswith(self._SSE_DATA_EVENT_PREFIX): - raise ValueError(f"SSE event not supported (line `{line}`)") - - if line.startswith(self._SSE_DATA_EVENT_DONE): - if self._ENABLE_CLASS_LOGS: - logger.debug("[Done]") - return True - - # If you reached here, the line should contain `data: {...}\n` - # where the curly braces contain a valid JSON object. - # Deserialize it into a StreamingChatCompletionsUpdate object - # and add it to the queue. - # pylint: disable=W0212 # Access to a protected member _deserialize of a client class - update = _models.StreamingChatCompletionsUpdate._deserialize( - json.loads(line[len(self._SSE_DATA_EVENT_PREFIX) : -1]), [] - ) - - # We skip any update that has a None or empty choices list - # (this is what OpenAI Python SDK does) - if update.choices: - - # We update all empty content strings to None - # (this is what OpenAI Python SDK does) - # for choice in update.choices: - # if not choice.delta.content: - # choice.delta.content = None - - self._queue.put(update) - - if self._ENABLE_CLASS_LOGS: - logger.debug("[Added to queue]") - - return False - - -class StreamingChatCompletions(BaseStreamingChatCompletions): - """Represents an interator over StreamingChatCompletionsUpdate objects. It can be used for either synchronous or - asynchronous iterations. The class deserializes the Server Sent Events (SSE) response stream - into chat completions updates, each one represented by a StreamingChatCompletionsUpdate object. - """ - - def __init__(self, response: HttpResponse): - super().__init__() - self._response = response - self._bytes_iterator: Iterator[bytes] = response.iter_bytes() - - def __iter__(self) -> Any: - return self - - def __next__(self) -> "_models.StreamingChatCompletionsUpdate": - while self._queue.empty() and not self._done: - self._done = self._read_next_block() - if self._queue.empty(): - raise StopIteration - return self._queue.get() - - def _read_next_block(self) -> bool: - if self._ENABLE_CLASS_LOGS: - logger.debug("[Reading next block]") - try: - element = self._bytes_iterator.__next__() - except StopIteration: - self.close() - return True - return self._deserialize_and_add_to_queue(element) - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # type: ignore - self.close() - - def close(self) -> None: - self._response.close() - - -class AsyncStreamingChatCompletions(BaseStreamingChatCompletions): - """Represents an async interator over StreamingChatCompletionsUpdate objects. - It can be used for either synchronous or asynchronous iterations. The class - deserializes the Server Sent Events (SSE) response stream into chat - completions updates, each one represented by a StreamingChatCompletionsUpdate object. - """ - - def __init__(self, response: AsyncHttpResponse): - super().__init__() - self._response = response - self._bytes_iterator: AsyncIterator[bytes] = response.iter_bytes() - - def __aiter__(self) -> Any: - return self - - async def __anext__(self) -> "_models.StreamingChatCompletionsUpdate": - while self._queue.empty() and not self._done: - self._done = await self._read_next_block_async() - if self._queue.empty(): - raise StopAsyncIteration - return self._queue.get() - - async def _read_next_block_async(self) -> bool: - if self._ENABLE_CLASS_LOGS: - logger.debug("[Reading next block]") - try: - element = await self._bytes_iterator.__anext__() - except StopAsyncIteration: - await self.aclose() - return True - return self._deserialize_and_add_to_queue(element) - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # type: ignore - asyncio.run(self.aclose()) - - async def aclose(self) -> None: - await self._response.close() - - -__all__: List[str] = [ - "ImageUrl", - "ChatCompletions", - "EmbeddingsResult", - "StreamingChatCompletions", - "AsyncStreamingChatCompletions", -] # Add all objects you want publicly available to users at this package level +__all__: List[str] = [] # Add all objects you want publicly available to users at this package level def patch_sdk(): diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/tracing.py b/sdk/ai/azure-ai-inference/azure/ai/inference/tracing.py deleted file mode 100644 index b2690c13cddd..000000000000 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/tracing.py +++ /dev/null @@ -1,676 +0,0 @@ -# ------------------------------------ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# ------------------------------------ -import copy -from enum import Enum -import functools -import json -import importlib -import logging -import os -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union -from urllib.parse import urlparse - -# pylint: disable = no-name-in-module -from azure.core import CaseInsensitiveEnumMeta # type: ignore -from azure.core.settings import settings -from . import models as _models - -try: - # pylint: disable = no-name-in-module - from azure.core.tracing import AbstractSpan, SpanKind # type: ignore - from opentelemetry.trace import StatusCode, Span - - _tracing_library_available = True -except ModuleNotFoundError: - - _tracing_library_available = False - - -__all__ = [ - "AIInferenceInstrumentor", -] - - -_inference_traces_enabled: bool = False -_trace_inference_content: bool = False -_INFERENCE_GEN_AI_SYSTEM_NAME = "az.ai.inference" - - -class TraceType(str, Enum, metaclass=CaseInsensitiveEnumMeta): # pylint: disable=C4747 - """An enumeration class to represent different types of traces.""" - - INFERENCE = "Inference" - - -class AIInferenceInstrumentor: - """ - A class for managing the trace instrumentation of AI Inference. - - This class allows enabling or disabling tracing for AI Inference. - and provides functionality to check whether instrumentation is active. - - """ - - def __init__(self): - if not _tracing_library_available: - raise ModuleNotFoundError( - "Azure Core Tracing Opentelemetry is not installed. " - "Please install it using 'pip install azure-core-tracing-opentelemetry'" - ) - # In the future we could support different versions from the same library - # and have a parameter that specifies the version to use. - self._impl = _AIInferenceInstrumentorPreview() - - def instrument(self) -> None: - """ - Enable trace instrumentation for AI Inference. - - Raises: - RuntimeError: If instrumentation is already enabled. - - """ - self._impl.instrument() - - def uninstrument(self) -> None: - """ - Disable trace instrumentation for AI Inference. - - Raises: - RuntimeError: If instrumentation is not currently enabled. - - This method removes any active instrumentation, stopping the tracing - of AI Inference. - """ - self._impl.uninstrument() - - def is_instrumented(self) -> bool: - """ - Check if trace instrumentation for AI Inference is currently enabled. - - :return: True if instrumentation is active, False otherwise. - :rtype: bool - """ - return self._impl.is_instrumented() - - -class _AIInferenceInstrumentorPreview: - """ - A class for managing the trace instrumentation of AI Inference. - - This class allows enabling or disabling tracing for AI Inference. - and provides functionality to check whether instrumentation is active. - """ - - def _str_to_bool(self, s): - if s is None: - return False - return str(s).lower() == "true" - - def instrument(self): - """ - Enable trace instrumentation for AI Inference. - - Raises: - RuntimeError: If instrumentation is already enabled. - - This method checks the environment variable - 'AZURE_TRACING_GEN_AI_CONTENT_RECORDING_ENABLED' to determine - whether to enable content tracing. - """ - if self.is_instrumented(): - raise RuntimeError("Already instrumented") - - var_value = os.environ.get("AZURE_TRACING_GEN_AI_CONTENT_RECORDING_ENABLED") - enable_content_tracing = self._str_to_bool(var_value) - self._instrument_inference(enable_content_tracing) - - def uninstrument(self): - """ - Disable trace instrumentation for AI Inference. - - Raises: - RuntimeError: If instrumentation is not currently enabled. - - This method removes any active instrumentation, stopping the tracing - of AI Inference. - """ - if not self.is_instrumented(): - raise RuntimeError("Not instrumented") - self._uninstrument_inference() - - def is_instrumented(self): - """ - Check if trace instrumentation for AI Inference is currently enabled. - - :return: True if instrumentation is active, False otherwise. - :rtype: bool - """ - return self._is_instrumented() - - def _set_attributes(self, span: "AbstractSpan", *attrs: Tuple[str, Any]) -> None: - for attr in attrs: - key, value = attr - if value is not None: - span.add_attribute(key, value) - - def _add_request_chat_message_event(self, span: "AbstractSpan", **kwargs: Any) -> None: - for message in kwargs.get("messages", []): - try: - message = message.as_dict() - except AttributeError: - pass - - if message.get("role"): - name = f"gen_ai.{message.get('role')}.message" - span.span_instance.add_event( - name=name, - attributes={ - "gen_ai.system": _INFERENCE_GEN_AI_SYSTEM_NAME, - "gen_ai.event.content": json.dumps(message), - }, - ) - - def _parse_url(self, url): - parsed = urlparse(url) - server_address = parsed.hostname - port = parsed.port - return server_address, port - - def _add_request_chat_attributes(self, span: "AbstractSpan", *args: Any, **kwargs: Any) -> None: - client = args[0] - endpoint = client._config.endpoint # pylint: disable=protected-access - server_address, port = self._parse_url(endpoint) - model = "chat" - if kwargs.get("model") is not None: - model_value = kwargs.get("model") - if model_value is not None: - model = model_value - - self._set_attributes( - span, - ("gen_ai.operation.name", "chat"), - ("gen_ai.system", _INFERENCE_GEN_AI_SYSTEM_NAME), - ("gen_ai.request.model", model), - ("gen_ai.request.max_tokens", kwargs.get("max_tokens")), - ("gen_ai.request.temperature", kwargs.get("temperature")), - ("gen_ai.request.top_p", kwargs.get("top_p")), - ("server.address", server_address), - ) - if port is not None and port != 443: - span.add_attribute("server.port", port) - - def _remove_function_call_names_and_arguments(self, tool_calls: list) -> list: - tool_calls_copy = copy.deepcopy(tool_calls) - for tool_call in tool_calls_copy: - if "function" in tool_call: - if "name" in tool_call["function"]: - del tool_call["function"]["name"] - if "arguments" in tool_call["function"]: - del tool_call["function"]["arguments"] - if not tool_call["function"]: - del tool_call["function"] - return tool_calls_copy - - def _get_finish_reasons(self, result) -> Optional[List[str]]: - if hasattr(result, "choices") and result.choices: - finish_reasons: List[str] = [] - for choice in result.choices: - finish_reason = getattr(choice, "finish_reason", None) - - if finish_reason is None: - # If finish_reason is None, default to "none" - finish_reasons.append("none") - elif hasattr(finish_reason, "value"): - # If finish_reason has a 'value' attribute (i.e., it's an enum), use it - finish_reasons.append(finish_reason.value) - elif isinstance(finish_reason, str): - # If finish_reason is a string, use it directly - finish_reasons.append(finish_reason) - else: - # Default to "none" - finish_reasons.append("none") - - return finish_reasons - return None - - def _get_finish_reason_for_choice(self, choice): - finish_reason = getattr(choice, "finish_reason", None) - if finish_reason is not None: - return finish_reason.value - - return "none" - - def _add_response_chat_message_event(self, span: "AbstractSpan", result: _models.ChatCompletions) -> None: - for choice in result.choices: - if _trace_inference_content: - full_response: Dict[str, Any] = { - "message": {"content": choice.message.content}, - "finish_reason": self._get_finish_reason_for_choice(choice), - "index": choice.index, - } - if choice.message.tool_calls: - full_response["message"]["tool_calls"] = [tool.as_dict() for tool in choice.message.tool_calls] - attributes = { - "gen_ai.system": _INFERENCE_GEN_AI_SYSTEM_NAME, - "gen_ai.event.content": json.dumps(full_response), - } - else: - response: Dict[str, Any] = { - "finish_reason": self._get_finish_reason_for_choice(choice), - "index": choice.index, - } - if choice.message.tool_calls: - response["message"] = {} - tool_calls_function_names_and_arguments_removed = self._remove_function_call_names_and_arguments( - choice.message.tool_calls - ) - response["message"]["tool_calls"] = [ - tool.as_dict() for tool in tool_calls_function_names_and_arguments_removed - ] - - attributes = { - "gen_ai.system": _INFERENCE_GEN_AI_SYSTEM_NAME, - "gen_ai.event.content": json.dumps(response), - } - span.span_instance.add_event(name="gen_ai.choice", attributes=attributes) - - def _add_response_chat_attributes( - self, - span: "AbstractSpan", - result: Union[_models.ChatCompletions, _models.StreamingChatCompletionsUpdate], - ) -> None: - self._set_attributes( - span, - ("gen_ai.response.id", result.id), - ("gen_ai.response.model", result.model), - ( - "gen_ai.usage.input_tokens", - (result.usage.prompt_tokens if hasattr(result, "usage") and result.usage else None), - ), - ( - "gen_ai.usage.output_tokens", - (result.usage.completion_tokens if hasattr(result, "usage") and result.usage else None), - ), - ) - finish_reasons = self._get_finish_reasons(result) - if not finish_reasons is None: - span.add_attribute("gen_ai.response.finish_reasons", finish_reasons) # type: ignore - - def _add_request_span_attributes(self, span: "AbstractSpan", _span_name: str, args: Any, kwargs: Any) -> None: - self._add_request_chat_attributes(span, *args, **kwargs) - if _trace_inference_content: - self._add_request_chat_message_event(span, **kwargs) - - def _add_response_span_attributes(self, span: "AbstractSpan", result: object) -> None: - if isinstance(result, _models.ChatCompletions): - self._add_response_chat_attributes(span, result) - self._add_response_chat_message_event(span, result) - # TODO add more models here - - def _accumulate_response(self, item, accumulate: Dict[str, Any]) -> None: - if item.finish_reason: - accumulate["finish_reason"] = item.finish_reason - if item.index: - accumulate["index"] = item.index - if item.delta.content: - accumulate.setdefault("message", {}) - accumulate["message"].setdefault("content", "") - accumulate["message"]["content"] += item.delta.content - if item.delta.tool_calls: - accumulate.setdefault("message", {}) - accumulate["message"].setdefault("tool_calls", []) - if item.delta.tool_calls is not None: - for tool_call in item.delta.tool_calls: - if tool_call.id: - accumulate["message"]["tool_calls"].append( - { - "id": tool_call.id, - "type": "", - "function": {"name": "", "arguments": ""}, - } - ) - if tool_call.function: - accumulate["message"]["tool_calls"][-1]["type"] = "function" - if tool_call.function and tool_call.function.name: - accumulate["message"]["tool_calls"][-1]["function"]["name"] = tool_call.function.name - if tool_call.function and tool_call.function.arguments: - accumulate["message"]["tool_calls"][-1]["function"]["arguments"] += tool_call.function.arguments - - def _wrapped_stream( - self, stream_obj: _models.StreamingChatCompletions, span: "AbstractSpan" - ) -> _models.StreamingChatCompletions: - class StreamWrapper(_models.StreamingChatCompletions): - def __init__(self, stream_obj, instrumentor): - super().__init__(stream_obj._response) - self._instrumentor = instrumentor - - def __iter__( # pyright: ignore [reportIncompatibleMethodOverride] - self, - ) -> Iterator[_models.StreamingChatCompletionsUpdate]: - accumulate: Dict[str, Any] = {} - try: - chunk = None - for chunk in stream_obj: - for item in chunk.choices: - self._instrumentor._accumulate_response(item, accumulate) - yield chunk - - if chunk is not None: - self._instrumentor._add_response_chat_attributes(span, chunk) - - except Exception as exc: - # Set the span status to error - if isinstance(span.span_instance, Span): # pyright: ignore [reportPossiblyUnboundVariable] - span.span_instance.set_status( - StatusCode.ERROR, # pyright: ignore [reportPossiblyUnboundVariable] - description=str(exc), - ) - module = exc.__module__ if hasattr(exc, "__module__") and exc.__module__ != "builtins" else "" - error_type = f"{module}.{type(exc).__name__}" if module else type(exc).__name__ - self._instrumentor._set_attributes(span, ("error.type", error_type)) - raise - - finally: - if stream_obj._done is False: - if accumulate.get("finish_reason") is None: - accumulate["finish_reason"] = "error" - else: - # Only one choice expected with streaming - accumulate["index"] = 0 - # Delete message if content tracing is not enabled - if not _trace_inference_content: - if "message" in accumulate: - if "content" in accumulate["message"]: - del accumulate["message"]["content"] - if not accumulate["message"]: - del accumulate["message"] - if "message" in accumulate: - if "tool_calls" in accumulate["message"]: - tool_calls_function_names_and_arguments_removed = ( - self._instrumentor._remove_function_call_names_and_arguments( - accumulate["message"]["tool_calls"] - ) - ) - accumulate["message"]["tool_calls"] = list( - tool_calls_function_names_and_arguments_removed - ) - - span.span_instance.add_event( - name="gen_ai.choice", - attributes={ - "gen_ai.system": _INFERENCE_GEN_AI_SYSTEM_NAME, - "gen_ai.event.content": json.dumps(accumulate), - }, - ) - span.finish() - - return StreamWrapper(stream_obj, self) - - def _trace_sync_function( - self, - function: Callable, - *, - _args_to_ignore: Optional[List[str]] = None, - _trace_type=TraceType.INFERENCE, - _name: Optional[str] = None, - ) -> Callable: - """ - Decorator that adds tracing to a synchronous function. - - :param function: The function to be traced. - :type function: Callable - :param args_to_ignore: A list of argument names to be ignored in the trace. - Defaults to None. - :type: args_to_ignore: [List[str]], optional - :param trace_type: The type of the trace. Defaults to TraceType.INFERENCE. - :type trace_type: TraceType, optional - :param name: The name of the trace, will set to func name if not provided. - :type name: str, optional - :return: The traced function. - :rtype: Callable - """ - - @functools.wraps(function) - def inner(*args, **kwargs): - - span_impl_type = settings.tracing_implementation() - if span_impl_type is None: - return function(*args, **kwargs) - - class_function_name = function.__qualname__ - - if class_function_name.startswith("ChatCompletionsClient.complete"): - if kwargs.get("model") is None: - span_name = "chat" - else: - model = kwargs.get("model") - span_name = f"chat {model}" - - span = span_impl_type( - name=span_name, - kind=SpanKind.CLIENT, # pyright: ignore [reportPossiblyUnboundVariable] - ) - try: - # tracing events not supported in azure-core-tracing-opentelemetry - # so need to access the span instance directly - with span_impl_type.change_context(span.span_instance): - self._add_request_span_attributes(span, span_name, args, kwargs) - result = function(*args, **kwargs) - if kwargs.get("stream") is True: - return self._wrapped_stream(result, span) - self._add_response_span_attributes(span, result) - - except Exception as exc: - # Set the span status to error - if isinstance(span.span_instance, Span): # pyright: ignore [reportPossiblyUnboundVariable] - span.span_instance.set_status( - StatusCode.ERROR, # pyright: ignore [reportPossiblyUnboundVariable] - description=str(exc), - ) - module = getattr(exc, "__module__", "") - module = module if module != "builtins" else "" - error_type = f"{module}.{type(exc).__name__}" if module else type(exc).__name__ - self._set_attributes(span, ("error.type", error_type)) - span.finish() - raise - - span.finish() - return result - - # Handle the default case (if the function name does not match) - return None # Ensure all paths return - - return inner - - def _trace_async_function( - self, - function: Callable, - *, - _args_to_ignore: Optional[List[str]] = None, - _trace_type=TraceType.INFERENCE, - _name: Optional[str] = None, - ) -> Callable: - """ - Decorator that adds tracing to an asynchronous function. - - :param function: The function to be traced. - :type function: Callable - :param args_to_ignore: A list of argument names to be ignored in the trace. - Defaults to None. - :type: args_to_ignore: [List[str]], optional - :param trace_type: The type of the trace. Defaults to TraceType.INFERENCE. - :type trace_type: TraceType, optional - :param name: The name of the trace, will set to func name if not provided. - :type name: str, optional - :return: The traced function. - :rtype: Callable - """ - - @functools.wraps(function) - async def inner(*args, **kwargs): - span_impl_type = settings.tracing_implementation() - if span_impl_type is None: - return function(*args, **kwargs) - - class_function_name = function.__qualname__ - - if class_function_name.startswith("ChatCompletionsClient.complete"): - if kwargs.get("model") is None: - span_name = "chat" - else: - model = kwargs.get("model") - span_name = f"chat {model}" - - span = span_impl_type( - name=span_name, - kind=SpanKind.CLIENT, # pyright: ignore [reportPossiblyUnboundVariable] - ) - try: - # tracing events not supported in azure-core-tracing-opentelemetry - # so need to access the span instance directly - with span_impl_type.change_context(span.span_instance): - self._add_request_span_attributes(span, span_name, args, kwargs) - result = await function(*args, **kwargs) - if kwargs.get("stream") is True: - return self._wrapped_stream(result, span) - self._add_response_span_attributes(span, result) - - except Exception as exc: - # Set the span status to error - if isinstance(span.span_instance, Span): # pyright: ignore [reportPossiblyUnboundVariable] - span.span_instance.set_status( - StatusCode.ERROR, # pyright: ignore [reportPossiblyUnboundVariable] - description=str(exc), - ) - module = getattr(exc, "__module__", "") - module = module if module != "builtins" else "" - error_type = f"{module}.{type(exc).__name__}" if module else type(exc).__name__ - self._set_attributes(span, ("error.type", error_type)) - span.finish() - raise - - span.finish() - return result - - # Handle the default case (if the function name does not match) - return None # Ensure all paths return - - return inner - - def _inject_async(self, f, _trace_type, _name): - wrapper_fun = self._trace_async_function(f) - wrapper_fun._original = f # pylint: disable=protected-access # pyright: ignore [reportFunctionMemberAccess] - return wrapper_fun - - def _inject_sync(self, f, _trace_type, _name): - wrapper_fun = self._trace_sync_function(f) - wrapper_fun._original = f # pylint: disable=protected-access # pyright: ignore [reportFunctionMemberAccess] - return wrapper_fun - - def _inference_apis(self): - sync_apis = ( - ( - "azure.ai.inference", - "ChatCompletionsClient", - "complete", - TraceType.INFERENCE, - "inference_chat_completions_complete", - ), - ) - async_apis = ( - ( - "azure.ai.inference.aio", - "ChatCompletionsClient", - "complete", - TraceType.INFERENCE, - "inference_chat_completions_complete", - ), - ) - return sync_apis, async_apis - - def _inference_api_list(self): - sync_apis, async_apis = self._inference_apis() - yield sync_apis, self._inject_sync - yield async_apis, self._inject_async - - def _generate_api_and_injector(self, apis): - for api, injector in apis: - for module_name, class_name, method_name, trace_type, name in api: - try: - module = importlib.import_module(module_name) - api = getattr(module, class_name) - if hasattr(api, method_name): - yield api, method_name, trace_type, injector, name - except AttributeError as e: - # Log the attribute exception with the missing class information - logging.warning( - "AttributeError: The module '%s' does not have the class '%s'. %s", - module_name, - class_name, - str(e), - ) - except Exception as e: # pylint: disable=broad-except - # Log other exceptions as a warning, as we're not sure what they might be - logging.warning("An unexpected error occurred: '%s'", str(e)) - - def _available_inference_apis_and_injectors(self): - """ - Generates a sequence of tuples containing Inference API classes, method names, and - corresponding injector functions. - - :return: A generator yielding tuples. - :rtype: tuple - """ - yield from self._generate_api_and_injector(self._inference_api_list()) - - def _instrument_inference(self, enable_content_tracing: bool = False): - """This function modifies the methods of the Inference API classes to - inject logic before calling the original methods. - The original methods are stored as _original attributes of the methods. - - :param enable_content_tracing: Indicates whether tracing of message content should be enabled. - This also controls whether function call tool function names, - parameter names and parameter values are traced. - :type enable_content_tracing: bool - """ - # pylint: disable=W0603 - global _inference_traces_enabled - global _trace_inference_content - if _inference_traces_enabled: - raise RuntimeError("Traces already started for azure.ai.inference") - _inference_traces_enabled = True - _trace_inference_content = enable_content_tracing - for ( - api, - method, - trace_type, - injector, - name, - ) in self._available_inference_apis_and_injectors(): - # Check if the method of the api class has already been modified - if not hasattr(getattr(api, method), "_original"): - setattr(api, method, injector(getattr(api, method), trace_type, name)) - - def _uninstrument_inference(self): - """This function restores the original methods of the Inference API classes - by assigning them back from the _original attributes of the modified methods. - """ - # pylint: disable=W0603 - global _inference_traces_enabled - global _trace_inference_content - _trace_inference_content = False - for api, method, _, _, _ in self._available_inference_apis_and_injectors(): - if hasattr(getattr(api, method), "_original"): - setattr(api, method, getattr(getattr(api, method), "_original")) - _inference_traces_enabled = False - - def _is_instrumented(self): - """This function returns True if Inference libary has already been instrumented - for tracing and False if it has not been instrumented. - - :return: A value indicating whether the Inference library is currently instrumented or not. - :rtype: bool - """ - return _inference_traces_enabled diff --git a/sdk/ai/azure-ai-inference/samples/async_samples/sample_chat_completions_from_input_json_async.py b/sdk/ai/azure-ai-inference/samples/async_samples/sample_chat_completions_from_input_json_async.py index ec2dd6afae75..25d6ce20cce7 100644 --- a/sdk/ai/azure-ai-inference/samples/async_samples/sample_chat_completions_from_input_json_async.py +++ b/sdk/ai/azure-ai-inference/samples/async_samples/sample_chat_completions_from_input_json_async.py @@ -58,10 +58,7 @@ async def sample_chat_completions_from_input_json_async(): "role": "assistant", "content": "The main construction of the International Space Station (ISS) was completed between 1998 and 2011. During this period, more than 30 flights by US space shuttles and 40 by Russian rockets were conducted to transport components and modules to the station.", }, - { - "role": "user", - "content": "And what was the estimated cost to build it?" - }, + {"role": "user", "content": "And what was the estimated cost to build it?"}, ] } diff --git a/sdk/ai/azure-ai-inference/samples/sample_chat_completions_azure_openai.py b/sdk/ai/azure-ai-inference/samples/sample_chat_completions_azure_openai.py index f025eea212cb..e4b03dbe50f9 100644 --- a/sdk/ai/azure-ai-inference/samples/sample_chat_completions_azure_openai.py +++ b/sdk/ai/azure-ai-inference/samples/sample_chat_completions_azure_openai.py @@ -65,7 +65,7 @@ def sample_chat_completions_azure_openai(): endpoint=endpoint, credential=DefaultAzureCredential(exclude_interactive_browser_credential=False), credential_scopes=["https://cognitiveservices.azure.com/.default"], - api_version="2024-06-01", # Azure OpenAI api-version. See https://aka.ms/azsdk/azure-ai-inference/azure-openai-api-versions + api_version="2024-06-01", # Azure OpenAI api-version. See https://aka.ms/azsdk/azure-ai-inference/azure-openai-api-versions ) response = client.complete( diff --git a/sdk/ai/azure-ai-inference/samples/sample_chat_completions_from_input_json.py b/sdk/ai/azure-ai-inference/samples/sample_chat_completions_from_input_json.py index 925583af4772..78a9b9a42690 100644 --- a/sdk/ai/azure-ai-inference/samples/sample_chat_completions_from_input_json.py +++ b/sdk/ai/azure-ai-inference/samples/sample_chat_completions_from_input_json.py @@ -58,10 +58,7 @@ def sample_chat_completions_from_input_json(): "role": "assistant", "content": "The main construction of the International Space Station (ISS) was completed between 1998 and 2011. During this period, more than 30 flights by US space shuttles and 40 by Russian rockets were conducted to transport components and modules to the station.", }, - { - "role": "user", - "content": "And what was the estimated cost to build it?" - }, + {"role": "user", "content": "And what was the estimated cost to build it?"}, ] } ) diff --git a/sdk/ai/azure-ai-inference/samples/sample_chat_completions_from_input_json_with_image_url.py b/sdk/ai/azure-ai-inference/samples/sample_chat_completions_from_input_json_with_image_url.py index 912b98afccb8..83f3afceaa19 100644 --- a/sdk/ai/azure-ai-inference/samples/sample_chat_completions_from_input_json_with_image_url.py +++ b/sdk/ai/azure-ai-inference/samples/sample_chat_completions_from_input_json_with_image_url.py @@ -54,9 +54,7 @@ def sample_chat_completions_from_input_json_with_image_url(): model_deployment = None client = ChatCompletionsClient( - endpoint=endpoint, - credential=AzureKeyCredential(key), - headers={"azureml-model-deployment": model_deployment} + endpoint=endpoint, credential=AzureKeyCredential(key), headers={"azureml-model-deployment": model_deployment} ) response = client.complete( @@ -69,10 +67,7 @@ def sample_chat_completions_from_input_json_with_image_url(): { "role": "user", "content": [ - { - "type": "text", - "text": "What's in this image?" - }, + {"type": "text", "text": "What's in this image?"}, { "type": "image_url", "image_url": { diff --git a/sdk/ai/azure-ai-inference/samples/sample_chat_completions_streaming_with_tools.py b/sdk/ai/azure-ai-inference/samples/sample_chat_completions_streaming_with_tools.py index dfa62afa2127..8eb5c7472af4 100644 --- a/sdk/ai/azure-ai-inference/samples/sample_chat_completions_streaming_with_tools.py +++ b/sdk/ai/azure-ai-inference/samples/sample_chat_completions_streaming_with_tools.py @@ -35,6 +35,7 @@ use_azure_openai_endpoint = True + def sample_chat_completions_streaming_with_tools(): import os import json @@ -79,11 +80,9 @@ def get_flight_info(origin_city: str, destination_city: str): str: The airline name, fight number, date and time of the next flight between the cities, in JSON format. """ if origin_city == "Seattle" and destination_city == "Miami": - return json.dumps({ - "airline": "Delta", - "flight_number": "DL123", - "flight_date": "May 7th, 2024", - "flight_time": "10:00AM"}) + return json.dumps( + {"airline": "Delta", "flight_number": "DL123", "flight_date": "May 7th, 2024", "flight_time": "10:00AM"} + ) return json.dumps({"error": "No flights found between the cities"}) # Define a function 'tool' that the model can use to retrieves flight information @@ -117,10 +116,7 @@ def get_flight_info(origin_city: str, destination_city: str): ) else: # Create a chat completions client for Serverless API endpoint or Managed Compute endpoint - client = ChatCompletionsClient( - endpoint=endpoint, - credential=AzureKeyCredential(key) - ) + client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(key)) # Make a streaming chat completions call asking for flight information, while providing a tool to handle the request messages = [ @@ -128,10 +124,7 @@ def get_flight_info(origin_city: str, destination_city: str): UserMessage(content="What is the next flights from Seattle to Miami?"), ] - response = client.complete( - messages=messages, - tools=[flight_info], - stream=True) + response = client.complete(messages=messages, tools=[flight_info], stream=True) # Note that in the above call we did not specify `tool_choice`. The service defaults to a setting equivalent # to specifying `tool_choice=ChatCompletionsToolChoicePreset.AUTO`. Other than ChatCompletionsToolChoicePreset @@ -158,11 +151,7 @@ def get_flight_info(origin_city: str, destination_city: str): AssistantMessage( tool_calls=[ ChatCompletionsToolCall( - id=tool_call_id, - function=FunctionCall( - name=function_name, - arguments=function_args - ) + id=tool_call_id, function=FunctionCall(name=function_name, arguments=function_args) ) ] ) @@ -176,19 +165,10 @@ def get_flight_info(origin_city: str, destination_city: str): print(f"Function response = {function_response}") # Append the function response as a tool message to the chat history - messages.append( - ToolMessage( - tool_call_id=tool_call_id, - content=function_response - ) - ) + messages.append(ToolMessage(tool_call_id=tool_call_id, content=function_response)) # With the additional tools information on hand, get another streaming response from the model - response = client.complete( - messages=messages, - tools=[flight_info], - stream=True - ) + response = client.complete(messages=messages, tools=[flight_info], stream=True) print("Model response = ", end="") for update in response: diff --git a/sdk/ai/azure-ai-inference/samples/sample_chat_completions_with_defaults.py b/sdk/ai/azure-ai-inference/samples/sample_chat_completions_with_defaults.py index 011735a7e61f..269ce2d232de 100644 --- a/sdk/ai/azure-ai-inference/samples/sample_chat_completions_with_defaults.py +++ b/sdk/ai/azure-ai-inference/samples/sample_chat_completions_with_defaults.py @@ -43,10 +43,7 @@ def sample_chat_completions_with_defaults(): # Create a client with default chat completions settings client = ChatCompletionsClient( - endpoint=endpoint, - credential=AzureKeyCredential(key), - temperature=0.5, - max_tokens=1000 + endpoint=endpoint, credential=AzureKeyCredential(key), temperature=0.5, max_tokens=1000 ) # Call the service with the defaults specified above diff --git a/sdk/ai/azure-ai-inference/samples/sample_chat_completions_with_tools.py b/sdk/ai/azure-ai-inference/samples/sample_chat_completions_with_tools.py index 3d14a550ab68..2074c447fdfe 100644 --- a/sdk/ai/azure-ai-inference/samples/sample_chat_completions_with_tools.py +++ b/sdk/ai/azure-ai-inference/samples/sample_chat_completions_with_tools.py @@ -64,14 +64,11 @@ def get_flight_info(origin_city: str, destination_city: str): str: The airline name, fight number, date and time of the next flight between the cities, in JSON format. """ if origin_city == "Seattle" and destination_city == "Miami": - return json.dumps({ - "airline": "Delta", - "flight_number": "DL123", - "flight_date": "May 7th, 2024", - "flight_time": "10:00AM"}) + return json.dumps( + {"airline": "Delta", "flight_number": "DL123", "flight_date": "May 7th, 2024", "flight_time": "10:00AM"} + ) return json.dumps({"error": "No flights found between the cities"}) - # Define a function 'tool' that the model can use to retrieves flight information flight_info = ChatCompletionsToolDefinition( function=FunctionDefinition( @@ -95,10 +92,7 @@ def get_flight_info(origin_city: str, destination_city: str): ) # Create a chat completion client. Make sure you selected a model that supports tools. - client = ChatCompletionsClient( - endpoint=endpoint, - credential=AzureKeyCredential(key) - ) + client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(key)) # Make a chat completions call asking for flight information, while providing a tool to handle the request messages = [ diff --git a/sdk/ai/azure-ai-inference/samples/sample_chat_completions_with_tracing.py b/sdk/ai/azure-ai-inference/samples/sample_chat_completions_with_tracing.py index 875010ebbd26..3c8120d3ff01 100644 --- a/sdk/ai/azure-ai-inference/samples/sample_chat_completions_with_tracing.py +++ b/sdk/ai/azure-ai-inference/samples/sample_chat_completions_with_tracing.py @@ -28,34 +28,38 @@ import os from opentelemetry import trace + # opentelemetry-sdk is required for the opentelemetry.sdk imports. # You can install it with command "pip install opentelemetry-sdk". -#from opentelemetry.sdk.trace import TracerProvider -#from opentelemetry.sdk.trace.export import SimpleSpanProcessor, ConsoleSpanExporter +# from opentelemetry.sdk.trace import TracerProvider +# from opentelemetry.sdk.trace.export import SimpleSpanProcessor, ConsoleSpanExporter from azure.ai.inference import ChatCompletionsClient from azure.ai.inference.models import SystemMessage, UserMessage, CompletionsFinishReason from azure.core.credentials import AzureKeyCredential - # [START trace_setting] +# [START trace_setting] from azure.core.settings import settings + settings.tracing_implementation = "opentelemetry" # [END trace_setting] # Setup tracing to console # Requires opentelemetry-sdk -#exporter = ConsoleSpanExporter() -#trace.set_tracer_provider(TracerProvider()) -#tracer = trace.get_tracer(__name__) -#trace.get_tracer_provider().add_span_processor(SimpleSpanProcessor(exporter)) +# exporter = ConsoleSpanExporter() +# trace.set_tracer_provider(TracerProvider()) +# tracer = trace.get_tracer(__name__) +# trace.get_tracer_provider().add_span_processor(SimpleSpanProcessor(exporter)) - # [START trace_function] +# [START trace_function] from opentelemetry.trace import get_tracer + tracer = get_tracer(__name__) + # The tracer.start_as_current_span decorator will trace the function call and enable adding additional attributes # to the span in the function implementation. Note that this will trace the function parameters and their values. -@tracer.start_as_current_span("get_temperature") # type: ignore +@tracer.start_as_current_span("get_temperature") # type: ignore def get_temperature(city: str) -> str: # Adding attributes to the current span @@ -68,7 +72,9 @@ def get_temperature(city: str) -> str: return "80" else: return "Unavailable" - # [END trace_function] + + +# [END trace_function] def get_weather(city: str) -> str: @@ -82,7 +88,13 @@ def get_weather(city: str) -> str: def chat_completion_with_function_call(key, endpoint): import json - from azure.ai.inference.models import ToolMessage, AssistantMessage, ChatCompletionsToolCall, ChatCompletionsToolDefinition, FunctionDefinition + from azure.ai.inference.models import ( + ToolMessage, + AssistantMessage, + ChatCompletionsToolCall, + ChatCompletionsToolDefinition, + FunctionDefinition, + ) weather_description = ChatCompletionsToolDefinition( function=FunctionDefinition( @@ -119,7 +131,7 @@ def chat_completion_with_function_call(key, endpoint): ) client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(key)) - messages=[ + messages = [ SystemMessage(content="You are a helpful assistant."), UserMessage(content="What is the weather and temperature in Seattle?"), ] @@ -142,13 +154,14 @@ def chat_completion_with_function_call(key, endpoint): messages.append(ToolMessage(tool_call_id=tool_call.id, content=function_response)) # With the additional tools information on hand, get another response from the model response = client.complete(messages=messages, tools=[weather_description, temperature_in_city]) - + print(f"Model response = {response.choices[0].message.content}") def main(): # [START instrument_inferencing] from azure.ai.inference.tracing import AIInferenceInstrumentor + # Instrument AI Inference API AIInferenceInstrumentor().instrument() # [END instrument_inferencing] diff --git a/sdk/ai/azure-ai-inference/samples/sample_embeddings_with_base64_encoding.py b/sdk/ai/azure-ai-inference/samples/sample_embeddings_with_base64_encoding.py index 9d9ec9c5c492..248bccb83a55 100644 --- a/sdk/ai/azure-ai-inference/samples/sample_embeddings_with_base64_encoding.py +++ b/sdk/ai/azure-ai-inference/samples/sample_embeddings_with_base64_encoding.py @@ -44,13 +44,15 @@ def sample_embeddings_with_base64_encoding(): # Request embeddings as base64 encoded strings response = client.embed( - input=["first phrase", "second phrase", "third phrase"], - encoding_format=EmbeddingEncodingFormat.BASE64) + input=["first phrase", "second phrase", "third phrase"], encoding_format=EmbeddingEncodingFormat.BASE64 + ) for item in response.data: # Display the start and end of the resulting base64 string - print(f"data[{item.index}] encoded (string length={len(item.embedding)}): " - f"\"{item.embedding[:32]}...{item.embedding[-32:]}\"") + print( + f"data[{item.index}] encoded (string length={len(item.embedding)}): " + f'"{item.embedding[:32]}...{item.embedding[-32:]}"' + ) # For display purposes, decode the string into a list of floating point numbers. # Display the first and last two elements of the list. diff --git a/sdk/ai/azure-ai-inference/samples/sample_image_embeddings_with_defaults.py b/sdk/ai/azure-ai-inference/samples/sample_image_embeddings_with_defaults.py index 3ce84554ab4d..5282f22e4f45 100644 --- a/sdk/ai/azure-ai-inference/samples/sample_image_embeddings_with_defaults.py +++ b/sdk/ai/azure-ai-inference/samples/sample_image_embeddings_with_defaults.py @@ -49,10 +49,7 @@ def sample_image_embeddings_with_defaults(): # Create a client with default embeddings settings client = ImageEmbeddingsClient( - endpoint=endpoint, - credential=AzureKeyCredential(key), - dimensions=1024, - input_type=EmbeddingInputType.QUERY + endpoint=endpoint, credential=AzureKeyCredential(key), dimensions=1024, input_type=EmbeddingInputType.QUERY ) # Call the service with the defaults specified above diff --git a/sdk/ai/azure-ai-inference/sdk_packaging.toml b/sdk/ai/azure-ai-inference/sdk_packaging.toml new file mode 100644 index 000000000000..e7687fdae93b --- /dev/null +++ b/sdk/ai/azure-ai-inference/sdk_packaging.toml @@ -0,0 +1,2 @@ +[packaging] +auto_update = false \ No newline at end of file diff --git a/sdk/ai/azure-ai-inference/setup.py b/sdk/ai/azure-ai-inference/setup.py index c264ae00239e..c7b5395a3f9f 100644 --- a/sdk/ai/azure-ai-inference/setup.py +++ b/sdk/ai/azure-ai-inference/setup.py @@ -35,7 +35,7 @@ license="MIT License", author="Microsoft Corporation", author_email="azpysdkhelp@microsoft.com", - url="https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/ai/azure-ai-inference", + url="https://github.com/Azure/azure-sdk-for-python/tree/main/sdk", keywords="azure, azure sdk", classifiers=[ "Development Status :: 4 - Beta", @@ -68,7 +68,4 @@ "typing-extensions>=4.6.0", ], python_requires=">=3.8", - extras_require={ - 'opentelemetry': ['azure-core-tracing-opentelemetry'] - } ) diff --git a/sdk/ai/azure-ai-inference/tests/gen_ai_trace_verifier.py b/sdk/ai/azure-ai-inference/tests/gen_ai_trace_verifier.py index 29bb2ef57f47..a105b60cf8ac 100644 --- a/sdk/ai/azure-ai-inference/tests/gen_ai_trace_verifier.py +++ b/sdk/ai/azure-ai-inference/tests/gen_ai_trace_verifier.py @@ -10,11 +10,11 @@ class GenAiTraceVerifier: def check_span_attributes(self, span, attributes): - # Convert the list of tuples to a dictionary for easier lookup + # Convert the list of tuples to a dictionary for easier lookup attribute_dict = dict(attributes) - + for attribute_name in span.attributes.keys(): - # Check if the attribute name exists in the input attributes + # Check if the attribute name exists in the input attributes if attribute_name not in attribute_dict: return False @@ -26,7 +26,7 @@ def check_span_attributes(self, span, attributes): elif isinstance(attribute_value, tuple): # Check if the attribute value in the span matches the provided list if span.attributes[attribute_name] != attribute_value: - return False + return False else: # Check if the attribute value matches the provided value if attribute_value == "+": @@ -62,7 +62,7 @@ def check_event_attributes(self, expected_dict, actual_dict): return False for key, expected_val in expected_dict.items(): if key not in actual_dict: - return False + return False actual_val = actual_dict[key] if self.is_valid_json(expected_val): @@ -72,17 +72,17 @@ def check_event_attributes(self, expected_dict, actual_dict): return False elif isinstance(expected_val, dict): if not isinstance(actual_val, dict): - return False + return False if not self.check_event_attributes(expected_val, actual_val): return False - elif isinstance(expected_val, list): - if not isinstance(actual_val, list): + elif isinstance(expected_val, list): + if not isinstance(actual_val, list): return False if len(expected_val) != len(actual_val): return False - for expected_list, actual_list in zip(expected_val, actual_val): - if not self.check_event_attributes(expected_list, actual_list): - return False + for expected_list, actual_list in zip(expected_val, actual_val): + if not self.check_event_attributes(expected_list, actual_list): + return False elif isinstance(expected_val, str) and expected_val == "*": if actual_val == "": return False @@ -95,8 +95,8 @@ def check_span_events(self, span, expected_events): for expected_event in expected_events: for actual_event in span_events: - if expected_event['name'] == actual_event.name: - if not self.check_event_attributes(expected_event['attributes'], actual_event.attributes): + if expected_event["name"] == actual_event.name: + if not self.check_event_attributes(expected_event["attributes"], actual_event.attributes): return False span_events.remove(actual_event) # Remove the matched event from the span_events break diff --git a/sdk/ai/azure-ai-inference/tests/memory_trace_exporter.py b/sdk/ai/azure-ai-inference/tests/memory_trace_exporter.py index 7b609fbf5724..d0007f6f1bdc 100644 --- a/sdk/ai/azure-ai-inference/tests/memory_trace_exporter.py +++ b/sdk/ai/azure-ai-inference/tests/memory_trace_exporter.py @@ -34,6 +34,6 @@ def get_spans_by_name_starts_with(self, name_prefix: str) -> List[Span]: def get_spans_by_name(self, name: str) -> List[Span]: return [span for span in self._trace_list if span.name == name] - + def get_spans(self) -> List[Span]: - return [span for span in self._trace_list] \ No newline at end of file + return [span for span in self._trace_list] diff --git a/sdk/ai/azure-ai-inference/tests/test_model_inference_async_client.py b/sdk/ai/azure-ai-inference/tests/test_model_inference_async_client.py index 3be34667d424..e0f7360dc476 100644 --- a/sdk/ai/azure-ai-inference/tests/test_model_inference_async_client.py +++ b/sdk/ai/azure-ai-inference/tests/test_model_inference_async_client.py @@ -28,6 +28,7 @@ CONTENT_TRACING_ENV_VARIABLE = "AZURE_TRACING_GEN_AI_CONTENT_RECORDING_ENABLED" content_tracing_initial_value = os.getenv(CONTENT_TRACING_ENV_VARIABLE) + # The test class name needs to start with "Test" to get collected by pytest class TestModelAsyncClient(ModelClientTestBase): @@ -492,7 +493,7 @@ async def test_async_load_chat_completions_client(self, **kwargs): response1 = await client.get_model_info() self._print_model_info_result(response1) self._validate_model_info_result( - response1, "chat-completion" # TODO: This should be chat_completions based on REST API spec... + response1, "chat-completion" # TODO: This should be chat_completions based on REST API spec... ) # TODO: This should be ModelType.CHAT once the model is fixed await client.close() @@ -737,27 +738,29 @@ async def test_chat_completion_async_tracing_content_recording_disabled(self, ** spans = exporter.get_spans_by_name("chat") assert len(spans) == 1 span = spans[0] - expected_attributes = [('gen_ai.operation.name', 'chat'), - ('gen_ai.system', 'az.ai.inference'), - ('gen_ai.request.model', 'chat'), - ('server.address', ''), - ('gen_ai.response.id', ''), - ('gen_ai.response.model', 'mistral-large'), - ('gen_ai.usage.input_tokens', '+'), - ('gen_ai.usage.output_tokens', '+'), - ('gen_ai.response.finish_reasons', ('stop',))] + expected_attributes = [ + ("gen_ai.operation.name", "chat"), + ("gen_ai.system", "az.ai.inference"), + ("gen_ai.request.model", "chat"), + ("server.address", ""), + ("gen_ai.response.id", ""), + ("gen_ai.response.model", "mistral-large"), + ("gen_ai.usage.input_tokens", "+"), + ("gen_ai.usage.output_tokens", "+"), + ("gen_ai.response.finish_reasons", ("stop",)), + ] attributes_match = GenAiTraceVerifier().check_span_attributes(span, expected_attributes) assert attributes_match == True expected_events = [ - { - 'name': 'gen_ai.choice', - 'attributes': { - 'gen_ai.system': 'az.ai.inference', - 'gen_ai.event.content': '{"finish_reason": "stop", "index": 0}' - } + { + "name": "gen_ai.choice", + "attributes": { + "gen_ai.system": "az.ai.inference", + "gen_ai.event.content": '{"finish_reason": "stop", "index": 0}', + }, } ] events_match = GenAiTraceVerifier().check_span_events(span, expected_events) assert events_match == True - AIInferenceInstrumentor().uninstrument() \ No newline at end of file + AIInferenceInstrumentor().uninstrument() diff --git a/sdk/ai/azure-ai-inference/tests/test_model_inference_client.py b/sdk/ai/azure-ai-inference/tests/test_model_inference_client.py index a6cfffea8e8a..9bcb374554a2 100644 --- a/sdk/ai/azure-ai-inference/tests/test_model_inference_client.py +++ b/sdk/ai/azure-ai-inference/tests/test_model_inference_client.py @@ -1,3 +1,4 @@ +# pylint: disable=too-many-lines # ------------------------------------ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. @@ -27,6 +28,7 @@ CONTENT_TRACING_ENV_VARIABLE = "AZURE_TRACING_GEN_AI_CONTENT_RECORDING_ENABLED" content_tracing_initial_value = os.getenv(CONTENT_TRACING_ENV_VARIABLE) + # The test class name needs to start with "Test" to get collected by pytest class TestModelClient(ModelClientTestBase): @@ -559,7 +561,7 @@ def test_get_model_info_on_chat_client(self, **kwargs): self._print_model_info_result(response1) self._validate_model_info_result( - response1, "chat-completion" # TODO: This should be chat_comletions according to REST API spec... + response1, "chat-completion" # TODO: This should be chat_comletions according to REST API spec... ) # TODO: This should be ModelType.CHAT once the model is fixed # Get the model info again. No network calls should be made here, @@ -810,7 +812,6 @@ def test_embeddings_on_chat_completion_endpoint(self, **kwargs): client.close() assert exception_caught - # ********************************************************************************** # # TRACING TESTS - CHAT COMPLETIONS @@ -942,25 +943,27 @@ def test_chat_completion_tracing_content_recording_disabled(self, **kwargs): spans = exporter.get_spans_by_name("chat") assert len(spans) == 1 span = spans[0] - expected_attributes = [('gen_ai.operation.name', 'chat'), - ('gen_ai.system', 'az.ai.inference'), - ('gen_ai.request.model', 'chat'), - ('server.address', ''), - ('gen_ai.response.id', ''), - ('gen_ai.response.model', 'mistral-large'), - ('gen_ai.usage.input_tokens', '+'), - ('gen_ai.usage.output_tokens', '+'), - ('gen_ai.response.finish_reasons', ('stop',))] + expected_attributes = [ + ("gen_ai.operation.name", "chat"), + ("gen_ai.system", "az.ai.inference"), + ("gen_ai.request.model", "chat"), + ("server.address", ""), + ("gen_ai.response.id", ""), + ("gen_ai.response.model", "mistral-large"), + ("gen_ai.usage.input_tokens", "+"), + ("gen_ai.usage.output_tokens", "+"), + ("gen_ai.response.finish_reasons", ("stop",)), + ] attributes_match = GenAiTraceVerifier().check_span_attributes(span, expected_attributes) assert attributes_match == True expected_events = [ - { - 'name': 'gen_ai.choice', - 'attributes': { - 'gen_ai.system': 'az.ai.inference', - 'gen_ai.event.content': '{"finish_reason": "stop", "index": 0}' - } + { + "name": "gen_ai.choice", + "attributes": { + "gen_ai.system": "az.ai.inference", + "gen_ai.event.content": '{"finish_reason": "stop", "index": 0}', + }, } ] events_match = GenAiTraceVerifier().check_span_events(span, expected_events) @@ -991,40 +994,42 @@ def test_chat_completion_tracing_content_recording_enabled(self, **kwargs): spans = exporter.get_spans_by_name("chat") assert len(spans) == 1 span = spans[0] - expected_attributes = [('gen_ai.operation.name', 'chat'), - ('gen_ai.system', 'az.ai.inference'), - ('gen_ai.request.model', 'chat'), - ('server.address', ''), - ('gen_ai.response.id', ''), - ('gen_ai.response.model', 'mistral-large'), - ('gen_ai.usage.input_tokens', '+'), - ('gen_ai.usage.output_tokens', '+'), - ('gen_ai.response.finish_reasons', ('stop',))] + expected_attributes = [ + ("gen_ai.operation.name", "chat"), + ("gen_ai.system", "az.ai.inference"), + ("gen_ai.request.model", "chat"), + ("server.address", ""), + ("gen_ai.response.id", ""), + ("gen_ai.response.model", "mistral-large"), + ("gen_ai.usage.input_tokens", "+"), + ("gen_ai.usage.output_tokens", "+"), + ("gen_ai.response.finish_reasons", ("stop",)), + ] attributes_match = GenAiTraceVerifier().check_span_attributes(span, expected_attributes) assert attributes_match == True expected_events = [ { - 'name': 'gen_ai.system.message', - 'attributes': { - 'gen_ai.system': 'az.ai.inference', - 'gen_ai.event.content': '{"role": "system", "content": "You are a helpful assistant."}' - } + "name": "gen_ai.system.message", + "attributes": { + "gen_ai.system": "az.ai.inference", + "gen_ai.event.content": '{"role": "system", "content": "You are a helpful assistant."}', + }, }, { - 'name': 'gen_ai.user.message', - 'attributes': { - 'gen_ai.system': 'az.ai.inference', - 'gen_ai.event.content': '{"role": "user", "content": "What is the capital of France?"}' - } + "name": "gen_ai.user.message", + "attributes": { + "gen_ai.system": "az.ai.inference", + "gen_ai.event.content": '{"role": "user", "content": "What is the capital of France?"}', + }, }, { - 'name': 'gen_ai.choice', - 'attributes': { - 'gen_ai.system': 'az.ai.inference', - 'gen_ai.event.content': '{"message": {"content": "*"}, "finish_reason": "stop", "index": 0}' - } - } + "name": "gen_ai.choice", + "attributes": { + "gen_ai.system": "az.ai.inference", + "gen_ai.event.content": '{"message": {"content": "*"}, "finish_reason": "stop", "index": 0}', + }, + }, ] events_match = GenAiTraceVerifier().check_span_events(span, expected_events) assert events_match == True @@ -1047,7 +1052,7 @@ def test_chat_completion_streaming_tracing_content_recording_disabled(self, **kw sdk.models.SystemMessage(content="You are a helpful assistant."), sdk.models.UserMessage(content="What is the capital of France?"), ], - stream=True + stream=True, ) response_content = "" for update in response: @@ -1061,25 +1066,27 @@ def test_chat_completion_streaming_tracing_content_recording_disabled(self, **kw spans = exporter.get_spans_by_name("chat") assert len(spans) == 1 span = spans[0] - expected_attributes = [('gen_ai.operation.name', 'chat'), - ('gen_ai.system', 'az.ai.inference'), - ('gen_ai.request.model', 'chat'), - ('server.address', ''), - ('gen_ai.response.id', ''), - ('gen_ai.response.model', 'mistral-large'), - ('gen_ai.usage.input_tokens', '+'), - ('gen_ai.usage.output_tokens', '+'), - ('gen_ai.response.finish_reasons', ('stop',))] + expected_attributes = [ + ("gen_ai.operation.name", "chat"), + ("gen_ai.system", "az.ai.inference"), + ("gen_ai.request.model", "chat"), + ("server.address", ""), + ("gen_ai.response.id", ""), + ("gen_ai.response.model", "mistral-large"), + ("gen_ai.usage.input_tokens", "+"), + ("gen_ai.usage.output_tokens", "+"), + ("gen_ai.response.finish_reasons", ("stop",)), + ] attributes_match = GenAiTraceVerifier().check_span_attributes(span, expected_attributes) assert attributes_match == True expected_events = [ { - 'name': 'gen_ai.choice', - 'attributes': { - 'gen_ai.system': 'az.ai.inference', - 'gen_ai.event.content': '{"finish_reason": "stop", "index": 0}' - } + "name": "gen_ai.choice", + "attributes": { + "gen_ai.system": "az.ai.inference", + "gen_ai.event.content": '{"finish_reason": "stop", "index": 0}', + }, } ] events_match = GenAiTraceVerifier().check_span_events(span, expected_events) @@ -1103,7 +1110,7 @@ def test_chat_completion_streaming_tracing_content_recording_enabled(self, **kwa sdk.models.SystemMessage(content="You are a helpful assistant."), sdk.models.UserMessage(content="What is the capital of France?"), ], - stream=True + stream=True, ) response_content = "" for update in response: @@ -1117,40 +1124,42 @@ def test_chat_completion_streaming_tracing_content_recording_enabled(self, **kwa spans = exporter.get_spans_by_name("chat") assert len(spans) == 1 span = spans[0] - expected_attributes = [('gen_ai.operation.name', 'chat'), - ('gen_ai.system', 'az.ai.inference'), - ('gen_ai.request.model', 'chat'), - ('server.address', ''), - ('gen_ai.response.id', ''), - ('gen_ai.response.model', 'mistral-large'), - ('gen_ai.usage.input_tokens', '+'), - ('gen_ai.usage.output_tokens', '+'), - ('gen_ai.response.finish_reasons', ('stop',))] + expected_attributes = [ + ("gen_ai.operation.name", "chat"), + ("gen_ai.system", "az.ai.inference"), + ("gen_ai.request.model", "chat"), + ("server.address", ""), + ("gen_ai.response.id", ""), + ("gen_ai.response.model", "mistral-large"), + ("gen_ai.usage.input_tokens", "+"), + ("gen_ai.usage.output_tokens", "+"), + ("gen_ai.response.finish_reasons", ("stop",)), + ] attributes_match = GenAiTraceVerifier().check_span_attributes(span, expected_attributes) assert attributes_match == True expected_events = [ { - 'name': 'gen_ai.system.message', - 'attributes': { - 'gen_ai.system': 'az.ai.inference', - 'gen_ai.event.content': '{"role": "system", "content": "You are a helpful assistant."}' - } + "name": "gen_ai.system.message", + "attributes": { + "gen_ai.system": "az.ai.inference", + "gen_ai.event.content": '{"role": "system", "content": "You are a helpful assistant."}', + }, }, { - 'name': 'gen_ai.user.message', - 'attributes': { - 'gen_ai.system': 'az.ai.inference', - 'gen_ai.event.content': '{"role": "user", "content": "What is the capital of France?"}' - } + "name": "gen_ai.user.message", + "attributes": { + "gen_ai.system": "az.ai.inference", + "gen_ai.event.content": '{"role": "user", "content": "What is the capital of France?"}', + }, }, { - 'name': 'gen_ai.choice', - 'attributes': { - 'gen_ai.system': 'az.ai.inference', - 'gen_ai.event.content': '{"message": {"content": "*"}, "finish_reason": "stop", "index": 0}' - } - } + "name": "gen_ai.choice", + "attributes": { + "gen_ai.system": "az.ai.inference", + "gen_ai.event.content": '{"message": {"content": "*"}, "finish_reason": "stop", "index": 0}', + }, + }, ] events_match = GenAiTraceVerifier().check_span_events(span, expected_events) assert events_match == True @@ -1165,7 +1174,16 @@ def test_chat_completion_with_function_call_tracing_content_recording_enabled(se except RuntimeError as e: pass import json - from azure.ai.inference.models import SystemMessage, UserMessage, CompletionsFinishReason, ToolMessage, AssistantMessage, ChatCompletionsToolCall, ChatCompletionsToolDefinition, FunctionDefinition + from azure.ai.inference.models import ( + SystemMessage, + UserMessage, + CompletionsFinishReason, + ToolMessage, + AssistantMessage, + ChatCompletionsToolCall, + ChatCompletionsToolDefinition, + FunctionDefinition, + ) from azure.ai.inference import ChatCompletionsClient self.modify_env_var(CONTENT_TRACING_ENV_VARIABLE, "True") @@ -1197,7 +1215,7 @@ def get_weather(city: str) -> str: }, ) ) - messages=[ + messages = [ sdk.models.SystemMessage(content="You are a helpful assistant."), sdk.models.UserMessage(content="What is the weather in Seattle?"), ] @@ -1225,26 +1243,30 @@ def get_weather(city: str) -> str: if len(spans) == 0: spans = exporter.get_spans_by_name("chat") assert len(spans) == 2 - expected_attributes = [('gen_ai.operation.name', 'chat'), - ('gen_ai.system', 'az.ai.inference'), - ('gen_ai.request.model', 'chat'), - ('server.address', ''), - ('gen_ai.response.id', ''), - ('gen_ai.response.model', 'mistral-large'), - ('gen_ai.usage.input_tokens', '+'), - ('gen_ai.usage.output_tokens', '+'), - ('gen_ai.response.finish_reasons', ('tool_calls',))] + expected_attributes = [ + ("gen_ai.operation.name", "chat"), + ("gen_ai.system", "az.ai.inference"), + ("gen_ai.request.model", "chat"), + ("server.address", ""), + ("gen_ai.response.id", ""), + ("gen_ai.response.model", "mistral-large"), + ("gen_ai.usage.input_tokens", "+"), + ("gen_ai.usage.output_tokens", "+"), + ("gen_ai.response.finish_reasons", ("tool_calls",)), + ] attributes_match = GenAiTraceVerifier().check_span_attributes(spans[0], expected_attributes) assert attributes_match == True - expected_attributes = [('gen_ai.operation.name', 'chat'), - ('gen_ai.system', 'az.ai.inference'), - ('gen_ai.request.model', 'chat'), - ('server.address', ''), - ('gen_ai.response.id', ''), - ('gen_ai.response.model', 'mistral-large'), - ('gen_ai.usage.input_tokens', '+'), - ('gen_ai.usage.output_tokens', '+'), - ('gen_ai.response.finish_reasons', ('stop',))] + expected_attributes = [ + ("gen_ai.operation.name", "chat"), + ("gen_ai.system", "az.ai.inference"), + ("gen_ai.request.model", "chat"), + ("server.address", ""), + ("gen_ai.response.id", ""), + ("gen_ai.response.model", "mistral-large"), + ("gen_ai.usage.input_tokens", "+"), + ("gen_ai.usage.output_tokens", "+"), + ("gen_ai.response.finish_reasons", ("stop",)), + ] attributes_match = GenAiTraceVerifier().check_span_attributes(spans[1], expected_attributes) assert attributes_match == True @@ -1254,25 +1276,25 @@ def get_weather(city: str) -> str: "timestamp": "*", "attributes": { "gen_ai.system": "az.ai.inference", - "gen_ai.event.content": "{\"role\": \"system\", \"content\": \"You are a helpful assistant.\"}" - } + "gen_ai.event.content": '{"role": "system", "content": "You are a helpful assistant."}', + }, }, { "name": "gen_ai.user.message", "timestamp": "*", "attributes": { "gen_ai.system": "az.ai.inference", - "gen_ai.event.content": "{\"role\": \"user\", \"content\": \"What is the weather in Seattle?\"}" - } + "gen_ai.event.content": '{"role": "user", "content": "What is the weather in Seattle?"}', + }, }, { "name": "gen_ai.choice", "timestamp": "*", "attributes": { "gen_ai.system": "az.ai.inference", - "gen_ai.event.content": "{\"message\": {\"content\": \"\", \"tool_calls\": [{\"function\": {\"arguments\": \"{\\\"city\\\":\\\"Seattle\\\"}\", \"call_id\": null, \"name\": \"get_weather\"}, \"id\": \"*\", \"type\": \"function\"}]}, \"finish_reason\": \"tool_calls\", \"index\": 0}" - } - } + "gen_ai.event.content": '{"message": {"content": "", "tool_calls": [{"function": {"arguments": "{\\"city\\":\\"Seattle\\"}", "call_id": null, "name": "get_weather"}, "id": "*", "type": "function"}]}, "finish_reason": "tool_calls", "index": 0}', + }, + }, ] events_match = GenAiTraceVerifier().check_span_events(spans[0], expected_events) assert events_match == True @@ -1283,43 +1305,43 @@ def get_weather(city: str) -> str: "timestamp": "*", "attributes": { "gen_ai.system": "az.ai.inference", - "gen_ai.event.content": "{\"role\": \"system\", \"content\": \"You are a helpful assistant.\"}" - } + "gen_ai.event.content": '{"role": "system", "content": "You are a helpful assistant."}', + }, }, { "name": "gen_ai.user.message", "timestamp": "*", "attributes": { "gen_ai.system": "az.ai.inference", - "gen_ai.event.content": "{\"role\": \"user\", \"content\": \"What is the weather in Seattle?\"}" - } + "gen_ai.event.content": '{"role": "user", "content": "What is the weather in Seattle?"}', + }, }, { "name": "gen_ai.assistant.message", "timestamp": "*", "attributes": { "gen_ai.system": "az.ai.inference", - "gen_ai.event.content": "{\"role\": \"assistant\", \"tool_calls\": [{\"function\": {\"arguments\": \"{\\\"city\\\": \\\"Seattle\\\"}\", \"call_id\": null, \"name\": \"get_weather\"}, \"id\": \"*\", \"type\": \"function\"}]}" - } + "gen_ai.event.content": '{"role": "assistant", "tool_calls": [{"function": {"arguments": "{\\"city\\": \\"Seattle\\"}", "call_id": null, "name": "get_weather"}, "id": "*", "type": "function"}]}', + }, }, { "name": "gen_ai.tool.message", "timestamp": "*", "attributes": { "gen_ai.system": "az.ai.inference", - "gen_ai.event.content": "{\"role\": \"tool\", \"tool_call_id\": \"*\", \"content\": \"Nice weather\"}" - } + "gen_ai.event.content": '{"role": "tool", "tool_call_id": "*", "content": "Nice weather"}', + }, }, { "name": "gen_ai.choice", "timestamp": "*", "attributes": { "gen_ai.system": "az.ai.inference", - "gen_ai.event.content": "{\"message\": {\"content\": \"*\"}, \"finish_reason\": \"stop\", \"index\": 0}" - } - } - ] - events_match = GenAiTraceVerifier().check_span_events(spans[1], expected_events) + "gen_ai.event.content": '{"message": {"content": "*"}, "finish_reason": "stop", "index": 0}', + }, + }, + ] + events_match = GenAiTraceVerifier().check_span_events(spans[1], expected_events) assert events_match == True AIInferenceInstrumentor().uninstrument() @@ -1333,7 +1355,16 @@ def test_chat_completion_with_function_call_tracing_content_recording_disabled(s except RuntimeError as e: pass import json - from azure.ai.inference.models import SystemMessage, UserMessage, CompletionsFinishReason, ToolMessage, AssistantMessage, ChatCompletionsToolCall, ChatCompletionsToolDefinition, FunctionDefinition + from azure.ai.inference.models import ( + SystemMessage, + UserMessage, + CompletionsFinishReason, + ToolMessage, + AssistantMessage, + ChatCompletionsToolCall, + ChatCompletionsToolDefinition, + FunctionDefinition, + ) from azure.ai.inference import ChatCompletionsClient self.modify_env_var(CONTENT_TRACING_ENV_VARIABLE, "False") @@ -1365,7 +1396,7 @@ def get_weather(city: str) -> str: }, ) ) - messages=[ + messages = [ sdk.models.SystemMessage(content="You are a helpful assistant."), sdk.models.UserMessage(content="What is the weather in Seattle?"), ] @@ -1393,26 +1424,30 @@ def get_weather(city: str) -> str: if len(spans) == 0: spans = exporter.get_spans_by_name("chat") assert len(spans) == 2 - expected_attributes = [('gen_ai.operation.name', 'chat'), - ('gen_ai.system', 'az.ai.inference'), - ('gen_ai.request.model', 'chat'), - ('server.address', ''), - ('gen_ai.response.id', ''), - ('gen_ai.response.model', 'mistral-large'), - ('gen_ai.usage.input_tokens', '+'), - ('gen_ai.usage.output_tokens', '+'), - ('gen_ai.response.finish_reasons', ('tool_calls',))] + expected_attributes = [ + ("gen_ai.operation.name", "chat"), + ("gen_ai.system", "az.ai.inference"), + ("gen_ai.request.model", "chat"), + ("server.address", ""), + ("gen_ai.response.id", ""), + ("gen_ai.response.model", "mistral-large"), + ("gen_ai.usage.input_tokens", "+"), + ("gen_ai.usage.output_tokens", "+"), + ("gen_ai.response.finish_reasons", ("tool_calls",)), + ] attributes_match = GenAiTraceVerifier().check_span_attributes(spans[0], expected_attributes) assert attributes_match == True - expected_attributes = [('gen_ai.operation.name', 'chat'), - ('gen_ai.system', 'az.ai.inference'), - ('gen_ai.request.model', 'chat'), - ('server.address', ''), - ('gen_ai.response.id', ''), - ('gen_ai.response.model', 'mistral-large'), - ('gen_ai.usage.input_tokens', '+'), - ('gen_ai.usage.output_tokens', '+'), - ('gen_ai.response.finish_reasons', ('stop',))] + expected_attributes = [ + ("gen_ai.operation.name", "chat"), + ("gen_ai.system", "az.ai.inference"), + ("gen_ai.request.model", "chat"), + ("server.address", ""), + ("gen_ai.response.id", ""), + ("gen_ai.response.model", "mistral-large"), + ("gen_ai.usage.input_tokens", "+"), + ("gen_ai.usage.output_tokens", "+"), + ("gen_ai.response.finish_reasons", ("stop",)), + ] attributes_match = GenAiTraceVerifier().check_span_attributes(spans[1], expected_attributes) assert attributes_match == True @@ -1422,8 +1457,8 @@ def get_weather(city: str) -> str: "timestamp": "*", "attributes": { "gen_ai.system": "az.ai.inference", - "gen_ai.event.content": "{\"finish_reason\": \"tool_calls\", \"index\": 0, \"message\": {\"tool_calls\": [{\"function\": {\"call_id\": null}, \"id\": \"*\", \"type\": \"function\"}]}}" - } + "gen_ai.event.content": '{"finish_reason": "tool_calls", "index": 0, "message": {"tool_calls": [{"function": {"call_id": null}, "id": "*", "type": "function"}]}}', + }, } ] events_match = GenAiTraceVerifier().check_span_events(spans[0], expected_events) @@ -1435,11 +1470,11 @@ def get_weather(city: str) -> str: "timestamp": "*", "attributes": { "gen_ai.system": "az.ai.inference", - "gen_ai.event.content": "{\"finish_reason\": \"stop\", \"index\": 0}" - } + "gen_ai.event.content": '{"finish_reason": "stop", "index": 0}', + }, } - ] - events_match = GenAiTraceVerifier().check_span_events(spans[1], expected_events) + ] + events_match = GenAiTraceVerifier().check_span_events(spans[1], expected_events) assert events_match == True AIInferenceInstrumentor().uninstrument() @@ -1453,7 +1488,17 @@ def test_chat_completion_with_function_call_streaming_tracing_content_recording_ except RuntimeError as e: pass import json - from azure.ai.inference.models import SystemMessage, UserMessage, CompletionsFinishReason, FunctionCall, ToolMessage, AssistantMessage, ChatCompletionsToolCall, ChatCompletionsToolDefinition, FunctionDefinition + from azure.ai.inference.models import ( + SystemMessage, + UserMessage, + CompletionsFinishReason, + FunctionCall, + ToolMessage, + AssistantMessage, + ChatCompletionsToolCall, + ChatCompletionsToolDefinition, + FunctionDefinition, + ) from azure.ai.inference import ChatCompletionsClient self.modify_env_var(CONTENT_TRACING_ENV_VARIABLE, "True") @@ -1485,15 +1530,12 @@ def get_weather(city: str) -> str: }, ) ) - messages=[ + messages = [ sdk.models.SystemMessage(content="You are a helpful AI assistant."), sdk.models.UserMessage(content="What is the weather in Seattle?"), ] - response = client.complete( - messages=messages, - tools=[weather_description], - stream=True) + response = client.complete(messages=messages, tools=[weather_description], stream=True) # At this point we expect a function tool call in the model response tool_call_id: str = "" @@ -1506,17 +1548,13 @@ def get_weather(city: str) -> str: if update.choices[0].delta.tool_calls[0].id is not None: tool_call_id = update.choices[0].delta.tool_calls[0].id function_args += update.choices[0].delta.tool_calls[0].function.arguments or "" - + # Append the previous model response to the chat history messages.append( AssistantMessage( tool_calls=[ ChatCompletionsToolCall( - id=tool_call_id, - function=FunctionCall( - name=function_name, - arguments=function_args - ) + id=tool_call_id, function=FunctionCall(name=function_name, arguments=function_args) ) ] ) @@ -1528,19 +1566,10 @@ def get_weather(city: str) -> str: function_response = callable_func(**function_args_mapping) # Append the function response as a tool message to the chat history - messages.append( - ToolMessage( - tool_call_id=tool_call_id, - content=function_response - ) - ) + messages.append(ToolMessage(tool_call_id=tool_call_id, content=function_response)) # With the additional tools information on hand, get another streaming response from the model - response = client.complete( - messages=messages, - tools=[weather_description], - stream=True - ) + response = client.complete(messages=messages, tools=[weather_description], stream=True) content = "" for update in response: @@ -1551,26 +1580,30 @@ def get_weather(city: str) -> str: if len(spans) == 0: spans = exporter.get_spans_by_name("chat") assert len(spans) == 2 - expected_attributes = [('gen_ai.operation.name', 'chat'), - ('gen_ai.system', 'az.ai.inference'), - ('gen_ai.request.model', 'chat'), - ('server.address', ''), - ('gen_ai.response.id', ''), - ('gen_ai.response.model', 'mistral-large'), - ('gen_ai.usage.input_tokens', '+'), - ('gen_ai.usage.output_tokens', '+'), - ('gen_ai.response.finish_reasons', ('tool_calls',))] + expected_attributes = [ + ("gen_ai.operation.name", "chat"), + ("gen_ai.system", "az.ai.inference"), + ("gen_ai.request.model", "chat"), + ("server.address", ""), + ("gen_ai.response.id", ""), + ("gen_ai.response.model", "mistral-large"), + ("gen_ai.usage.input_tokens", "+"), + ("gen_ai.usage.output_tokens", "+"), + ("gen_ai.response.finish_reasons", ("tool_calls",)), + ] attributes_match = GenAiTraceVerifier().check_span_attributes(spans[0], expected_attributes) assert attributes_match == True - expected_attributes = [('gen_ai.operation.name', 'chat'), - ('gen_ai.system', 'az.ai.inference'), - ('gen_ai.request.model', 'chat'), - ('server.address', ''), - ('gen_ai.response.id', ''), - ('gen_ai.response.model', 'mistral-large'), - ('gen_ai.usage.input_tokens', '+'), - ('gen_ai.usage.output_tokens', '+'), - ('gen_ai.response.finish_reasons', ('stop',))] + expected_attributes = [ + ("gen_ai.operation.name", "chat"), + ("gen_ai.system", "az.ai.inference"), + ("gen_ai.request.model", "chat"), + ("server.address", ""), + ("gen_ai.response.id", ""), + ("gen_ai.response.model", "mistral-large"), + ("gen_ai.usage.input_tokens", "+"), + ("gen_ai.usage.output_tokens", "+"), + ("gen_ai.response.finish_reasons", ("stop",)), + ] attributes_match = GenAiTraceVerifier().check_span_attributes(spans[1], expected_attributes) assert attributes_match == True @@ -1580,25 +1613,25 @@ def get_weather(city: str) -> str: "timestamp": "*", "attributes": { "gen_ai.system": "az.ai.inference", - "gen_ai.event.content": "{\"role\": \"system\", \"content\": \"You are a helpful AI assistant.\"}" - } + "gen_ai.event.content": '{"role": "system", "content": "You are a helpful AI assistant."}', + }, }, { "name": "gen_ai.user.message", "timestamp": "*", "attributes": { "gen_ai.system": "az.ai.inference", - "gen_ai.event.content": "{\"role\": \"user\", \"content\": \"What is the weather in Seattle?\"}" - } + "gen_ai.event.content": '{"role": "user", "content": "What is the weather in Seattle?"}', + }, }, { "name": "gen_ai.choice", "timestamp": "*", "attributes": { "gen_ai.system": "az.ai.inference", - "gen_ai.event.content": "{\"finish_reason\": \"tool_calls\", \"message\": {\"tool_calls\": [{\"id\": \"*\", \"type\": \"function\", \"function\": {\"name\": \"get_weather\", \"arguments\": \"{\\\"city\\\": \\\"Seattle\\\"}\"}}]}, \"index\": 0}" - } - } + "gen_ai.event.content": '{"finish_reason": "tool_calls", "message": {"tool_calls": [{"id": "*", "type": "function", "function": {"name": "get_weather", "arguments": "{\\"city\\": \\"Seattle\\"}"}}]}, "index": 0}', + }, + }, ] events_match = GenAiTraceVerifier().check_span_events(spans[0], expected_events) assert events_match == True @@ -1609,43 +1642,43 @@ def get_weather(city: str) -> str: "timestamp": "*", "attributes": { "gen_ai.system": "az.ai.inference", - "gen_ai.event.content": "{\"role\": \"system\", \"content\": \"You are a helpful AI assistant.\"}" - } + "gen_ai.event.content": '{"role": "system", "content": "You are a helpful AI assistant."}', + }, }, { "name": "gen_ai.user.message", "timestamp": "*", "attributes": { "gen_ai.system": "az.ai.inference", - "gen_ai.event.content": "{\"role\": \"user\", \"content\": \"What is the weather in Seattle?\"}" - } + "gen_ai.event.content": '{"role": "user", "content": "What is the weather in Seattle?"}', + }, }, { "name": "gen_ai.assistant.message", "timestamp": "*", "attributes": { "gen_ai.system": "az.ai.inference", - "gen_ai.event.content": "{\"role\": \"assistant\", \"tool_calls\": [{\"id\": \"*\", \"function\": {\"name\": \"get_weather\", \"arguments\": \"{\\\"city\\\": \\\"Seattle\\\"}\"}, \"type\": \"function\"}]}" - } + "gen_ai.event.content": '{"role": "assistant", "tool_calls": [{"id": "*", "function": {"name": "get_weather", "arguments": "{\\"city\\": \\"Seattle\\"}"}, "type": "function"}]}', + }, }, { "name": "gen_ai.tool.message", "timestamp": "*", "attributes": { "gen_ai.system": "az.ai.inference", - "gen_ai.event.content": "{\"role\": \"tool\", \"tool_call_id\": \"*\", \"content\": \"Nice weather\"}" - } + "gen_ai.event.content": '{"role": "tool", "tool_call_id": "*", "content": "Nice weather"}', + }, }, { "name": "gen_ai.choice", "timestamp": "*", "attributes": { "gen_ai.system": "az.ai.inference", - "gen_ai.event.content": "{\"message\": {\"content\": \"*\"}, \"finish_reason\": \"stop\", \"index\": 0}" - } - } - ] - events_match = GenAiTraceVerifier().check_span_events(spans[1], expected_events) + "gen_ai.event.content": '{"message": {"content": "*"}, "finish_reason": "stop", "index": 0}', + }, + }, + ] + events_match = GenAiTraceVerifier().check_span_events(spans[1], expected_events) assert events_match == True AIInferenceInstrumentor().uninstrument() @@ -1659,7 +1692,17 @@ def test_chat_completion_with_function_call_streaming_tracing_content_recording_ except RuntimeError as e: pass import json - from azure.ai.inference.models import SystemMessage, UserMessage, CompletionsFinishReason, FunctionCall, ToolMessage, AssistantMessage, ChatCompletionsToolCall, ChatCompletionsToolDefinition, FunctionDefinition + from azure.ai.inference.models import ( + SystemMessage, + UserMessage, + CompletionsFinishReason, + FunctionCall, + ToolMessage, + AssistantMessage, + ChatCompletionsToolCall, + ChatCompletionsToolDefinition, + FunctionDefinition, + ) from azure.ai.inference import ChatCompletionsClient self.modify_env_var(CONTENT_TRACING_ENV_VARIABLE, "False") @@ -1691,15 +1734,12 @@ def get_weather(city: str) -> str: }, ) ) - messages=[ + messages = [ sdk.models.SystemMessage(content="You are a helpful assistant."), sdk.models.UserMessage(content="What is the weather in Seattle?"), ] - response = client.complete( - messages=messages, - tools=[weather_description], - stream=True) + response = client.complete(messages=messages, tools=[weather_description], stream=True) # At this point we expect a function tool call in the model response tool_call_id: str = "" @@ -1712,17 +1752,13 @@ def get_weather(city: str) -> str: if update.choices[0].delta.tool_calls[0].id is not None: tool_call_id = update.choices[0].delta.tool_calls[0].id function_args += update.choices[0].delta.tool_calls[0].function.arguments or "" - + # Append the previous model response to the chat history messages.append( AssistantMessage( tool_calls=[ ChatCompletionsToolCall( - id=tool_call_id, - function=FunctionCall( - name=function_name, - arguments=function_args - ) + id=tool_call_id, function=FunctionCall(name=function_name, arguments=function_args) ) ] ) @@ -1734,19 +1770,10 @@ def get_weather(city: str) -> str: function_response = callable_func(**function_args_mapping) # Append the function response as a tool message to the chat history - messages.append( - ToolMessage( - tool_call_id=tool_call_id, - content=function_response - ) - ) + messages.append(ToolMessage(tool_call_id=tool_call_id, content=function_response)) # With the additional tools information on hand, get another streaming response from the model - response = client.complete( - messages=messages, - tools=[weather_description], - stream=True - ) + response = client.complete(messages=messages, tools=[weather_description], stream=True) content = "" for update in response: @@ -1757,26 +1784,30 @@ def get_weather(city: str) -> str: if len(spans) == 0: spans = exporter.get_spans_by_name("chat") assert len(spans) == 2 - expected_attributes = [('gen_ai.operation.name', 'chat'), - ('gen_ai.system', 'az.ai.inference'), - ('gen_ai.request.model', 'chat'), - ('server.address', ''), - ('gen_ai.response.id', ''), - ('gen_ai.response.model', 'mistral-large'), - ('gen_ai.usage.input_tokens', '+'), - ('gen_ai.usage.output_tokens', '+'), - ('gen_ai.response.finish_reasons', ('tool_calls',))] + expected_attributes = [ + ("gen_ai.operation.name", "chat"), + ("gen_ai.system", "az.ai.inference"), + ("gen_ai.request.model", "chat"), + ("server.address", ""), + ("gen_ai.response.id", ""), + ("gen_ai.response.model", "mistral-large"), + ("gen_ai.usage.input_tokens", "+"), + ("gen_ai.usage.output_tokens", "+"), + ("gen_ai.response.finish_reasons", ("tool_calls",)), + ] attributes_match = GenAiTraceVerifier().check_span_attributes(spans[0], expected_attributes) assert attributes_match == True - expected_attributes = [('gen_ai.operation.name', 'chat'), - ('gen_ai.system', 'az.ai.inference'), - ('gen_ai.request.model', 'chat'), - ('server.address', ''), - ('gen_ai.response.id', ''), - ('gen_ai.response.model', 'mistral-large'), - ('gen_ai.usage.input_tokens', '+'), - ('gen_ai.usage.output_tokens', '+'), - ('gen_ai.response.finish_reasons', ('stop',))] + expected_attributes = [ + ("gen_ai.operation.name", "chat"), + ("gen_ai.system", "az.ai.inference"), + ("gen_ai.request.model", "chat"), + ("server.address", ""), + ("gen_ai.response.id", ""), + ("gen_ai.response.model", "mistral-large"), + ("gen_ai.usage.input_tokens", "+"), + ("gen_ai.usage.output_tokens", "+"), + ("gen_ai.response.finish_reasons", ("stop",)), + ] attributes_match = GenAiTraceVerifier().check_span_attributes(spans[1], expected_attributes) assert attributes_match == True @@ -1786,8 +1817,8 @@ def get_weather(city: str) -> str: "timestamp": "*", "attributes": { "gen_ai.system": "az.ai.inference", - "gen_ai.event.content": "{\"finish_reason\": \"tool_calls\", \"message\": {\"tool_calls\": [{\"id\": \"*\", \"type\": \"function\"}]}, \"index\": 0}" - } + "gen_ai.event.content": '{"finish_reason": "tool_calls", "message": {"tool_calls": [{"id": "*", "type": "function"}]}, "index": 0}', + }, } ] events_match = GenAiTraceVerifier().check_span_events(spans[0], expected_events) @@ -1799,11 +1830,11 @@ def get_weather(city: str) -> str: "timestamp": "*", "attributes": { "gen_ai.system": "az.ai.inference", - "gen_ai.event.content": "{\"finish_reason\": \"stop\", \"index\": 0}" - } + "gen_ai.event.content": '{"finish_reason": "stop", "index": 0}', + }, } - ] - events_match = GenAiTraceVerifier().check_span_events(spans[1], expected_events) + ] + events_match = GenAiTraceVerifier().check_span_events(spans[1], expected_events) assert events_match == True - AIInferenceInstrumentor().uninstrument() \ No newline at end of file + AIInferenceInstrumentor().uninstrument() diff --git a/sdk/ai/azure-ai-inference/tsp-location.yaml b/sdk/ai/azure-ai-inference/tsp-location.yaml index df185250688b..268600f22b4d 100644 --- a/sdk/ai/azure-ai-inference/tsp-location.yaml +++ b/sdk/ai/azure-ai-inference/tsp-location.yaml @@ -1,4 +1,4 @@ directory: specification/ai/ModelClient -commit: 3e95e575e537024a02470cf59c7a78078dc10cd1 +commit: b1ad4b2d6b802834ce695f4b21da2af587f53fba repo: Azure/azure-rest-api-specs -additionalDirectories: +additionalDirectories: