diff --git a/sdk/ai/azure-ai-inference/_meta.json b/sdk/ai/azure-ai-inference/_meta.json new file mode 100644 index 000000000000..8409c7166d1e --- /dev/null +++ b/sdk/ai/azure-ai-inference/_meta.json @@ -0,0 +1,6 @@ +{ + "commit": "834383067ac02f95702b17f494fc1df973bd9455", + "repository_url": "https://github.com/Azure/azure-rest-api-specs", + "typespec_src": "specification/ai/ModelClient", + "@azure-tools/typespec-python": "0.37.2" +} \ No newline at end of file diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/__init__.py b/sdk/ai/azure-ai-inference/azure/ai/inference/__init__.py index ff62b276a309..b7537d16cab3 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/__init__.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/__init__.py @@ -5,24 +5,32 @@ # Code generated by Microsoft (R) Python Code Generator. # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- +# pylint: disable=wrong-import-position -from ._patch import ChatCompletionsClient -from ._patch import EmbeddingsClient -from ._patch import ImageEmbeddingsClient +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ._patch import * # pylint: disable=unused-wildcard-import + +from ._client import ChatCompletionsClient # type: ignore +from ._client import EmbeddingsClient # type: ignore +from ._client import ImageEmbeddingsClient # type: ignore from ._version import VERSION __version__ = VERSION - -from ._patch import load_client +try: + from ._patch import __all__ as _patch_all + from ._patch import * +except ImportError: + _patch_all = [] from ._patch import patch_sdk as _patch_sdk __all__ = [ - "load_client", "ChatCompletionsClient", "EmbeddingsClient", "ImageEmbeddingsClient", ] - +__all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/_client.py b/sdk/ai/azure-ai-inference/azure/ai/inference/_client.py index 25f4b3746e76..5e73e91ea2b2 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/_client.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/_client.py @@ -28,11 +28,10 @@ from ._serialization import Deserializer, Serializer if TYPE_CHECKING: - # pylint: disable=unused-import,ungrouped-imports from azure.core.credentials import TokenCredential -class ChatCompletionsClient(ChatCompletionsClientOperationsMixin): # pylint: disable=client-accepts-api-version-keyword +class ChatCompletionsClient(ChatCompletionsClientOperationsMixin): """ChatCompletionsClient. :param endpoint: Service host. Required. @@ -110,7 +109,7 @@ def __exit__(self, *exc_details: Any) -> None: self._client.__exit__(*exc_details) -class EmbeddingsClient(EmbeddingsClientOperationsMixin): # pylint: disable=client-accepts-api-version-keyword +class EmbeddingsClient(EmbeddingsClientOperationsMixin): """EmbeddingsClient. :param endpoint: Service host. Required. @@ -188,7 +187,7 @@ def __exit__(self, *exc_details: Any) -> None: self._client.__exit__(*exc_details) -class ImageEmbeddingsClient(ImageEmbeddingsClientOperationsMixin): # pylint: disable=client-accepts-api-version-keyword +class ImageEmbeddingsClient(ImageEmbeddingsClientOperationsMixin): """ImageEmbeddingsClient. :param endpoint: Service host. Required. diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/_configuration.py b/sdk/ai/azure-ai-inference/azure/ai/inference/_configuration.py index 9393659fb910..8158dd310196 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/_configuration.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/_configuration.py @@ -14,11 +14,10 @@ from ._version import VERSION if TYPE_CHECKING: - # pylint: disable=unused-import,ungrouped-imports from azure.core.credentials import TokenCredential -class ChatCompletionsClientConfiguration: # pylint: disable=too-many-instance-attributes,name-too-long +class ChatCompletionsClientConfiguration: # pylint: disable=too-many-instance-attributes """Configuration for ChatCompletionsClient. Note that all parameters used to create this instance are saved as instance @@ -75,7 +74,7 @@ def _configure(self, **kwargs: Any) -> None: self.authentication_policy = self._infer_policy(**kwargs) -class EmbeddingsClientConfiguration: # pylint: disable=too-many-instance-attributes,name-too-long +class EmbeddingsClientConfiguration: # pylint: disable=too-many-instance-attributes """Configuration for EmbeddingsClient. Note that all parameters used to create this instance are saved as instance @@ -132,7 +131,7 @@ def _configure(self, **kwargs: Any) -> None: self.authentication_policy = self._infer_policy(**kwargs) -class ImageEmbeddingsClientConfiguration: # pylint: disable=too-many-instance-attributes,name-too-long +class ImageEmbeddingsClientConfiguration: # pylint: disable=too-many-instance-attributes """Configuration for ImageEmbeddingsClient. Note that all parameters used to create this instance are saved as instance diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/_model_base.py b/sdk/ai/azure-ai-inference/azure/ai/inference/_model_base.py index 53305e2213a7..7f73b97b23ef 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/_model_base.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/_model_base.py @@ -1,10 +1,11 @@ +# pylint: disable=too-many-lines # coding=utf-8 # -------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -# pylint: disable=protected-access, arguments-differ, signature-differs, broad-except +# pylint: disable=protected-access, broad-except import copy import calendar @@ -19,6 +20,7 @@ import email.utils from datetime import datetime, date, time, timedelta, timezone from json import JSONEncoder +import xml.etree.ElementTree as ET from typing_extensions import Self import isodate from azure.core.exceptions import DeserializationError @@ -123,7 +125,7 @@ def _serialize_datetime(o, format: typing.Optional[str] = None): def _is_readonly(p): try: - return p._visibility == ["read"] # pylint: disable=protected-access + return p._visibility == ["read"] except AttributeError: return False @@ -286,6 +288,12 @@ def _deserialize_decimal(attr): return decimal.Decimal(str(attr)) +def _deserialize_int_as_str(attr): + if isinstance(attr, int): + return attr + return int(attr) + + _DESERIALIZE_MAPPING = { datetime: _deserialize_datetime, date: _deserialize_date, @@ -307,9 +315,11 @@ def _deserialize_decimal(attr): def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = None): + if annotation is int and rf and rf._format == "str": + return _deserialize_int_as_str if rf and rf._format: return _DESERIALIZE_MAPPING_WITHFORMAT.get(rf._format) - return _DESERIALIZE_MAPPING.get(annotation) + return _DESERIALIZE_MAPPING.get(annotation) # pyright: ignore def _get_type_alias_type(module_name: str, alias_name: str): @@ -441,6 +451,10 @@ def _serialize(o, format: typing.Optional[str] = None): # pylint: disable=too-m return float(o) if isinstance(o, enum.Enum): return o.value + if isinstance(o, int): + if format == "str": + return str(o) + return o try: # First try datetime.datetime return _serialize_datetime(o, format) @@ -471,6 +485,8 @@ def _create_value(rf: typing.Optional["_RestField"], value: typing.Any) -> typin return value if rf._is_model: return _deserialize(rf._type, value) + if isinstance(value, ET.Element): + value = _deserialize(rf._type, value) return _serialize(value, rf._format) @@ -489,10 +505,58 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: for rest_field in self._attr_to_rest_field.values() if rest_field._default is not _UNSET } - if args: - dict_to_pass.update( - {k: _create_value(_get_rest_field(self._attr_to_rest_field, k), v) for k, v in args[0].items()} - ) + if args: # pylint: disable=too-many-nested-blocks + if isinstance(args[0], ET.Element): + existed_attr_keys = [] + model_meta = getattr(self, "_xml", {}) + + for rf in self._attr_to_rest_field.values(): + prop_meta = getattr(rf, "_xml", {}) + xml_name = prop_meta.get("name", rf._rest_name) + xml_ns = prop_meta.get("ns", model_meta.get("ns", None)) + if xml_ns: + xml_name = "{" + xml_ns + "}" + xml_name + + # attribute + if prop_meta.get("attribute", False) and args[0].get(xml_name) is not None: + existed_attr_keys.append(xml_name) + dict_to_pass[rf._rest_name] = _deserialize(rf._type, args[0].get(xml_name)) + continue + + # unwrapped element is array + if prop_meta.get("unwrapped", False): + # unwrapped array could either use prop items meta/prop meta + if prop_meta.get("itemsName"): + xml_name = prop_meta.get("itemsName") + xml_ns = prop_meta.get("itemNs") + if xml_ns: + xml_name = "{" + xml_ns + "}" + xml_name + items = args[0].findall(xml_name) # pyright: ignore + if len(items) > 0: + existed_attr_keys.append(xml_name) + dict_to_pass[rf._rest_name] = _deserialize(rf._type, items) + continue + + # text element is primitive type + if prop_meta.get("text", False): + if args[0].text is not None: + dict_to_pass[rf._rest_name] = _deserialize(rf._type, args[0].text) + continue + + # wrapped element could be normal property or array, it should only have one element + item = args[0].find(xml_name) + if item is not None: + existed_attr_keys.append(xml_name) + dict_to_pass[rf._rest_name] = _deserialize(rf._type, item) + + # rest thing is additional properties + for e in args[0]: + if e.tag not in existed_attr_keys: + dict_to_pass[e.tag] = _convert_element(e) + else: + dict_to_pass.update( + {k: _create_value(_get_rest_field(self._attr_to_rest_field, k), v) for k, v in args[0].items()} + ) else: non_attr_kwargs = [k for k in kwargs if k not in self._attr_to_rest_field] if non_attr_kwargs: @@ -510,7 +574,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: def copy(self) -> "Model": return Model(self.__dict__) - def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> Self: # pylint: disable=unused-argument + def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> Self: if f"{cls.__module__}.{cls.__qualname__}" not in cls._calculated: # we know the last nine classes in mro are going to be 'Model', '_MyMutableMapping', 'MutableMapping', # 'Mapping', 'Collection', 'Sized', 'Iterable', 'Container' and 'object' @@ -521,8 +585,8 @@ def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> Self: # pylint: di annotations = { k: v for mro_class in mros - if hasattr(mro_class, "__annotations__") # pylint: disable=no-member - for k, v in mro_class.__annotations__.items() # pylint: disable=no-member + if hasattr(mro_class, "__annotations__") + for k, v in mro_class.__annotations__.items() } for attr, rf in attr_to_rest_field.items(): rf._module = cls.__module__ @@ -537,31 +601,43 @@ def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> Self: # pylint: di def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: for base in cls.__bases__: - if hasattr(base, "__mapping__"): # pylint: disable=no-member - base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member + if hasattr(base, "__mapping__"): + base.__mapping__[discriminator or cls.__name__] = cls # type: ignore @classmethod - def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional["_RestField"]: for v in cls.__dict__.values(): - if ( - isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators - ): # pylint: disable=protected-access - return v._rest_name # pylint: disable=protected-access + if isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators: + return v return None @classmethod def _deserialize(cls, data, exist_discriminators): - if not hasattr(cls, "__mapping__"): # pylint: disable=no-member + if not hasattr(cls, "__mapping__"): return cls(data) discriminator = cls._get_discriminator(exist_discriminators) - exist_discriminators.append(discriminator) - mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pyright: ignore # pylint: disable=no-member - if mapped_cls == cls: + if discriminator is None: return cls(data) - return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + exist_discriminators.append(discriminator._rest_name) + if isinstance(data, ET.Element): + model_meta = getattr(cls, "_xml", {}) + prop_meta = getattr(discriminator, "_xml", {}) + xml_name = prop_meta.get("name", discriminator._rest_name) + xml_ns = prop_meta.get("ns", model_meta.get("ns", None)) + if xml_ns: + xml_name = "{" + xml_ns + "}" + xml_name + + if data.get(xml_name) is not None: + discriminator_value = data.get(xml_name) + else: + discriminator_value = data.find(xml_name).text # pyright: ignore + else: + discriminator_value = data.get(discriminator._rest_name) + mapped_cls = cls.__mapping__.get(discriminator_value, cls) # pyright: ignore + return mapped_cls._deserialize(data, exist_discriminators) def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: - """Return a dict that can be JSONify using json.dump. + """Return a dict that can be turned into json using json.dump. :keyword bool exclude_readonly: Whether to remove the readonly properties. :returns: A dict JSON compatible object @@ -624,6 +700,8 @@ def _deserialize_dict( ): if obj is None: return obj + if isinstance(obj, ET.Element): + obj = {child.tag: child for child in obj} return {k: _deserialize(value_deserializer, v, module) for k, v in obj.items()} @@ -644,6 +722,8 @@ def _deserialize_sequence( ): if obj is None: return obj + if isinstance(obj, ET.Element): + obj = list(obj) return type(obj)(_deserialize(deserializer, entry, module) for entry in obj) @@ -654,12 +734,12 @@ def _sorted_annotations(types: typing.List[typing.Any]) -> typing.List[typing.An ) -def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 +def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-branches annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, ) -> typing.Optional[typing.Callable[[typing.Any], typing.Any]]: - if not annotation or annotation in [int, float]: + if not annotation: return None # is it a type alias? @@ -734,7 +814,6 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, try: if annotation._name in ["List", "Set", "Tuple", "Sequence"]: # pyright: ignore if len(annotation.__args__) > 1: # pyright: ignore - entry_deserializers = [ _get_deserialize_callable_from_annotation(dt, module, rf) for dt in annotation.__args__ # pyright: ignore @@ -769,12 +848,23 @@ def _deserialize_default( def _deserialize_with_callable( deserializer: typing.Optional[typing.Callable[[typing.Any], typing.Any]], value: typing.Any, -): +): # pylint: disable=too-many-return-statements try: if value is None or isinstance(value, _Null): return None + if isinstance(value, ET.Element): + if deserializer is str: + return value.text or "" + if deserializer is int: + return int(value.text) if value.text else None + if deserializer is float: + return float(value.text) if value.text else None + if deserializer is bool: + return value.text == "true" if value.text else None if deserializer is None: return value + if deserializer in [int, float, bool]: + return deserializer(value) if isinstance(deserializer, CaseInsensitiveEnumMeta): try: return deserializer(value) @@ -804,6 +894,22 @@ def _deserialize( return _deserialize_with_callable(deserializer, value) +def _failsafe_deserialize( + deserializer: typing.Any, + value: typing.Any, + module: typing.Optional[str] = None, + rf: typing.Optional["_RestField"] = None, + format: typing.Optional[str] = None, +) -> typing.Any: + try: + return _deserialize(deserializer, value, module, rf, format) + except DeserializationError: + _LOGGER.warning( + "Ran into a deserialization error. Ignoring since this is failsafe deserialization", exc_info=True + ) + return None + + class _RestField: def __init__( self, @@ -815,6 +921,7 @@ def __init__( default: typing.Any = _UNSET, format: typing.Optional[str] = None, is_multipart_file_input: bool = False, + xml: typing.Optional[typing.Dict[str, typing.Any]] = None, ): self._type = type self._rest_name_input = name @@ -825,6 +932,7 @@ def __init__( self._default = default self._format = format self._is_multipart_file_input = is_multipart_file_input + self._xml = xml if xml is not None else {} @property def _class_type(self) -> typing.Any: @@ -875,6 +983,7 @@ def rest_field( default: typing.Any = _UNSET, format: typing.Optional[str] = None, is_multipart_file_input: bool = False, + xml: typing.Optional[typing.Dict[str, typing.Any]] = None, ) -> typing.Any: return _RestField( name=name, @@ -883,6 +992,7 @@ def rest_field( default=default, format=format, is_multipart_file_input=is_multipart_file_input, + xml=xml, ) @@ -891,5 +1001,175 @@ def rest_discriminator( name: typing.Optional[str] = None, type: typing.Optional[typing.Callable] = None, # pylint: disable=redefined-builtin visibility: typing.Optional[typing.List[str]] = None, + xml: typing.Optional[typing.Dict[str, typing.Any]] = None, +) -> typing.Any: + return _RestField(name=name, type=type, is_discriminator=True, visibility=visibility, xml=xml) + + +def serialize_xml(model: Model, exclude_readonly: bool = False) -> str: + """Serialize a model to XML. + + :param Model model: The model to serialize. + :param bool exclude_readonly: Whether to exclude readonly properties. + :returns: The XML representation of the model. + :rtype: str + """ + return ET.tostring(_get_element(model, exclude_readonly), encoding="unicode") # type: ignore + + +def _get_element( + o: typing.Any, + exclude_readonly: bool = False, + parent_meta: typing.Optional[typing.Dict[str, typing.Any]] = None, + wrapped_element: typing.Optional[ET.Element] = None, +) -> typing.Union[ET.Element, typing.List[ET.Element]]: + if _is_model(o): + model_meta = getattr(o, "_xml", {}) + + # if prop is a model, then use the prop element directly, else generate a wrapper of model + if wrapped_element is None: + wrapped_element = _create_xml_element( + model_meta.get("name", o.__class__.__name__), + model_meta.get("prefix"), + model_meta.get("ns"), + ) + + readonly_props = [] + if exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + + for k, v in o.items(): + # do not serialize readonly properties + if exclude_readonly and k in readonly_props: + continue + + prop_rest_field = _get_rest_field(o._attr_to_rest_field, k) + if prop_rest_field: + prop_meta = getattr(prop_rest_field, "_xml").copy() + # use the wire name as xml name if no specific name is set + if prop_meta.get("name") is None: + prop_meta["name"] = k + else: + # additional properties will not have rest field, use the wire name as xml name + prop_meta = {"name": k} + + # if no ns for prop, use model's + if prop_meta.get("ns") is None and model_meta.get("ns"): + prop_meta["ns"] = model_meta.get("ns") + prop_meta["prefix"] = model_meta.get("prefix") + + if prop_meta.get("unwrapped", False): + # unwrapped could only set on array + wrapped_element.extend(_get_element(v, exclude_readonly, prop_meta)) + elif prop_meta.get("text", False): + # text could only set on primitive type + wrapped_element.text = _get_primitive_type_value(v) + elif prop_meta.get("attribute", False): + xml_name = prop_meta.get("name", k) + if prop_meta.get("ns"): + ET.register_namespace(prop_meta.get("prefix"), prop_meta.get("ns")) # pyright: ignore + xml_name = "{" + prop_meta.get("ns") + "}" + xml_name # pyright: ignore + # attribute should be primitive type + wrapped_element.set(xml_name, _get_primitive_type_value(v)) + else: + # other wrapped prop element + wrapped_element.append(_get_wrapped_element(v, exclude_readonly, prop_meta)) + return wrapped_element + if isinstance(o, list): + return [_get_element(x, exclude_readonly, parent_meta) for x in o] # type: ignore + if isinstance(o, dict): + result = [] + for k, v in o.items(): + result.append( + _get_wrapped_element( + v, + exclude_readonly, + { + "name": k, + "ns": parent_meta.get("ns") if parent_meta else None, + "prefix": parent_meta.get("prefix") if parent_meta else None, + }, + ) + ) + return result + + # primitive case need to create element based on parent_meta + if parent_meta: + return _get_wrapped_element( + o, + exclude_readonly, + { + "name": parent_meta.get("itemsName", parent_meta.get("name")), + "prefix": parent_meta.get("itemsPrefix", parent_meta.get("prefix")), + "ns": parent_meta.get("itemsNs", parent_meta.get("ns")), + }, + ) + + raise ValueError("Could not serialize value into xml: " + o) + + +def _get_wrapped_element( + v: typing.Any, + exclude_readonly: bool, + meta: typing.Optional[typing.Dict[str, typing.Any]], +) -> ET.Element: + wrapped_element = _create_xml_element( + meta.get("name") if meta else None, meta.get("prefix") if meta else None, meta.get("ns") if meta else None + ) + if isinstance(v, (dict, list)): + wrapped_element.extend(_get_element(v, exclude_readonly, meta)) + elif _is_model(v): + _get_element(v, exclude_readonly, meta, wrapped_element) + else: + wrapped_element.text = _get_primitive_type_value(v) + return wrapped_element + + +def _get_primitive_type_value(v) -> str: + if v is True: + return "true" + if v is False: + return "false" + if isinstance(v, _Null): + return "" + return str(v) + + +def _create_xml_element(tag, prefix=None, ns=None): + if prefix and ns: + ET.register_namespace(prefix, ns) + if ns: + return ET.Element("{" + ns + "}" + tag) + return ET.Element(tag) + + +def _deserialize_xml( + deserializer: typing.Any, + value: str, ) -> typing.Any: - return _RestField(name=name, type=type, is_discriminator=True, visibility=visibility) + element = ET.fromstring(value) # nosec + return _deserialize(deserializer, element) + + +def _convert_element(e: ET.Element): + # dict case + if len(e.attrib) > 0 or len({child.tag for child in e}) > 1: + dict_result: typing.Dict[str, typing.Any] = {} + for child in e: + if dict_result.get(child.tag) is not None: + if isinstance(dict_result[child.tag], list): + dict_result[child.tag].append(_convert_element(child)) + else: + dict_result[child.tag] = [dict_result[child.tag], _convert_element(child)] + else: + dict_result[child.tag] = _convert_element(child) + dict_result.update(e.attrib) + return dict_result + # array case + if len(e) > 0: + array_result: typing.List[typing.Any] = [] + for child in e: + array_result.append(_convert_element(child)) + return array_result + # primitive case + return e.text diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/_operations/__init__.py b/sdk/ai/azure-ai-inference/azure/ai/inference/_operations/__init__.py index d3ebd561f739..ab87088736aa 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/_operations/__init__.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/_operations/__init__.py @@ -5,13 +5,19 @@ # Code generated by Microsoft (R) Python Code Generator. # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- +# pylint: disable=wrong-import-position -from ._operations import ChatCompletionsClientOperationsMixin -from ._operations import EmbeddingsClientOperationsMixin -from ._operations import ImageEmbeddingsClientOperationsMixin +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ._patch import * # pylint: disable=unused-wildcard-import + +from ._operations import ChatCompletionsClientOperationsMixin # type: ignore +from ._operations import EmbeddingsClientOperationsMixin # type: ignore +from ._operations import ImageEmbeddingsClientOperationsMixin # type: ignore from ._patch import __all__ as _patch_all -from ._patch import * # pylint: disable=unused-wildcard-import +from ._patch import * from ._patch import patch_sdk as _patch_sdk __all__ = [ @@ -19,5 +25,5 @@ "EmbeddingsClientOperationsMixin", "ImageEmbeddingsClientOperationsMixin", ] -__all__.extend([p for p in _patch_all if p not in __all__]) +__all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/_operations/_operations.py b/sdk/ai/azure-ai-inference/azure/ai/inference/_operations/_operations.py index 3a24ee5736d3..f603f3615f8e 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/_operations/_operations.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/_operations/_operations.py @@ -1,4 +1,3 @@ -# pylint: disable=too-many-lines,too-many-statements # coding=utf-8 # -------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. @@ -9,7 +8,7 @@ from io import IOBase import json import sys -from typing import Any, Callable, Dict, IO, List, Optional, Type, TypeVar, Union, overload +from typing import Any, Callable, Dict, IO, List, Optional, TypeVar, Union, overload from azure.core.exceptions import ( ClientAuthenticationError, @@ -34,7 +33,7 @@ if sys.version_info >= (3, 9): from collections.abc import MutableMapping else: - from typing import MutableMapping # type: ignore # pylint: disable=ungrouped-imports + from typing import MutableMapping # type: ignore JSON = MutableMapping[str, Any] # pylint: disable=unsubscriptable-object _Unset: Any = object() T = TypeVar("T") @@ -246,7 +245,6 @@ def _complete( model: Optional[str] = None, **kwargs: Any ) -> _models.ChatCompletions: - # pylint: disable=too-many-locals """Gets chat completions for the provided chat messages. Completions support a wide variety of tasks and generate text that continues from or "completes" @@ -335,7 +333,7 @@ def _complete( :rtype: ~azure.ai.inference.models.ChatCompletions :raises ~azure.core.exceptions.HttpResponseError: """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { + error_map: MutableMapping = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError, @@ -425,7 +423,7 @@ def _get_model_info(self, **kwargs: Any) -> _models.ModelInfo: :rtype: ~azure.ai.inference.models.ModelInfo :raises ~azure.core.exceptions.HttpResponseError: """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { + error_map: MutableMapping = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError, @@ -556,7 +554,7 @@ def _embed( :rtype: ~azure.ai.inference.models.EmbeddingsResult :raises ~azure.core.exceptions.HttpResponseError: """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { + error_map: MutableMapping = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError, @@ -638,7 +636,7 @@ def _get_model_info(self, **kwargs: Any) -> _models.ModelInfo: :rtype: ~azure.ai.inference.models.ModelInfo :raises ~azure.core.exceptions.HttpResponseError: """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { + error_map: MutableMapping = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError, @@ -703,7 +701,7 @@ def _embed( def _embed( self, *, - input: List[_models.EmbeddingInput], + input: List[_models.ImageEmbeddingInput], extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, content_type: str = "application/json", dimensions: Optional[int] = None, @@ -727,7 +725,7 @@ def _embed( self, body: Union[JSON, IO[bytes]] = _Unset, *, - input: List[_models.EmbeddingInput] = _Unset, + input: List[_models.ImageEmbeddingInput] = _Unset, extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, dimensions: Optional[int] = None, encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, @@ -743,7 +741,7 @@ def _embed( :keyword input: Input image to embed. To embed multiple inputs in a single request, pass an array. The input must not exceed the max input tokens for the model. Required. - :paramtype input: list[~azure.ai.inference.models.EmbeddingInput] + :paramtype input: list[~azure.ai.inference.models.ImageEmbeddingInput] :keyword extra_params: Controls what happens if extra parameters, undefined by the REST API, are passed in the JSON request payload. This sets the HTTP request header ``extra-parameters``. Known values are: "error", "drop", and @@ -772,7 +770,7 @@ def _embed( :rtype: ~azure.ai.inference.models.EmbeddingsResult :raises ~azure.core.exceptions.HttpResponseError: """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { + error_map: MutableMapping = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError, @@ -854,7 +852,7 @@ def _get_model_info(self, **kwargs: Any) -> _models.ModelInfo: :rtype: ~azure.ai.inference.models.ModelInfo :raises ~azure.core.exceptions.HttpResponseError: """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { + error_map: MutableMapping = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError, diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/_patch.py b/sdk/ai/azure-ai-inference/azure/ai/inference/_patch.py index 8d2ca4a4aaf1..f7dd32510333 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/_patch.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/_patch.py @@ -2,1284 +2,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -# pylint: disable=too-many-lines) """Customize generated code here. Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize - -Why do we patch auto-generated code? -1. Add support for input argument `model_extras` (all clients) -2. Add support for function load_client -3. Add support for setting sticky chat completions/embeddings input arguments in the client constructor -4. Add support for get_model_info, while caching the result (all clients) -5. Add support for chat completion streaming (ChatCompletionsClient client only) -6. Add support for friendly print of result objects (__str__ method) (all clients) -7. Add support for load() method in ImageUrl class (see /models/_patch.py) -8. Add support for sending two auth headers for api-key auth (all clients) - """ -import json -import logging -import sys - -from io import IOBase -from typing import Any, Dict, Union, IO, List, Literal, Optional, overload, Type, TYPE_CHECKING, Iterable - -from azure.core.pipeline import PipelineResponse -from azure.core.credentials import AzureKeyCredential -from azure.core.tracing.decorator import distributed_trace -from azure.core.utils import case_insensitive_dict -from azure.core.exceptions import ( - ClientAuthenticationError, - HttpResponseError, - map_error, - ResourceExistsError, - ResourceNotFoundError, - ResourceNotModifiedError, -) -from . import models as _models -from ._model_base import SdkJSONEncoder, _deserialize -from ._serialization import Serializer -from ._operations._operations import ( - build_chat_completions_complete_request, - build_embeddings_embed_request, - build_image_embeddings_embed_request, -) -from ._client import ChatCompletionsClient as ChatCompletionsClientGenerated -from ._client import EmbeddingsClient as EmbeddingsClientGenerated -from ._client import ImageEmbeddingsClient as ImageEmbeddingsClientGenerated - -if sys.version_info >= (3, 9): - from collections.abc import MutableMapping -else: - from typing import MutableMapping # type: ignore # pylint: disable=ungrouped-imports - -if TYPE_CHECKING: - # pylint: disable=unused-import,ungrouped-imports - from azure.core.credentials import TokenCredential - -JSON = MutableMapping[str, Any] # pylint: disable=unsubscriptable-object -_Unset: Any = object() - -_SERIALIZER = Serializer() -_SERIALIZER.client_side_validation = False - -_LOGGER = logging.getLogger(__name__) - - -def load_client( - endpoint: str, credential: Union[AzureKeyCredential, "TokenCredential"], **kwargs: Any -) -> Union["ChatCompletionsClient", "EmbeddingsClient", "ImageEmbeddingsClient"]: - """ - Load a client from a given endpoint URL. The method makes a REST API call to the `/info` route - on the given endpoint, to determine the model type and therefore which client to instantiate. - This method will only work when using Serverless API or Managed Compute endpoint. - It will not work for GitHub Models endpoint or Azure OpenAI endpoint. - - :param endpoint: Service host. Required. - :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is either a - AzureKeyCredential type or a TokenCredential type. Required. - :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials.TokenCredential - :keyword api_version: The API version to use for this operation. Default value is - "2024-05-01-preview". Note that overriding this default value may result in unsupported - behavior. - :paramtype api_version: str - :return: The appropriate synchronous client associated with the given endpoint - :rtype: ~azure.ai.inference.ChatCompletionsClient or ~azure.ai.inference.EmbeddingsClient - or ~azure.ai.inference.ImageEmbeddingsClient - :raises ~azure.core.exceptions.HttpResponseError: - """ - - with ChatCompletionsClient( - endpoint, credential, **kwargs - ) as client: # Pick any of the clients, it does not matter. - model_info = client.get_model_info() # type: ignore - - _LOGGER.info("model_info=%s", model_info) - if not model_info.model_type: - raise ValueError( - "The AI model information is missing a value for `model type`. Cannot create an appropriate client." - ) - - # TODO: Remove "completions", "chat-comletions" and "embedding" once Mistral Large and Cohere fixes their model type - if model_info.model_type in (_models.ModelType.CHAT, "completion", "chat-completion", "chat-completions"): - chat_completion_client = ChatCompletionsClient(endpoint, credential, **kwargs) - chat_completion_client._model_info = ( # pylint: disable=protected-access,attribute-defined-outside-init - model_info - ) - return chat_completion_client - - if model_info.model_type in (_models.ModelType.EMBEDDINGS, "embedding"): - embedding_client = EmbeddingsClient(endpoint, credential, **kwargs) - embedding_client._model_info = model_info # pylint: disable=protected-access,attribute-defined-outside-init - return embedding_client - - if model_info.model_type == _models.ModelType.IMAGE_EMBEDDINGS: - image_embedding_client = ImageEmbeddingsClient(endpoint, credential, **kwargs) - image_embedding_client._model_info = ( # pylint: disable=protected-access,attribute-defined-outside-init - model_info - ) - return image_embedding_client - - raise ValueError(f"No client available to support AI model type `{model_info.model_type}`") - - -class ChatCompletionsClient(ChatCompletionsClientGenerated): # pylint: disable=too-many-instance-attributes - """ChatCompletionsClient. - - :param endpoint: Service host. Required. - :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is either a - AzureKeyCredential type or a TokenCredential type. Required. - :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials.TokenCredential - :keyword frequency_penalty: A value that influences the probability of generated tokens - appearing based on their cumulative frequency in generated text. - Positive values will make tokens less likely to appear as their frequency increases and - decrease the likelihood of the model repeating the same statements verbatim. - Supported range is [-2, 2]. - Default value is None. - :paramtype frequency_penalty: float - :keyword presence_penalty: A value that influences the probability of generated tokens - appearing based on their existing - presence in generated text. - Positive values will make tokens less likely to appear when they already exist and increase - the model's likelihood to output new topics. - Supported range is [-2, 2]. - Default value is None. - :paramtype presence_penalty: float - :keyword temperature: The sampling temperature to use that controls the apparent creativity of - generated completions. - Higher values will make output more random while lower values will make results more focused - and deterministic. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. - Default value is None. - :paramtype temperature: float - :keyword top_p: An alternative to sampling with temperature called nucleus sampling. This value - causes the - model to consider the results of tokens with the provided probability mass. As an example, a - value of 0.15 will cause only the tokens comprising the top 15% of probability mass to be - considered. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. - Default value is None. - :paramtype top_p: float - :keyword max_tokens: The maximum number of tokens to generate. Default value is None. - :paramtype max_tokens: int - :keyword response_format: The format that the model must output. Use this to enable JSON mode - instead of the default text mode. - Note that to enable JSON mode, some AI models may also require you to instruct the model to - produce JSON via a system or user message. Default value is None. - :paramtype response_format: ~azure.ai.inference.models.ChatCompletionsResponseFormat - :keyword stop: A collection of textual sequences that will end completions generation. Default - value is None. - :paramtype stop: list[str] - :keyword tools: The available tool definitions that the chat completions request can use, - including caller-defined functions. Default value is None. - :paramtype tools: list[~azure.ai.inference.models.ChatCompletionsToolDefinition] - :keyword tool_choice: If specified, the model will configure which of the provided tools it can - use for the chat completions response. Is either a Union[str, - "_models.ChatCompletionsToolChoicePreset"] type or a ChatCompletionsNamedToolChoice type. - Default value is None. - :paramtype tool_choice: str or ~azure.ai.inference.models.ChatCompletionsToolChoicePreset or - ~azure.ai.inference.models.ChatCompletionsNamedToolChoice - :keyword seed: If specified, the system will make a best effort to sample deterministically - such that repeated requests with the - same seed and parameters should return the same result. Determinism is not guaranteed. - Default value is None. - :paramtype seed: int - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :keyword api_version: The API version to use for this operation. Default value is - "2024-05-01-preview". Note that overriding this default value may result in unsupported - behavior. - :paramtype api_version: str - """ - - def __init__( - self, - endpoint: str, - credential: Union[AzureKeyCredential, "TokenCredential"], - *, - frequency_penalty: Optional[float] = None, - presence_penalty: Optional[float] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, - response_format: Optional[_models.ChatCompletionsResponseFormat] = None, - stop: Optional[List[str]] = None, - tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, - tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] - ] = None, - seed: Optional[int] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> None: - - self._model_info: Optional[_models.ModelInfo] = None - - # Store default chat completions settings, to be applied in all future service calls - # unless overridden by arguments in the `complete` method. - self._frequency_penalty = frequency_penalty - self._presence_penalty = presence_penalty - self._temperature = temperature - self._top_p = top_p - self._max_tokens = max_tokens - self._response_format = response_format - self._stop = stop - self._tools = tools - self._tool_choice = tool_choice - self._seed = seed - self._model = model - self._model_extras = model_extras - - # For Key auth, we need to send these two auth HTTP request headers simultaneously: - # 1. "Authorization: Bearer " - # 2. "api-key: " - # This is because Serverless API, Managed Compute and GitHub endpoints support the first header, - # and Azure OpenAI and the new Unified Inference endpoints support the second header. - # The first header will be taken care of by auto-generated code. - # The second one is added here. - if isinstance(credential, AzureKeyCredential): - headers = kwargs.pop("headers", {}) - if "api-key" not in headers: - headers["api-key"] = credential.key - kwargs["headers"] = headers - - super().__init__(endpoint, credential, **kwargs) - - @overload - def complete( - self, - *, - messages: Union[List[_models.ChatRequestMessage], List[Dict[str, Any]]], - stream: Literal[False] = False, - frequency_penalty: Optional[float] = None, - presence_penalty: Optional[float] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, - response_format: Optional[_models.ChatCompletionsResponseFormat] = None, - stop: Optional[List[str]] = None, - tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, - tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] - ] = None, - seed: Optional[int] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> _models.ChatCompletions: ... - - @overload - def complete( - self, - *, - messages: Union[List[_models.ChatRequestMessage], List[Dict[str, Any]]], - stream: Literal[True], - frequency_penalty: Optional[float] = None, - presence_penalty: Optional[float] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, - response_format: Optional[_models.ChatCompletionsResponseFormat] = None, - stop: Optional[List[str]] = None, - tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, - tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] - ] = None, - seed: Optional[int] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> Iterable[_models.StreamingChatCompletionsUpdate]: ... - - @overload - def complete( - self, - *, - messages: Union[List[_models.ChatRequestMessage], List[Dict[str, Any]]], - stream: Optional[bool] = None, - frequency_penalty: Optional[float] = None, - presence_penalty: Optional[float] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, - response_format: Optional[_models.ChatCompletionsResponseFormat] = None, - stop: Optional[List[str]] = None, - tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, - tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] - ] = None, - seed: Optional[int] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: - # pylint: disable=line-too-long - """Gets chat completions for the provided chat messages. - Completions support a wide variety of tasks and generate text that continues from or - "completes" provided prompt data. The method makes a REST API call to the `/chat/completions` route - on the given endpoint. - When using this method with `stream=True`, the response is streamed - back to the client. Iterate over the resulting StreamingChatCompletions - object to get content updates as they arrive. By default, the response is a ChatCompletions object - (non-streaming). - - :keyword messages: The collection of context messages associated with this chat completions - request. - Typical usage begins with a chat message for the System role that provides instructions for - the behavior of the assistant, followed by alternating messages between the User and - Assistant roles. Required. - :paramtype messages: list[~azure.ai.inference.models.ChatRequestMessage] or list[dict[str, Any]] - :keyword stream: A value indicating whether chat completions should be streamed for this request. - Default value is False. If streaming is enabled, the response will be a StreamingChatCompletions. - Otherwise the response will be a ChatCompletions. - :paramtype stream: bool - :keyword frequency_penalty: A value that influences the probability of generated tokens - appearing based on their cumulative frequency in generated text. - Positive values will make tokens less likely to appear as their frequency increases and - decrease the likelihood of the model repeating the same statements verbatim. - Supported range is [-2, 2]. - Default value is None. - :paramtype frequency_penalty: float - :keyword presence_penalty: A value that influences the probability of generated tokens - appearing based on their existing - presence in generated text. - Positive values will make tokens less likely to appear when they already exist and increase - the model's likelihood to output new topics. - Supported range is [-2, 2]. - Default value is None. - :paramtype presence_penalty: float - :keyword temperature: The sampling temperature to use that controls the apparent creativity of - generated completions. - Higher values will make output more random while lower values will make results more focused - and deterministic. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. - Default value is None. - :paramtype temperature: float - :keyword top_p: An alternative to sampling with temperature called nucleus sampling. This value - causes the - model to consider the results of tokens with the provided probability mass. As an example, a - value of 0.15 will cause only the tokens comprising the top 15% of probability mass to be - considered. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. - Default value is None. - :paramtype top_p: float - :keyword max_tokens: The maximum number of tokens to generate. Default value is None. - :paramtype max_tokens: int - :keyword response_format: The format that the model must output. Use this to enable JSON mode - instead of the default text mode. - Note that to enable JSON mode, some AI models may also require you to instruct the model to - produce JSON via a system or user message. Default value is None. - :paramtype response_format: ~azure.ai.inference.models.ChatCompletionsResponseFormat - :keyword stop: A collection of textual sequences that will end completions generation. Default - value is None. - :paramtype stop: list[str] - :keyword tools: The available tool definitions that the chat completions request can use, - including caller-defined functions. Default value is None. - :paramtype tools: list[~azure.ai.inference.models.ChatCompletionsToolDefinition] - :keyword tool_choice: If specified, the model will configure which of the provided tools it can - use for the chat completions response. Is either a Union[str, - "_models.ChatCompletionsToolChoicePreset"] type or a ChatCompletionsNamedToolChoice type. - Default value is None. - :paramtype tool_choice: str or ~azure.ai.inference.models.ChatCompletionsToolChoicePreset or - ~azure.ai.inference.models.ChatCompletionsNamedToolChoice - :keyword seed: If specified, the system will make a best effort to sample deterministically - such that repeated requests with the - same seed and parameters should return the same result. Determinism is not guaranteed. - Default value is None. - :paramtype seed: int - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :return: ChatCompletions for non-streaming, or Iterable[StreamingChatCompletionsUpdate] for streaming. - :rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.StreamingChatCompletions - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - def complete( - self, - body: JSON, - *, - content_type: str = "application/json", - **kwargs: Any, - ) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: - # pylint: disable=line-too-long - """Gets chat completions for the provided chat messages. - Completions support a wide variety of tasks and generate text that continues from or - "completes" provided prompt data. - - :param body: An object of type MutableMapping[str, Any], such as a dictionary, that - specifies the full request payload. Required. - :type body: JSON - :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. - Default value is "application/json". - :paramtype content_type: str - :return: ChatCompletions for non-streaming, or Iterable[StreamingChatCompletionsUpdate] for streaming. - :rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.StreamingChatCompletions - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - def complete( - self, - body: IO[bytes], - *, - content_type: str = "application/json", - **kwargs: Any, - ) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: - # pylint: disable=line-too-long - # pylint: disable=too-many-locals - """Gets chat completions for the provided chat messages. - Completions support a wide variety of tasks and generate text that continues from or - "completes" provided prompt data. - - :param body: Specifies the full request payload. Required. - :type body: IO[bytes] - :keyword content_type: Body Parameter content-type. Content type parameter for binary body. - Default value is "application/json". - :paramtype content_type: str - :return: ChatCompletions for non-streaming, or Iterable[StreamingChatCompletionsUpdate] for streaming. - :rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.StreamingChatCompletions - :raises ~azure.core.exceptions.HttpResponseError: - """ - - # pylint:disable=client-method-missing-tracing-decorator - def complete( - self, - body: Union[JSON, IO[bytes]] = _Unset, - *, - messages: Union[List[_models.ChatRequestMessage], List[Dict[str, Any]]] = _Unset, - stream: Optional[bool] = None, - frequency_penalty: Optional[float] = None, - presence_penalty: Optional[float] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, - response_format: Optional[_models.ChatCompletionsResponseFormat] = None, - stop: Optional[List[str]] = None, - tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, - tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] - ] = None, - seed: Optional[int] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> Union[Iterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: - # pylint: disable=line-too-long - # pylint: disable=too-many-locals - """Gets chat completions for the provided chat messages. - Completions support a wide variety of tasks and generate text that continues from or - "completes" provided prompt data. When using this method with `stream=True`, the response is streamed - back to the client. Iterate over the resulting :class:`~azure.ai.inference.models.StreamingChatCompletions` - object to get content updates as they arrive. - - :param body: Is either a MutableMapping[str, Any] type (like a dictionary) or a IO[bytes] type - that specifies the full request payload. Required. - :type body: JSON or IO[bytes] - :keyword messages: The collection of context messages associated with this chat completions - request. - Typical usage begins with a chat message for the System role that provides instructions for - the behavior of the assistant, followed by alternating messages between the User and - Assistant roles. Required. - :paramtype messages: list[~azure.ai.inference.models.ChatRequestMessage] or list[dict[str, Any]] - :keyword stream: A value indicating whether chat completions should be streamed for this request. - Default value is False. If streaming is enabled, the response will be a StreamingChatCompletions. - Otherwise the response will be a ChatCompletions. - :paramtype stream: bool - :keyword frequency_penalty: A value that influences the probability of generated tokens - appearing based on their cumulative frequency in generated text. - Positive values will make tokens less likely to appear as their frequency increases and - decrease the likelihood of the model repeating the same statements verbatim. - Supported range is [-2, 2]. - Default value is None. - :paramtype frequency_penalty: float - :keyword presence_penalty: A value that influences the probability of generated tokens - appearing based on their existing - presence in generated text. - Positive values will make tokens less likely to appear when they already exist and increase - the model's likelihood to output new topics. - Supported range is [-2, 2]. - Default value is None. - :paramtype presence_penalty: float - :keyword temperature: The sampling temperature to use that controls the apparent creativity of - generated completions. - Higher values will make output more random while lower values will make results more focused - and deterministic. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. - Default value is None. - :paramtype temperature: float - :keyword top_p: An alternative to sampling with temperature called nucleus sampling. This value - causes the - model to consider the results of tokens with the provided probability mass. As an example, a - value of 0.15 will cause only the tokens comprising the top 15% of probability mass to be - considered. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. - Default value is None. - :paramtype top_p: float - :keyword max_tokens: The maximum number of tokens to generate. Default value is None. - :paramtype max_tokens: int - :keyword response_format: The format that the model must output. Use this to enable JSON mode - instead of the default text mode. - Note that to enable JSON mode, some AI models may also require you to instruct the model to - produce JSON via a system or user message. Default value is None. - :paramtype response_format: ~azure.ai.inference.models.ChatCompletionsResponseFormat - :keyword stop: A collection of textual sequences that will end completions generation. Default - value is None. - :paramtype stop: list[str] - :keyword tools: The available tool definitions that the chat completions request can use, - including caller-defined functions. Default value is None. - :paramtype tools: list[~azure.ai.inference.models.ChatCompletionsToolDefinition] - :keyword tool_choice: If specified, the model will configure which of the provided tools it can - use for the chat completions response. Is either a Union[str, - "_models.ChatCompletionsToolChoicePreset"] type or a ChatCompletionsNamedToolChoice type. - Default value is None. - :paramtype tool_choice: str or ~azure.ai.inference.models.ChatCompletionsToolChoicePreset or - ~azure.ai.inference.models.ChatCompletionsNamedToolChoice - :keyword seed: If specified, the system will make a best effort to sample deterministically - such that repeated requests with the - same seed and parameters should return the same result. Determinism is not guaranteed. - Default value is None. - :paramtype seed: int - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :return: ChatCompletions for non-streaming, or Iterable[StreamingChatCompletionsUpdate] for streaming. - :rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.StreamingChatCompletions - :raises ~azure.core.exceptions.HttpResponseError: - """ - error_map = { - 401: ClientAuthenticationError, - 404: ResourceNotFoundError, - 409: ResourceExistsError, - 304: ResourceNotModifiedError, - } - error_map.update(kwargs.pop("error_map", {}) or {}) - - _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) - _params = kwargs.pop("params", {}) or {} - _extra_parameters: Union[_models._enums.ExtraParameters, None] = None - - content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - - if body is _Unset: - if messages is _Unset: - raise TypeError("missing required argument: messages") - body = { - "messages": messages, - "stream": stream, - "frequency_penalty": frequency_penalty if frequency_penalty is not None else self._frequency_penalty, - "max_tokens": max_tokens if max_tokens is not None else self._max_tokens, - "model": model if model is not None else self._model, - "presence_penalty": presence_penalty if presence_penalty is not None else self._presence_penalty, - "response_format": response_format if response_format is not None else self._response_format, - "seed": seed if seed is not None else self._seed, - "stop": stop if stop is not None else self._stop, - "temperature": temperature if temperature is not None else self._temperature, - "tool_choice": tool_choice if tool_choice is not None else self._tool_choice, - "tools": tools if tools is not None else self._tools, - "top_p": top_p if top_p is not None else self._top_p, - } - if model_extras is not None and bool(model_extras): - body.update(model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access - elif self._model_extras is not None and bool(self._model_extras): - body.update(self._model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access - body = {k: v for k, v in body.items() if v is not None} - elif isinstance(body, dict) and "stream" in body and isinstance(body["stream"], bool): - stream = body["stream"] - content_type = content_type or "application/json" - _content = None - if isinstance(body, (IOBase, bytes)): - _content = body - else: - _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore - - _request = build_chat_completions_complete_request( - extra_params=_extra_parameters, - content_type=content_type, - api_version=self._config.api_version, - content=_content, - headers=_headers, - params=_params, - ) - path_format_arguments = { - "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), - } - _request.url = self._client.format_url(_request.url, **path_format_arguments) - - _stream = stream or False - pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access - _request, stream=_stream, **kwargs - ) - - response = pipeline_response.http_response - - if response.status_code not in [200]: - if _stream: - response.read() # Load the body in memory and close the socket - map_error(status_code=response.status_code, response=response, error_map=error_map) - raise HttpResponseError(response=response) - - if _stream: - return _models.StreamingChatCompletions(response) - - return _deserialize(_models._patch.ChatCompletions, response.json()) # pylint: disable=protected-access - - @distributed_trace - def get_model_info(self, **kwargs: Any) -> _models.ModelInfo: - # pylint: disable=line-too-long - """Returns information about the AI model. - The method makes a REST API call to the ``/info`` route on the given endpoint. - This method will only work when using Serverless API or Managed Compute endpoint. - It will not work for GitHub Models endpoint or Azure OpenAI endpoint. - - :return: ModelInfo. The ModelInfo is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.ModelInfo - :raises ~azure.core.exceptions.HttpResponseError: - """ - if not self._model_info: - self._model_info = self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init - return self._model_info - - def __str__(self) -> str: - # pylint: disable=client-method-name-no-double-underscore - return super().__str__() + f"\n{self._model_info}" if self._model_info else super().__str__() - - -class EmbeddingsClient(EmbeddingsClientGenerated): - """EmbeddingsClient. - - :param endpoint: Service host. Required. - :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is either a - AzureKeyCredential type or a TokenCredential type. Required. - :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials.TokenCredential - :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should - have. Default value is None. - :paramtype dimensions: int - :keyword encoding_format: Optional. The desired format for the returned embeddings. - Known values are: - "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. - :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat - :keyword input_type: Optional. The type of the input. Known values are: - "text", "query", and "document". Default value is None. - :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :keyword api_version: The API version to use for this operation. Default value is - "2024-05-01-preview". Note that overriding this default value may result in unsupported - behavior. - :paramtype api_version: str - """ - - def __init__( - self, - endpoint: str, - credential: Union[AzureKeyCredential, "TokenCredential"], - *, - dimensions: Optional[int] = None, - encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, - input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> None: - - self._model_info: Optional[_models.ModelInfo] = None - - # Store default embeddings settings, to be applied in all future service calls - # unless overridden by arguments in the `embed` method. - self._dimensions = dimensions - self._encoding_format = encoding_format - self._input_type = input_type - self._model = model - self._model_extras = model_extras - - # For Key auth, we need to send these two auth HTTP request headers simultaneously: - # 1. "Authorization: Bearer " - # 2. "api-key: " - # This is because Serverless API, Managed Compute and GitHub endpoints support the first header, - # and Azure OpenAI and the new Unified Inference endpoints support the second header. - # The first header will be taken care of by auto-generated code. - # The second one is added here. - if isinstance(credential, AzureKeyCredential): - headers = kwargs.pop("headers", {}) - if "api-key" not in headers: - headers["api-key"] = credential.key - kwargs["headers"] = headers - - super().__init__(endpoint, credential, **kwargs) - - @overload - def embed( - self, - *, - input: List[str], - dimensions: Optional[int] = None, - encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, - input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> _models.EmbeddingsResult: - """Return the embedding vectors for given text prompts. - The method makes a REST API call to the `/embeddings` route on the given endpoint. - - :keyword input: Input text to embed, encoded as a string or array of tokens. - To embed multiple inputs in a single request, pass an array - of strings or array of token arrays. Required. - :paramtype input: list[str] - :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should - have. Default value is None. - :paramtype dimensions: int - :keyword encoding_format: Optional. The desired format for the returned embeddings. - Known values are: - "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. - :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat - :keyword input_type: Optional. The type of the input. Known values are: - "text", "query", and "document". Default value is None. - :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - def embed( - self, - body: JSON, - *, - content_type: str = "application/json", - **kwargs: Any, - ) -> _models.EmbeddingsResult: - """Return the embedding vectors for given text prompts. - The method makes a REST API call to the `/embeddings` route on the given endpoint. - - :param body: An object of type MutableMapping[str, Any], such as a dictionary, that - specifies the full request payload. Required. - :type body: JSON - :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. - Default value is "application/json". - :paramtype content_type: str - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - def embed( - self, - body: IO[bytes], - *, - content_type: str = "application/json", - **kwargs: Any, - ) -> _models.EmbeddingsResult: - """Return the embedding vectors for given text prompts. - The method makes a REST API call to the `/embeddings` route on the given endpoint. - - :param body: Specifies the full request payload. Required. - :type body: IO[bytes] - :keyword content_type: Body Parameter content-type. Content type parameter for binary body. - Default value is "application/json". - :paramtype content_type: str - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @distributed_trace - def embed( - self, - body: Union[JSON, IO[bytes]] = _Unset, - *, - input: List[str] = _Unset, - dimensions: Optional[int] = None, - encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, - input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> _models.EmbeddingsResult: - # pylint: disable=line-too-long - """Return the embedding vectors for given text prompts. - The method makes a REST API call to the `/embeddings` route on the given endpoint. - - :param body: Is either a MutableMapping[str, Any] type (like a dictionary) or a IO[bytes] type - that specifies the full request payload. Required. - :type body: JSON or IO[bytes] - :keyword input: Input text to embed, encoded as a string or array of tokens. - To embed multiple inputs in a single request, pass an array - of strings or array of token arrays. Required. - :paramtype input: list[str] - :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should - have. Default value is None. - :paramtype dimensions: int - :keyword encoding_format: Optional. The desired format for the returned embeddings. - Known values are: - "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. - :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat - :keyword input_type: Optional. The type of the input. Known values are: - "text", "query", and "document". Default value is None. - :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { - 401: ClientAuthenticationError, - 404: ResourceNotFoundError, - 409: ResourceExistsError, - 304: ResourceNotModifiedError, - } - error_map.update(kwargs.pop("error_map", {}) or {}) - - _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) - _params = kwargs.pop("params", {}) or {} - _extra_parameters: Union[_models._enums.ExtraParameters, None] = None - - content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - - if body is _Unset: - if input is _Unset: - raise TypeError("missing required argument: input") - body = { - "input": input, - "dimensions": dimensions if dimensions is not None else self._dimensions, - "encoding_format": encoding_format if encoding_format is not None else self._encoding_format, - "input_type": input_type if input_type is not None else self._input_type, - "model": model if model is not None else self._model, - } - if model_extras is not None and bool(model_extras): - body.update(model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access - elif self._model_extras is not None and bool(self._model_extras): - body.update(self._model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access - body = {k: v for k, v in body.items() if v is not None} - content_type = content_type or "application/json" - _content = None - if isinstance(body, (IOBase, bytes)): - _content = body - else: - _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore - - _request = build_embeddings_embed_request( - extra_params=_extra_parameters, - content_type=content_type, - api_version=self._config.api_version, - content=_content, - headers=_headers, - params=_params, - ) - path_format_arguments = { - "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), - } - _request.url = self._client.format_url(_request.url, **path_format_arguments) - - _stream = kwargs.pop("stream", False) - pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access - _request, stream=_stream, **kwargs - ) - - response = pipeline_response.http_response - - if response.status_code not in [200]: - if _stream: - response.read() # Load the body in memory and close the socket - map_error(status_code=response.status_code, response=response, error_map=error_map) - raise HttpResponseError(response=response) - - if _stream: - deserialized = response.iter_bytes() - else: - deserialized = _deserialize( - _models._patch.EmbeddingsResult, response.json() # pylint: disable=protected-access - ) - - return deserialized # type: ignore - - @distributed_trace - def get_model_info(self, **kwargs: Any) -> _models.ModelInfo: - # pylint: disable=line-too-long - """Returns information about the AI model. - The method makes a REST API call to the ``/info`` route on the given endpoint. - This method will only work when using Serverless API or Managed Compute endpoint. - It will not work for GitHub Models endpoint or Azure OpenAI endpoint. - - :return: ModelInfo. The ModelInfo is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.ModelInfo - :raises ~azure.core.exceptions.HttpResponseError: - """ - if not self._model_info: - self._model_info = self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init - return self._model_info - - def __str__(self) -> str: - # pylint: disable=client-method-name-no-double-underscore - return super().__str__() + f"\n{self._model_info}" if self._model_info else super().__str__() - - -class ImageEmbeddingsClient(ImageEmbeddingsClientGenerated): - """ImageEmbeddingsClient. - - :param endpoint: Service host. Required. - :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is either a - AzureKeyCredential type or a TokenCredential type. Required. - :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials.TokenCredential - :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should - have. Default value is None. - :paramtype dimensions: int - :keyword encoding_format: Optional. The desired format for the returned embeddings. - Known values are: - "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. - :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat - :keyword input_type: Optional. The type of the input. Known values are: - "text", "query", and "document". Default value is None. - :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :keyword api_version: The API version to use for this operation. Default value is - "2024-05-01-preview". Note that overriding this default value may result in unsupported - behavior. - :paramtype api_version: str - """ - - def __init__( - self, - endpoint: str, - credential: Union[AzureKeyCredential, "TokenCredential"], - *, - dimensions: Optional[int] = None, - encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, - input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> None: - - self._model_info: Optional[_models.ModelInfo] = None - - # Store default embeddings settings, to be applied in all future service calls - # unless overridden by arguments in the `embed` method. - self._dimensions = dimensions - self._encoding_format = encoding_format - self._input_type = input_type - self._model = model - self._model_extras = model_extras - - # For Key auth, we need to send these two auth HTTP request headers simultaneously: - # 1. "Authorization: Bearer " - # 2. "api-key: " - # This is because Serverless API, Managed Compute and GitHub endpoints support the first header, - # and Azure OpenAI and the new Unified Inference endpoints support the second header. - # The first header will be taken care of by auto-generated code. - # The second one is added here. - if isinstance(credential, AzureKeyCredential): - headers = kwargs.pop("headers", {}) - if "api-key" not in headers: - headers["api-key"] = credential.key - kwargs["headers"] = headers - - super().__init__(endpoint, credential, **kwargs) - - @overload - def embed( - self, - *, - input: List[_models.EmbeddingInput], - dimensions: Optional[int] = None, - encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, - input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> _models.EmbeddingsResult: - """Return the embedding vectors for given images. - The method makes a REST API call to the `/images/embeddings` route on the given endpoint. - - :keyword input: Input image to embed. To embed multiple inputs in a single request, pass an - array. - The input must not exceed the max input tokens for the model. Required. - :paramtype input: list[~azure.ai.inference.models.EmbeddingInput] - :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should - have. Default value is None. - :paramtype dimensions: int - :keyword encoding_format: Optional. The desired format for the returned embeddings. - Known values are: - "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. - :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat - :keyword input_type: Optional. The type of the input. Known values are: - "text", "query", and "document". Default value is None. - :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - def embed( - self, - body: JSON, - *, - content_type: str = "application/json", - **kwargs: Any, - ) -> _models.EmbeddingsResult: - """Return the embedding vectors for given images. - The method makes a REST API call to the `/images/embeddings` route on the given endpoint. - - :param body: An object of type MutableMapping[str, Any], such as a dictionary, that - specifies the full request payload. Required. - :type body: JSON - :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. - Default value is "application/json". - :paramtype content_type: str - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - def embed( - self, - body: IO[bytes], - *, - content_type: str = "application/json", - **kwargs: Any, - ) -> _models.EmbeddingsResult: - """Return the embedding vectors for given images. - The method makes a REST API call to the `/images/embeddings` route on the given endpoint. - - :param body: Specifies the full request payload. Required. - :type body: IO[bytes] - :keyword content_type: Body Parameter content-type. Content type parameter for binary body. - Default value is "application/json". - :paramtype content_type: str - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @distributed_trace - def embed( - self, - body: Union[JSON, IO[bytes]] = _Unset, - *, - input: List[_models.EmbeddingInput] = _Unset, - dimensions: Optional[int] = None, - encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, - input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> _models.EmbeddingsResult: - # pylint: disable=line-too-long - """Return the embedding vectors for given images. - The method makes a REST API call to the `/images/embeddings` route on the given endpoint. - - :param body: Is either a MutableMapping[str, Any] type (like a dictionary) or a IO[bytes] type - that specifies the full request payload. Required. - :type body: JSON or IO[bytes] - :keyword input: Input image to embed. To embed multiple inputs in a single request, pass an - array. - The input must not exceed the max input tokens for the model. Required. - :paramtype input: list[~azure.ai.inference.models.EmbeddingInput] - :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should - have. Default value is None. - :paramtype dimensions: int - :keyword encoding_format: Optional. The desired format for the returned embeddings. - Known values are: - "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. - :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat - :keyword input_type: Optional. The type of the input. Known values are: - "text", "query", and "document". Default value is None. - :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { - 401: ClientAuthenticationError, - 404: ResourceNotFoundError, - 409: ResourceExistsError, - 304: ResourceNotModifiedError, - } - error_map.update(kwargs.pop("error_map", {}) or {}) - - _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) - _params = kwargs.pop("params", {}) or {} - _extra_parameters: Union[_models._enums.ExtraParameters, None] = None - - content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - - if body is _Unset: - if input is _Unset: - raise TypeError("missing required argument: input") - body = { - "input": input, - "dimensions": dimensions if dimensions is not None else self._dimensions, - "encoding_format": encoding_format if encoding_format is not None else self._encoding_format, - "input_type": input_type if input_type is not None else self._input_type, - "model": model if model is not None else self._model, - } - if model_extras is not None and bool(model_extras): - body.update(model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access - elif self._model_extras is not None and bool(self._model_extras): - body.update(self._model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access - body = {k: v for k, v in body.items() if v is not None} - content_type = content_type or "application/json" - _content = None - if isinstance(body, (IOBase, bytes)): - _content = body - else: - _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore - - _request = build_image_embeddings_embed_request( - extra_params=_extra_parameters, - content_type=content_type, - api_version=self._config.api_version, - content=_content, - headers=_headers, - params=_params, - ) - path_format_arguments = { - "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), - } - _request.url = self._client.format_url(_request.url, **path_format_arguments) - - _stream = kwargs.pop("stream", False) - pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access - _request, stream=_stream, **kwargs - ) - - response = pipeline_response.http_response - - if response.status_code not in [200]: - if _stream: - response.read() # Load the body in memory and close the socket - map_error(status_code=response.status_code, response=response, error_map=error_map) - raise HttpResponseError(response=response) - - if _stream: - deserialized = response.iter_bytes() - else: - deserialized = _deserialize( - _models._patch.EmbeddingsResult, response.json() # pylint: disable=protected-access - ) - - return deserialized # type: ignore - - @distributed_trace - def get_model_info(self, **kwargs: Any) -> _models.ModelInfo: - # pylint: disable=line-too-long - """Returns information about the AI model. - The method makes a REST API call to the ``/info`` route on the given endpoint. - This method will only work when using Serverless API or Managed Compute endpoint. - It will not work for GitHub Models endpoint or Azure OpenAI endpoint. - - :return: ModelInfo. The ModelInfo is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.ModelInfo - :raises ~azure.core.exceptions.HttpResponseError: - """ - if not self._model_info: - self._model_info = self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init - return self._model_info - - def __str__(self) -> str: - # pylint: disable=client-method-name-no-double-underscore - return super().__str__() + f"\n{self._model_info}" if self._model_info else super().__str__() - +from typing import List -__all__: List[str] = [ - "load_client", - "ChatCompletionsClient", - "EmbeddingsClient", - "ImageEmbeddingsClient", -] # Add all objects you want publicly available to users at this package level +__all__: List[str] = [] # Add all objects you want publicly available to users at this package level def patch_sdk(): diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/_serialization.py b/sdk/ai/azure-ai-inference/azure/ai/inference/_serialization.py index 8139854b97bb..b24ab2885450 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/_serialization.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/_serialization.py @@ -1,3 +1,4 @@ +# pylint: disable=too-many-lines # -------------------------------------------------------------------------- # # Copyright (c) Microsoft Corporation. All rights reserved. @@ -24,7 +25,6 @@ # # -------------------------------------------------------------------------- -# pylint: skip-file # pyright: reportUnnecessaryTypeIgnoreComment=false from base64 import b64decode, b64encode @@ -52,7 +52,6 @@ MutableMapping, Type, List, - Mapping, ) try: @@ -91,6 +90,8 @@ def deserialize_from_text(cls, data: Optional[Union[AnyStr, IO]], content_type: :param data: Input, could be bytes or stream (will be decoded with UTF8) or text :type data: str or bytes or IO :param str content_type: The content type. + :return: The deserialized data. + :rtype: object """ if hasattr(data, "read"): # Assume a stream @@ -112,7 +113,7 @@ def deserialize_from_text(cls, data: Optional[Union[AnyStr, IO]], content_type: try: return json.loads(data_as_str) except ValueError as err: - raise DeserializationError("JSON is invalid: {}".format(err), err) + raise DeserializationError("JSON is invalid: {}".format(err), err) from err elif "xml" in (content_type or []): try: @@ -155,6 +156,11 @@ def deserialize_from_http_generics(cls, body_bytes: Optional[Union[AnyStr, IO]], Use bytes and headers to NOT use any requests/aiohttp or whatever specific implementation. Headers will tested for "content-type" + + :param bytes body_bytes: The body of the response. + :param dict headers: The headers of the response. + :returns: The deserialized data. + :rtype: object """ # Try to use content-type from headers if available content_type = None @@ -184,15 +190,30 @@ class UTC(datetime.tzinfo): """Time Zone info for handling UTC""" def utcoffset(self, dt): - """UTF offset for UTC is 0.""" + """UTF offset for UTC is 0. + + :param datetime.datetime dt: The datetime + :returns: The offset + :rtype: datetime.timedelta + """ return datetime.timedelta(0) def tzname(self, dt): - """Timestamp representation.""" + """Timestamp representation. + + :param datetime.datetime dt: The datetime + :returns: The timestamp representation + :rtype: str + """ return "Z" def dst(self, dt): - """No daylight saving for UTC.""" + """No daylight saving for UTC. + + :param datetime.datetime dt: The datetime + :returns: The daylight saving time + :rtype: datetime.timedelta + """ return datetime.timedelta(hours=1) @@ -206,7 +227,7 @@ class _FixedOffset(datetime.tzinfo): # type: ignore :param datetime.timedelta offset: offset in timedelta format """ - def __init__(self, offset): + def __init__(self, offset) -> None: self.__offset = offset def utcoffset(self, dt): @@ -235,24 +256,26 @@ def __getinitargs__(self): _FLATTEN = re.compile(r"(? None: self.additional_properties: Optional[Dict[str, Any]] = {} - for k in kwargs: + for k in kwargs: # pylint: disable=consider-using-dict-items if k not in self._attribute_map: _LOGGER.warning("%s is not a known attribute of class %s and will be ignored", k, self.__class__) elif k in self._validation and self._validation[k].get("readonly", False): @@ -300,13 +330,23 @@ def __init__(self, **kwargs: Any) -> None: setattr(self, k, kwargs[k]) def __eq__(self, other: Any) -> bool: - """Compare objects by comparing all attributes.""" + """Compare objects by comparing all attributes. + + :param object other: The object to compare + :returns: True if objects are equal + :rtype: bool + """ if isinstance(other, self.__class__): return self.__dict__ == other.__dict__ return False def __ne__(self, other: Any) -> bool: - """Compare objects by comparing all attributes.""" + """Compare objects by comparing all attributes. + + :param object other: The object to compare + :returns: True if objects are not equal + :rtype: bool + """ return not self.__eq__(other) def __str__(self) -> str: @@ -326,7 +366,11 @@ def is_xml_model(cls) -> bool: @classmethod def _create_xml_node(cls): - """Create XML node.""" + """Create XML node. + + :returns: The XML node + :rtype: xml.etree.ElementTree.Element + """ try: xml_map = cls._xml_map # type: ignore except AttributeError: @@ -346,7 +390,9 @@ def serialize(self, keep_readonly: bool = False, **kwargs: Any) -> JSON: :rtype: dict """ serializer = Serializer(self._infer_class_models()) - return serializer._serialize(self, keep_readonly=keep_readonly, **kwargs) # type: ignore + return serializer._serialize( # type: ignore # pylint: disable=protected-access + self, keep_readonly=keep_readonly, **kwargs + ) def as_dict( self, @@ -380,12 +426,15 @@ def my_key_transformer(key, attr_desc, value): If you want XML serialization, you can pass the kwargs is_xml=True. + :param bool keep_readonly: If you want to serialize the readonly attributes :param function key_transformer: A key transformer function. :returns: A dict JSON compatible object :rtype: dict """ serializer = Serializer(self._infer_class_models()) - return serializer._serialize(self, key_transformer=key_transformer, keep_readonly=keep_readonly, **kwargs) # type: ignore + return serializer._serialize( # type: ignore # pylint: disable=protected-access + self, key_transformer=key_transformer, keep_readonly=keep_readonly, **kwargs + ) @classmethod def _infer_class_models(cls): @@ -395,7 +444,7 @@ def _infer_class_models(cls): client_models = {k: v for k, v in models.__dict__.items() if isinstance(v, type)} if cls.__name__ not in client_models: raise ValueError("Not Autorest generated code") - except Exception: + except Exception: # pylint: disable=broad-exception-caught # Assume it's not Autorest generated (tests?). Add ourselves as dependencies. client_models = {cls.__name__: cls} return client_models @@ -408,6 +457,7 @@ def deserialize(cls: Type[ModelType], data: Any, content_type: Optional[str] = N :param str content_type: JSON by default, set application/xml if XML. :returns: An instance of this model :raises: DeserializationError if something went wrong + :rtype: ModelType """ deserializer = Deserializer(cls._infer_class_models()) return deserializer(cls.__name__, data, content_type=content_type) # type: ignore @@ -426,9 +476,11 @@ def from_dict( and last_rest_key_case_insensitive_extractor) :param dict data: A dict using RestAPI structure + :param function key_extractors: A key extractor function. :param str content_type: JSON by default, set application/xml if XML. :returns: An instance of this model :raises: DeserializationError if something went wrong + :rtype: ModelType """ deserializer = Deserializer(cls._infer_class_models()) deserializer.key_extractors = ( # type: ignore @@ -448,21 +500,25 @@ def _flatten_subtype(cls, key, objects): return {} result = dict(cls._subtype_map[key]) for valuetype in cls._subtype_map[key].values(): - result.update(objects[valuetype]._flatten_subtype(key, objects)) + result.update(objects[valuetype]._flatten_subtype(key, objects)) # pylint: disable=protected-access return result @classmethod def _classify(cls, response, objects): """Check the class _subtype_map for any child classes. We want to ignore any inherited _subtype_maps. - Remove the polymorphic key from the initial data. + + :param dict response: The initial data + :param dict objects: The class objects + :returns: The class to be used + :rtype: class """ for subtype_key in cls.__dict__.get("_subtype_map", {}).keys(): subtype_value = None if not isinstance(response, ET.Element): rest_api_response_key = cls._get_rest_key_parts(subtype_key)[-1] - subtype_value = response.pop(rest_api_response_key, None) or response.pop(subtype_key, None) + subtype_value = response.get(rest_api_response_key, None) or response.get(subtype_key, None) else: subtype_value = xml_key_extractor(subtype_key, cls._attribute_map[subtype_key], response) if subtype_value: @@ -501,11 +557,13 @@ def _decode_attribute_map_key(key): inside the received data. :param str key: A key string from the generated code + :returns: The decoded key + :rtype: str """ return key.replace("\\.", ".") -class Serializer(object): +class Serializer: # pylint: disable=too-many-public-methods """Request object model serializer.""" basic_types = {str: "str", int: "int", bool: "bool", float: "float"} @@ -540,7 +598,7 @@ class Serializer(object): "multiple": lambda x, y: x % y != 0, } - def __init__(self, classes: Optional[Mapping[str, type]] = None): + def __init__(self, classes: Optional[Mapping[str, type]] = None) -> None: self.serialize_type = { "iso-8601": Serializer.serialize_iso, "rfc-1123": Serializer.serialize_rfc, @@ -560,13 +618,16 @@ def __init__(self, classes: Optional[Mapping[str, type]] = None): self.key_transformer = full_restapi_key_transformer self.client_side_validation = True - def _serialize(self, target_obj, data_type=None, **kwargs): + def _serialize( # pylint: disable=too-many-nested-blocks, too-many-branches, too-many-statements, too-many-locals + self, target_obj, data_type=None, **kwargs + ): """Serialize data into a string according to type. - :param target_obj: The data to be serialized. + :param object target_obj: The data to be serialized. :param str data_type: The type to be serialized from. :rtype: str, dict :raises: SerializationError if serialization fails. + :returns: The serialized data. """ key_transformer = kwargs.get("key_transformer", self.key_transformer) keep_readonly = kwargs.get("keep_readonly", False) @@ -592,12 +653,14 @@ def _serialize(self, target_obj, data_type=None, **kwargs): serialized = {} if is_xml_model_serialization: - serialized = target_obj._create_xml_node() + serialized = target_obj._create_xml_node() # pylint: disable=protected-access try: - attributes = target_obj._attribute_map + attributes = target_obj._attribute_map # pylint: disable=protected-access for attr, attr_desc in attributes.items(): attr_name = attr - if not keep_readonly and target_obj._validation.get(attr_name, {}).get("readonly", False): + if not keep_readonly and target_obj._validation.get( # pylint: disable=protected-access + attr_name, {} + ).get("readonly", False): continue if attr_name == "additional_properties" and attr_desc["key"] == "": @@ -633,7 +696,8 @@ def _serialize(self, target_obj, data_type=None, **kwargs): if isinstance(new_attr, list): serialized.extend(new_attr) # type: ignore elif isinstance(new_attr, ET.Element): - # If the down XML has no XML/Name, we MUST replace the tag with the local tag. But keeping the namespaces. + # If the down XML has no XML/Name, + # we MUST replace the tag with the local tag. But keeping the namespaces. if "name" not in getattr(orig_attr, "_xml_map", {}): splitted_tag = new_attr.tag.split("}") if len(splitted_tag) == 2: # Namespace @@ -664,17 +728,17 @@ def _serialize(self, target_obj, data_type=None, **kwargs): except (AttributeError, KeyError, TypeError) as err: msg = "Attribute {} in object {} cannot be serialized.\n{}".format(attr_name, class_name, str(target_obj)) raise SerializationError(msg) from err - else: - return serialized + return serialized def body(self, data, data_type, **kwargs): """Serialize data intended for a request body. - :param data: The data to be serialized. + :param object data: The data to be serialized. :param str data_type: The type to be serialized from. :rtype: dict :raises: SerializationError if serialization fails. :raises: ValueError if data is None + :returns: The serialized request body """ # Just in case this is a dict @@ -703,7 +767,7 @@ def body(self, data, data_type, **kwargs): attribute_key_case_insensitive_extractor, last_rest_key_case_insensitive_extractor, ] - data = deserializer._deserialize(data_type, data) + data = deserializer._deserialize(data_type, data) # pylint: disable=protected-access except DeserializationError as err: raise SerializationError("Unable to build a model: " + str(err)) from err @@ -712,9 +776,11 @@ def body(self, data, data_type, **kwargs): def url(self, name, data, data_type, **kwargs): """Serialize data intended for a URL path. - :param data: The data to be serialized. + :param str name: The name of the URL path parameter. + :param object data: The data to be serialized. :param str data_type: The type to be serialized from. :rtype: str + :returns: The serialized URL path :raises: TypeError if serialization fails. :raises: ValueError if data is None """ @@ -728,21 +794,20 @@ def url(self, name, data, data_type, **kwargs): output = output.replace("{", quote("{")).replace("}", quote("}")) else: output = quote(str(output), safe="") - except SerializationError: - raise TypeError("{} must be type {}.".format(name, data_type)) - else: - return output + except SerializationError as exc: + raise TypeError("{} must be type {}.".format(name, data_type)) from exc + return output def query(self, name, data, data_type, **kwargs): """Serialize data intended for a URL query. - :param data: The data to be serialized. + :param str name: The name of the query parameter. + :param object data: The data to be serialized. :param str data_type: The type to be serialized from. - :keyword bool skip_quote: Whether to skip quote the serialized result. - Defaults to False. :rtype: str, list :raises: TypeError if serialization fails. :raises: ValueError if data is None + :returns: The serialized query parameter """ try: # Treat the list aside, since we don't want to encode the div separator @@ -759,19 +824,20 @@ def query(self, name, data, data_type, **kwargs): output = str(output) else: output = quote(str(output), safe="") - except SerializationError: - raise TypeError("{} must be type {}.".format(name, data_type)) - else: - return str(output) + except SerializationError as exc: + raise TypeError("{} must be type {}.".format(name, data_type)) from exc + return str(output) def header(self, name, data, data_type, **kwargs): """Serialize data intended for a request header. - :param data: The data to be serialized. + :param str name: The name of the header. + :param object data: The data to be serialized. :param str data_type: The type to be serialized from. :rtype: str :raises: TypeError if serialization fails. :raises: ValueError if data is None + :returns: The serialized header """ try: if data_type in ["[str]"]: @@ -780,21 +846,20 @@ def header(self, name, data, data_type, **kwargs): output = self.serialize_data(data, data_type, **kwargs) if data_type == "bool": output = json.dumps(output) - except SerializationError: - raise TypeError("{} must be type {}.".format(name, data_type)) - else: - return str(output) + except SerializationError as exc: + raise TypeError("{} must be type {}.".format(name, data_type)) from exc + return str(output) def serialize_data(self, data, data_type, **kwargs): """Serialize generic data according to supplied data type. - :param data: The data to be serialized. + :param object data: The data to be serialized. :param str data_type: The type to be serialized from. - :param bool required: Whether it's essential that the data not be - empty or None :raises: AttributeError if required data is None. :raises: ValueError if data is None :raises: SerializationError if serialization fails. + :returns: The serialized data. + :rtype: str, int, float, bool, dict, list """ if data is None: raise ValueError("No value for given attribute") @@ -805,7 +870,7 @@ def serialize_data(self, data, data_type, **kwargs): if data_type in self.basic_types.values(): return self.serialize_basic(data, data_type, **kwargs) - elif data_type in self.serialize_type: + if data_type in self.serialize_type: return self.serialize_type[data_type](data, **kwargs) # If dependencies is empty, try with current data class @@ -821,11 +886,10 @@ def serialize_data(self, data, data_type, **kwargs): except (ValueError, TypeError) as err: msg = "Unable to serialize value: {!r} as type: {!r}." raise SerializationError(msg.format(data, data_type)) from err - else: - return self._serialize(data, **kwargs) + return self._serialize(data, **kwargs) @classmethod - def _get_custom_serializers(cls, data_type, **kwargs): + def _get_custom_serializers(cls, data_type, **kwargs): # pylint: disable=inconsistent-return-statements custom_serializer = kwargs.get("basic_types_serializers", {}).get(data_type) if custom_serializer: return custom_serializer @@ -841,23 +905,26 @@ def serialize_basic(cls, data, data_type, **kwargs): - basic_types_serializers dict[str, callable] : If set, use the callable as serializer - is_xml bool : If set, use xml_basic_types_serializers - :param data: Object to be serialized. + :param obj data: Object to be serialized. :param str data_type: Type of object in the iterable. + :rtype: str, int, float, bool + :return: serialized object """ custom_serializer = cls._get_custom_serializers(data_type, **kwargs) if custom_serializer: return custom_serializer(data) if data_type == "str": return cls.serialize_unicode(data) - return eval(data_type)(data) # nosec + return eval(data_type)(data) # nosec # pylint: disable=eval-used @classmethod def serialize_unicode(cls, data): """Special handling for serializing unicode strings in Py2. Encode to UTF-8 if unicode, otherwise handle as a str. - :param data: Object to be serialized. + :param str data: Object to be serialized. :rtype: str + :return: serialized object """ try: # If I received an enum, return its value return data.value @@ -871,8 +938,7 @@ def serialize_unicode(cls, data): return data except NameError: return str(data) - else: - return str(data) + return str(data) def serialize_iter(self, data, iter_type, div=None, **kwargs): """Serialize iterable. @@ -882,15 +948,13 @@ def serialize_iter(self, data, iter_type, div=None, **kwargs): serialization_ctxt['type'] should be same as data_type. - is_xml bool : If set, serialize as XML - :param list attr: Object to be serialized. + :param list data: Object to be serialized. :param str iter_type: Type of object in the iterable. - :param bool required: Whether the objects in the iterable must - not be None or empty. :param str div: If set, this str will be used to combine the elements in the iterable into a combined string. Default is 'None'. - :keyword bool do_quote: Whether to quote the serialized result of each iterable element. Defaults to False. :rtype: list, str + :return: serialized iterable """ if isinstance(data, str): raise SerializationError("Refuse str type as a valid iter type.") @@ -945,9 +1009,8 @@ def serialize_dict(self, attr, dict_type, **kwargs): :param dict attr: Object to be serialized. :param str dict_type: Type of object in the dictionary. - :param bool required: Whether the objects in the dictionary must - not be None or empty. :rtype: dict + :return: serialized dictionary """ serialization_ctxt = kwargs.get("serialization_ctxt", {}) serialized = {} @@ -971,7 +1034,7 @@ def serialize_dict(self, attr, dict_type, **kwargs): return serialized - def serialize_object(self, attr, **kwargs): + def serialize_object(self, attr, **kwargs): # pylint: disable=too-many-return-statements """Serialize a generic object. This will be handled as a dictionary. If object passed in is not a basic type (str, int, float, dict, list) it will simply be @@ -979,6 +1042,7 @@ def serialize_object(self, attr, **kwargs): :param dict attr: Object to be serialized. :rtype: dict or str + :return: serialized object """ if attr is None: return None @@ -1003,7 +1067,7 @@ def serialize_object(self, attr, **kwargs): return self.serialize_decimal(attr) # If it's a model or I know this dependency, serialize as a Model - elif obj_type in self.dependencies.values() or isinstance(attr, Model): + if obj_type in self.dependencies.values() or isinstance(attr, Model): return self._serialize(attr) if obj_type == dict: @@ -1034,56 +1098,61 @@ def serialize_enum(attr, enum_obj=None): try: enum_obj(result) # type: ignore return result - except ValueError: + except ValueError as exc: for enum_value in enum_obj: # type: ignore if enum_value.value.lower() == str(attr).lower(): return enum_value.value error = "{!r} is not valid value for enum {!r}" - raise SerializationError(error.format(attr, enum_obj)) + raise SerializationError(error.format(attr, enum_obj)) from exc @staticmethod - def serialize_bytearray(attr, **kwargs): + def serialize_bytearray(attr, **kwargs): # pylint: disable=unused-argument """Serialize bytearray into base-64 string. - :param attr: Object to be serialized. + :param str attr: Object to be serialized. :rtype: str + :return: serialized base64 """ return b64encode(attr).decode() @staticmethod - def serialize_base64(attr, **kwargs): + def serialize_base64(attr, **kwargs): # pylint: disable=unused-argument """Serialize str into base-64 string. - :param attr: Object to be serialized. + :param str attr: Object to be serialized. :rtype: str + :return: serialized base64 """ encoded = b64encode(attr).decode("ascii") return encoded.strip("=").replace("+", "-").replace("/", "_") @staticmethod - def serialize_decimal(attr, **kwargs): + def serialize_decimal(attr, **kwargs): # pylint: disable=unused-argument """Serialize Decimal object to float. - :param attr: Object to be serialized. + :param decimal attr: Object to be serialized. :rtype: float + :return: serialized decimal """ return float(attr) @staticmethod - def serialize_long(attr, **kwargs): + def serialize_long(attr, **kwargs): # pylint: disable=unused-argument """Serialize long (Py2) or int (Py3). - :param attr: Object to be serialized. + :param int attr: Object to be serialized. :rtype: int/long + :return: serialized long """ return _long_type(attr) @staticmethod - def serialize_date(attr, **kwargs): + def serialize_date(attr, **kwargs): # pylint: disable=unused-argument """Serialize Date object into ISO-8601 formatted string. :param Date attr: Object to be serialized. :rtype: str + :return: serialized date """ if isinstance(attr, str): attr = isodate.parse_date(attr) @@ -1091,11 +1160,12 @@ def serialize_date(attr, **kwargs): return t @staticmethod - def serialize_time(attr, **kwargs): + def serialize_time(attr, **kwargs): # pylint: disable=unused-argument """Serialize Time object into ISO-8601 formatted string. :param datetime.time attr: Object to be serialized. :rtype: str + :return: serialized time """ if isinstance(attr, str): attr = isodate.parse_time(attr) @@ -1105,30 +1175,32 @@ def serialize_time(attr, **kwargs): return t @staticmethod - def serialize_duration(attr, **kwargs): + def serialize_duration(attr, **kwargs): # pylint: disable=unused-argument """Serialize TimeDelta object into ISO-8601 formatted string. :param TimeDelta attr: Object to be serialized. :rtype: str + :return: serialized duration """ if isinstance(attr, str): attr = isodate.parse_duration(attr) return isodate.duration_isoformat(attr) @staticmethod - def serialize_rfc(attr, **kwargs): + def serialize_rfc(attr, **kwargs): # pylint: disable=unused-argument """Serialize Datetime object into RFC-1123 formatted string. :param Datetime attr: Object to be serialized. :rtype: str :raises: TypeError if format invalid. + :return: serialized rfc """ try: if not attr.tzinfo: _LOGGER.warning("Datetime with no tzinfo will be considered UTC.") utc = attr.utctimetuple() - except AttributeError: - raise TypeError("RFC1123 object must be valid Datetime object.") + except AttributeError as exc: + raise TypeError("RFC1123 object must be valid Datetime object.") from exc return "{}, {:02} {} {:04} {:02}:{:02}:{:02} GMT".format( Serializer.days[utc.tm_wday], @@ -1141,12 +1213,13 @@ def serialize_rfc(attr, **kwargs): ) @staticmethod - def serialize_iso(attr, **kwargs): + def serialize_iso(attr, **kwargs): # pylint: disable=unused-argument """Serialize Datetime object into ISO-8601 formatted string. :param Datetime attr: Object to be serialized. :rtype: str :raises: SerializationError if format invalid. + :return: serialized iso """ if isinstance(attr, str): attr = isodate.parse_datetime(attr) @@ -1172,13 +1245,14 @@ def serialize_iso(attr, **kwargs): raise TypeError(msg) from err @staticmethod - def serialize_unix(attr, **kwargs): + def serialize_unix(attr, **kwargs): # pylint: disable=unused-argument """Serialize Datetime object into IntTime format. This is represented as seconds. :param Datetime attr: Object to be serialized. :rtype: int :raises: SerializationError if format invalid + :return: serialied unix """ if isinstance(attr, int): return attr @@ -1186,11 +1260,11 @@ def serialize_unix(attr, **kwargs): if not attr.tzinfo: _LOGGER.warning("Datetime with no tzinfo will be considered UTC.") return int(calendar.timegm(attr.utctimetuple())) - except AttributeError: - raise TypeError("Unix time object must be valid Datetime object.") + except AttributeError as exc: + raise TypeError("Unix time object must be valid Datetime object.") from exc -def rest_key_extractor(attr, attr_desc, data): +def rest_key_extractor(attr, attr_desc, data): # pylint: disable=unused-argument key = attr_desc["key"] working_data = data @@ -1211,7 +1285,9 @@ def rest_key_extractor(attr, attr_desc, data): return working_data.get(key) -def rest_key_case_insensitive_extractor(attr, attr_desc, data): +def rest_key_case_insensitive_extractor( # pylint: disable=unused-argument, inconsistent-return-statements + attr, attr_desc, data +): key = attr_desc["key"] working_data = data @@ -1232,17 +1308,29 @@ def rest_key_case_insensitive_extractor(attr, attr_desc, data): return attribute_key_case_insensitive_extractor(key, None, working_data) -def last_rest_key_extractor(attr, attr_desc, data): - """Extract the attribute in "data" based on the last part of the JSON path key.""" +def last_rest_key_extractor(attr, attr_desc, data): # pylint: disable=unused-argument + """Extract the attribute in "data" based on the last part of the JSON path key. + + :param str attr: The attribute to extract + :param dict attr_desc: The attribute description + :param dict data: The data to extract from + :rtype: object + :returns: The extracted attribute + """ key = attr_desc["key"] dict_keys = _FLATTEN.split(key) return attribute_key_extractor(dict_keys[-1], None, data) -def last_rest_key_case_insensitive_extractor(attr, attr_desc, data): +def last_rest_key_case_insensitive_extractor(attr, attr_desc, data): # pylint: disable=unused-argument """Extract the attribute in "data" based on the last part of the JSON path key. This is the case insensitive version of "last_rest_key_extractor" + :param str attr: The attribute to extract + :param dict attr_desc: The attribute description + :param dict data: The data to extract from + :rtype: object + :returns: The extracted attribute """ key = attr_desc["key"] dict_keys = _FLATTEN.split(key) @@ -1279,7 +1367,7 @@ def _extract_name_from_internal_type(internal_type): return xml_name -def xml_key_extractor(attr, attr_desc, data): +def xml_key_extractor(attr, attr_desc, data): # pylint: disable=unused-argument,too-many-return-statements if isinstance(data, dict): return None @@ -1331,22 +1419,21 @@ def xml_key_extractor(attr, attr_desc, data): if is_iter_type: if is_wrapped: return None # is_wrapped no node, we want None - else: - return [] # not wrapped, assume empty list + return [] # not wrapped, assume empty list return None # Assume it's not there, maybe an optional node. # If is_iter_type and not wrapped, return all found children if is_iter_type: if not is_wrapped: return children - else: # Iter and wrapped, should have found one node only (the wrap one) - if len(children) != 1: - raise DeserializationError( - "Tried to deserialize an array not wrapped, and found several nodes '{}'. Maybe you should declare this array as wrapped?".format( - xml_name - ) + # Iter and wrapped, should have found one node only (the wrap one) + if len(children) != 1: + raise DeserializationError( + "Tried to deserialize an array not wrapped, and found several nodes '{}'. Maybe you should declare this array as wrapped?".format( # pylint: disable=line-too-long + xml_name ) - return list(children[0]) # Might be empty list and that's ok. + ) + return list(children[0]) # Might be empty list and that's ok. # Here it's not a itertype, we should have found one element only or empty if len(children) > 1: @@ -1354,7 +1441,7 @@ def xml_key_extractor(attr, attr_desc, data): return children[0] -class Deserializer(object): +class Deserializer: """Response object model deserializer. :param dict classes: Class type dictionary for deserializing complex types. @@ -1363,9 +1450,9 @@ class Deserializer(object): basic_types = {str: "str", int: "int", bool: "bool", float: "float"} - valid_date = re.compile(r"\d{4}[-]\d{2}[-]\d{2}T\d{2}:\d{2}:\d{2}" r"\.?\d*Z?[-+]?[\d{2}]?:?[\d{2}]?") + valid_date = re.compile(r"\d{4}[-]\d{2}[-]\d{2}T\d{2}:\d{2}:\d{2}\.?\d*Z?[-+]?[\d{2}]?:?[\d{2}]?") - def __init__(self, classes: Optional[Mapping[str, type]] = None): + def __init__(self, classes: Optional[Mapping[str, type]] = None) -> None: self.deserialize_type = { "iso-8601": Deserializer.deserialize_iso, "rfc-1123": Deserializer.deserialize_rfc, @@ -1403,11 +1490,12 @@ def __call__(self, target_obj, response_data, content_type=None): :param str content_type: Swagger "produces" if available. :raises: DeserializationError if deserialization fails. :return: Deserialized object. + :rtype: object """ data = self._unpack_content(response_data, content_type) return self._deserialize(target_obj, data) - def _deserialize(self, target_obj, data): + def _deserialize(self, target_obj, data): # pylint: disable=inconsistent-return-statements """Call the deserializer on a model. Data needs to be already deserialized as JSON or XML ElementTree @@ -1416,12 +1504,13 @@ def _deserialize(self, target_obj, data): :param object data: Object to deserialize. :raises: DeserializationError if deserialization fails. :return: Deserialized object. + :rtype: object """ # This is already a model, go recursive just in case if hasattr(data, "_attribute_map"): constants = [name for name, config in getattr(data, "_validation", {}).items() if config.get("constant")] try: - for attr, mapconfig in data._attribute_map.items(): + for attr, mapconfig in data._attribute_map.items(): # pylint: disable=protected-access if attr in constants: continue value = getattr(data, attr) @@ -1440,13 +1529,13 @@ def _deserialize(self, target_obj, data): if isinstance(response, str): return self.deserialize_data(data, response) - elif isinstance(response, type) and issubclass(response, Enum): + if isinstance(response, type) and issubclass(response, Enum): return self.deserialize_enum(data, response) if data is None or data is CoreNull: return data try: - attributes = response._attribute_map # type: ignore + attributes = response._attribute_map # type: ignore # pylint: disable=protected-access d_attrs = {} for attr, attr_desc in attributes.items(): # Check empty string. If it's not empty, someone has a real "additionalProperties"... @@ -1476,9 +1565,8 @@ def _deserialize(self, target_obj, data): except (AttributeError, TypeError, KeyError) as err: msg = "Unable to deserialize to object: " + class_name # type: ignore raise DeserializationError(msg) from err - else: - additional_properties = self._build_additional_properties(attributes, data) - return self._instantiate_model(response, d_attrs, additional_properties) + additional_properties = self._build_additional_properties(attributes, data) + return self._instantiate_model(response, d_attrs, additional_properties) def _build_additional_properties(self, attribute_map, data): if not self.additional_properties_detection: @@ -1505,6 +1593,8 @@ def _classify_target(self, target, data): :param str target: The target object type to deserialize to. :param str/dict data: The response data to deserialize. + :return: The classified target object and its class name. + :rtype: tuple """ if target is None: return None, None @@ -1516,7 +1606,7 @@ def _classify_target(self, target, data): return target, target try: - target = target._classify(data, self.dependencies) # type: ignore + target = target._classify(data, self.dependencies) # type: ignore # pylint: disable=protected-access except AttributeError: pass # Target is not a Model, no classify return target, target.__class__.__name__ # type: ignore @@ -1531,10 +1621,12 @@ def failsafe_deserialize(self, target_obj, data, content_type=None): :param str target_obj: The target object type to deserialize to. :param str/dict data: The response data to deserialize. :param str content_type: Swagger "produces" if available. + :return: Deserialized object. + :rtype: object """ try: return self(target_obj, data, content_type=content_type) - except: + except: # pylint: disable=bare-except _LOGGER.debug( "Ran into a deserialization error. Ignoring since this is failsafe deserialization", exc_info=True ) @@ -1552,10 +1644,12 @@ def _unpack_content(raw_data, content_type=None): If raw_data is something else, bypass all logic and return it directly. - :param raw_data: Data to be processed. - :param content_type: How to parse if raw_data is a string/bytes. + :param obj raw_data: Data to be processed. + :param str content_type: How to parse if raw_data is a string/bytes. :raises JSONDecodeError: If JSON is requested and parsing is impossible. :raises UnicodeDecodeError: If bytes is not UTF8 + :rtype: object + :return: Unpacked content. """ # Assume this is enough to detect a Pipeline Response without importing it context = getattr(raw_data, "context", {}) @@ -1579,24 +1673,35 @@ def _unpack_content(raw_data, content_type=None): def _instantiate_model(self, response, attrs, additional_properties=None): """Instantiate a response model passing in deserialized args. - :param response: The response model class. - :param d_attrs: The deserialized response attributes. + :param Response response: The response model class. + :param dict attrs: The deserialized response attributes. + :param dict additional_properties: Additional properties to be set. + :rtype: Response + :return: The instantiated response model. """ if callable(response): subtype = getattr(response, "_subtype_map", {}) try: - readonly = [k for k, v in response._validation.items() if v.get("readonly")] - const = [k for k, v in response._validation.items() if v.get("constant")] + readonly = [ + k + for k, v in response._validation.items() # pylint: disable=protected-access # type: ignore + if v.get("readonly") + ] + const = [ + k + for k, v in response._validation.items() # pylint: disable=protected-access # type: ignore + if v.get("constant") + ] kwargs = {k: v for k, v in attrs.items() if k not in subtype and k not in readonly + const} response_obj = response(**kwargs) for attr in readonly: setattr(response_obj, attr, attrs.get(attr)) if additional_properties: - response_obj.additional_properties = additional_properties + response_obj.additional_properties = additional_properties # type: ignore return response_obj except TypeError as err: msg = "Unable to deserialize {} into model {}. ".format(kwargs, response) # type: ignore - raise DeserializationError(msg + str(err)) + raise DeserializationError(msg + str(err)) from err else: try: for attr, value in attrs.items(): @@ -1605,15 +1710,16 @@ def _instantiate_model(self, response, attrs, additional_properties=None): except Exception as exp: msg = "Unable to populate response model. " msg += "Type: {}, Error: {}".format(type(response), exp) - raise DeserializationError(msg) + raise DeserializationError(msg) from exp - def deserialize_data(self, data, data_type): + def deserialize_data(self, data, data_type): # pylint: disable=too-many-return-statements """Process data for deserialization according to data type. :param str data: The response string to be deserialized. :param str data_type: The type to deserialize to. :raises: DeserializationError if deserialization fails. :return: Deserialized object. + :rtype: object """ if data is None: return data @@ -1627,7 +1733,11 @@ def deserialize_data(self, data, data_type): if isinstance(data, self.deserialize_expected_types.get(data_type, tuple())): return data - is_a_text_parsing_type = lambda x: x not in ["object", "[]", r"{}"] + is_a_text_parsing_type = lambda x: x not in [ # pylint: disable=unnecessary-lambda-assignment + "object", + "[]", + r"{}", + ] if isinstance(data, ET.Element) and is_a_text_parsing_type(data_type) and not data.text: return None data_val = self.deserialize_type[data_type](data) @@ -1647,14 +1757,14 @@ def deserialize_data(self, data, data_type): msg = "Unable to deserialize response data." msg += " Data: {}, {}".format(data, data_type) raise DeserializationError(msg) from err - else: - return self._deserialize(obj_type, data) + return self._deserialize(obj_type, data) def deserialize_iter(self, attr, iter_type): """Deserialize an iterable. :param list attr: Iterable to be deserialized. :param str iter_type: The type of object in the iterable. + :return: Deserialized iterable. :rtype: list """ if attr is None: @@ -1671,6 +1781,7 @@ def deserialize_dict(self, attr, dict_type): :param dict/list attr: Dictionary to be deserialized. Also accepts a list of key, value pairs. :param str dict_type: The object type of the items in the dictionary. + :return: Deserialized dictionary. :rtype: dict """ if isinstance(attr, list): @@ -1681,11 +1792,12 @@ def deserialize_dict(self, attr, dict_type): attr = {el.tag: el.text for el in attr} return {k: self.deserialize_data(v, dict_type) for k, v in attr.items()} - def deserialize_object(self, attr, **kwargs): + def deserialize_object(self, attr, **kwargs): # pylint: disable=too-many-return-statements """Deserialize a generic object. This will be handled as a dictionary. :param dict attr: Dictionary to be deserialized. + :return: Deserialized object. :rtype: dict :raises: TypeError if non-builtin datatype encountered. """ @@ -1720,11 +1832,10 @@ def deserialize_object(self, attr, **kwargs): pass return deserialized - else: - error = "Cannot deserialize generic object with type: " - raise TypeError(error + str(obj_type)) + error = "Cannot deserialize generic object with type: " + raise TypeError(error + str(obj_type)) - def deserialize_basic(self, attr, data_type): + def deserialize_basic(self, attr, data_type): # pylint: disable=too-many-return-statements """Deserialize basic builtin data type from string. Will attempt to convert to str, int, float and bool. This function will also accept '1', '0', 'true' and 'false' as @@ -1732,6 +1843,7 @@ def deserialize_basic(self, attr, data_type): :param str attr: response string to be deserialized. :param str data_type: deserialization data type. + :return: Deserialized basic type. :rtype: str, int, float or bool :raises: TypeError if string format is not valid. """ @@ -1743,24 +1855,23 @@ def deserialize_basic(self, attr, data_type): if data_type == "str": # None or '', node is empty string. return "" - else: - # None or '', node with a strong type is None. - # Don't try to model "empty bool" or "empty int" - return None + # None or '', node with a strong type is None. + # Don't try to model "empty bool" or "empty int" + return None if data_type == "bool": if attr in [True, False, 1, 0]: return bool(attr) - elif isinstance(attr, str): + if isinstance(attr, str): if attr.lower() in ["true", "1"]: return True - elif attr.lower() in ["false", "0"]: + if attr.lower() in ["false", "0"]: return False raise TypeError("Invalid boolean value: {}".format(attr)) if data_type == "str": return self.deserialize_unicode(attr) - return eval(data_type)(attr) # nosec + return eval(data_type)(attr) # nosec # pylint: disable=eval-used @staticmethod def deserialize_unicode(data): @@ -1768,6 +1879,7 @@ def deserialize_unicode(data): as a string. :param str data: response string to be deserialized. + :return: Deserialized string. :rtype: str or unicode """ # We might be here because we have an enum modeled as string, @@ -1781,8 +1893,7 @@ def deserialize_unicode(data): return data except NameError: return str(data) - else: - return str(data) + return str(data) @staticmethod def deserialize_enum(data, enum_obj): @@ -1794,6 +1905,7 @@ def deserialize_enum(data, enum_obj): :param str data: Response string to be deserialized. If this value is None or invalid it will be returned as-is. :param Enum enum_obj: Enum object to deserialize to. + :return: Deserialized enum object. :rtype: Enum """ if isinstance(data, enum_obj) or data is None: @@ -1804,9 +1916,9 @@ def deserialize_enum(data, enum_obj): # Workaround. We might consider remove it in the future. try: return list(enum_obj.__members__.values())[data] - except IndexError: + except IndexError as exc: error = "{!r} is not a valid index for enum {!r}" - raise DeserializationError(error.format(data, enum_obj)) + raise DeserializationError(error.format(data, enum_obj)) from exc try: return enum_obj(str(data)) except ValueError: @@ -1822,6 +1934,7 @@ def deserialize_bytearray(attr): """Deserialize string into bytearray. :param str attr: response string to be deserialized. + :return: Deserialized bytearray :rtype: bytearray :raises: TypeError if string format invalid. """ @@ -1834,6 +1947,7 @@ def deserialize_base64(attr): """Deserialize base64 encoded string into string. :param str attr: response string to be deserialized. + :return: Deserialized base64 string :rtype: bytearray :raises: TypeError if string format invalid. """ @@ -1849,8 +1963,9 @@ def deserialize_decimal(attr): """Deserialize string into Decimal object. :param str attr: response string to be deserialized. - :rtype: Decimal + :return: Deserialized decimal :raises: DeserializationError if string format invalid. + :rtype: decimal """ if isinstance(attr, ET.Element): attr = attr.text @@ -1865,6 +1980,7 @@ def deserialize_long(attr): """Deserialize string into long (Py2) or int (Py3). :param str attr: response string to be deserialized. + :return: Deserialized int :rtype: long or int :raises: ValueError if string format invalid. """ @@ -1877,6 +1993,7 @@ def deserialize_duration(attr): """Deserialize ISO-8601 formatted string into TimeDelta object. :param str attr: response string to be deserialized. + :return: Deserialized duration :rtype: TimeDelta :raises: DeserializationError if string format invalid. """ @@ -1887,14 +2004,14 @@ def deserialize_duration(attr): except (ValueError, OverflowError, AttributeError) as err: msg = "Cannot deserialize duration object." raise DeserializationError(msg) from err - else: - return duration + return duration @staticmethod def deserialize_date(attr): """Deserialize ISO-8601 formatted string into Date object. :param str attr: response string to be deserialized. + :return: Deserialized date :rtype: Date :raises: DeserializationError if string format invalid. """ @@ -1910,6 +2027,7 @@ def deserialize_time(attr): """Deserialize ISO-8601 formatted string into time object. :param str attr: response string to be deserialized. + :return: Deserialized time :rtype: datetime.time :raises: DeserializationError if string format invalid. """ @@ -1924,6 +2042,7 @@ def deserialize_rfc(attr): """Deserialize RFC-1123 formatted string into Datetime object. :param str attr: response string to be deserialized. + :return: Deserialized RFC datetime :rtype: Datetime :raises: DeserializationError if string format invalid. """ @@ -1939,14 +2058,14 @@ def deserialize_rfc(attr): except ValueError as err: msg = "Cannot deserialize to rfc datetime object." raise DeserializationError(msg) from err - else: - return date_obj + return date_obj @staticmethod def deserialize_iso(attr): """Deserialize ISO-8601 formatted string into Datetime object. :param str attr: response string to be deserialized. + :return: Deserialized ISO datetime :rtype: Datetime :raises: DeserializationError if string format invalid. """ @@ -1976,8 +2095,7 @@ def deserialize_iso(attr): except (ValueError, OverflowError, AttributeError) as err: msg = "Cannot deserialize datetime object." raise DeserializationError(msg) from err - else: - return date_obj + return date_obj @staticmethod def deserialize_unix(attr): @@ -1985,6 +2103,7 @@ def deserialize_unix(attr): This is represented as seconds. :param int attr: Object to be serialized. + :return: Deserialized datetime :rtype: Datetime :raises: DeserializationError if format invalid """ @@ -1996,5 +2115,4 @@ def deserialize_unix(attr): except ValueError as err: msg = "Cannot deserialize to unix datetime object." raise DeserializationError(msg) from err - else: - return date_obj + return date_obj diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/_vendor.py b/sdk/ai/azure-ai-inference/azure/ai/inference/_vendor.py index 8ea240fb008b..147e96be133e 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/_vendor.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/_vendor.py @@ -15,7 +15,6 @@ ) if TYPE_CHECKING: - # pylint: disable=unused-import,ungrouped-imports from azure.core import PipelineClient from ._serialization import Deserializer, Serializer diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/_version.py b/sdk/ai/azure-ai-inference/azure/ai/inference/_version.py index 84058978c521..be71c81bd282 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/_version.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/_version.py @@ -6,4 +6,4 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -VERSION = "1.0.0b7" +VERSION = "1.0.0b1" diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/__init__.py b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/__init__.py index c31764c00803..668f989a5838 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/__init__.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/__init__.py @@ -5,21 +5,29 @@ # Code generated by Microsoft (R) Python Code Generator. # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- +# pylint: disable=wrong-import-position -from ._patch import ChatCompletionsClient -from ._patch import EmbeddingsClient -from ._patch import ImageEmbeddingsClient +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ._patch import * # pylint: disable=unused-wildcard-import -from ._patch import load_client +from ._client import ChatCompletionsClient # type: ignore +from ._client import EmbeddingsClient # type: ignore +from ._client import ImageEmbeddingsClient # type: ignore + +try: + from ._patch import __all__ as _patch_all + from ._patch import * +except ImportError: + _patch_all = [] from ._patch import patch_sdk as _patch_sdk __all__ = [ - "load_client", "ChatCompletionsClient", "EmbeddingsClient", "ImageEmbeddingsClient", ] - +__all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_client.py b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_client.py index 30c7afbfbd91..7cea61120519 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_client.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_client.py @@ -28,11 +28,10 @@ ) if TYPE_CHECKING: - # pylint: disable=unused-import,ungrouped-imports from azure.core.credentials_async import AsyncTokenCredential -class ChatCompletionsClient(ChatCompletionsClientOperationsMixin): # pylint: disable=client-accepts-api-version-keyword +class ChatCompletionsClient(ChatCompletionsClientOperationsMixin): """ChatCompletionsClient. :param endpoint: Service host. Required. @@ -115,7 +114,7 @@ async def __aexit__(self, *exc_details: Any) -> None: await self._client.__aexit__(*exc_details) -class EmbeddingsClient(EmbeddingsClientOperationsMixin): # pylint: disable=client-accepts-api-version-keyword +class EmbeddingsClient(EmbeddingsClientOperationsMixin): """EmbeddingsClient. :param endpoint: Service host. Required. @@ -198,7 +197,7 @@ async def __aexit__(self, *exc_details: Any) -> None: await self._client.__aexit__(*exc_details) -class ImageEmbeddingsClient(ImageEmbeddingsClientOperationsMixin): # pylint: disable=client-accepts-api-version-keyword +class ImageEmbeddingsClient(ImageEmbeddingsClientOperationsMixin): """ImageEmbeddingsClient. :param endpoint: Service host. Required. diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_configuration.py b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_configuration.py index e4c5d7111d22..2eee5cfe60cb 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_configuration.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_configuration.py @@ -14,11 +14,10 @@ from .._version import VERSION if TYPE_CHECKING: - # pylint: disable=unused-import,ungrouped-imports from azure.core.credentials_async import AsyncTokenCredential -class ChatCompletionsClientConfiguration: # pylint: disable=too-many-instance-attributes,name-too-long +class ChatCompletionsClientConfiguration: # pylint: disable=too-many-instance-attributes """Configuration for ChatCompletionsClient. Note that all parameters used to create this instance are saved as instance @@ -78,7 +77,7 @@ def _configure(self, **kwargs: Any) -> None: self.authentication_policy = self._infer_policy(**kwargs) -class EmbeddingsClientConfiguration: # pylint: disable=too-many-instance-attributes,name-too-long +class EmbeddingsClientConfiguration: # pylint: disable=too-many-instance-attributes """Configuration for EmbeddingsClient. Note that all parameters used to create this instance are saved as instance @@ -138,7 +137,7 @@ def _configure(self, **kwargs: Any) -> None: self.authentication_policy = self._infer_policy(**kwargs) -class ImageEmbeddingsClientConfiguration: # pylint: disable=too-many-instance-attributes,name-too-long +class ImageEmbeddingsClientConfiguration: # pylint: disable=too-many-instance-attributes """Configuration for ImageEmbeddingsClient. Note that all parameters used to create this instance are saved as instance diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_operations/__init__.py b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_operations/__init__.py index d3ebd561f739..ab87088736aa 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_operations/__init__.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_operations/__init__.py @@ -5,13 +5,19 @@ # Code generated by Microsoft (R) Python Code Generator. # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- +# pylint: disable=wrong-import-position -from ._operations import ChatCompletionsClientOperationsMixin -from ._operations import EmbeddingsClientOperationsMixin -from ._operations import ImageEmbeddingsClientOperationsMixin +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ._patch import * # pylint: disable=unused-wildcard-import + +from ._operations import ChatCompletionsClientOperationsMixin # type: ignore +from ._operations import EmbeddingsClientOperationsMixin # type: ignore +from ._operations import ImageEmbeddingsClientOperationsMixin # type: ignore from ._patch import __all__ as _patch_all -from ._patch import * # pylint: disable=unused-wildcard-import +from ._patch import * from ._patch import patch_sdk as _patch_sdk __all__ = [ @@ -19,5 +25,5 @@ "EmbeddingsClientOperationsMixin", "ImageEmbeddingsClientOperationsMixin", ] -__all__.extend([p for p in _patch_all if p not in __all__]) +__all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_operations/_operations.py b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_operations/_operations.py index 0be948bd275d..53015678105e 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_operations/_operations.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_operations/_operations.py @@ -1,4 +1,3 @@ -# pylint: disable=too-many-lines,too-many-statements # coding=utf-8 # -------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. @@ -9,7 +8,7 @@ from io import IOBase import json import sys -from typing import Any, Callable, Dict, IO, List, Optional, Type, TypeVar, Union, overload +from typing import Any, Callable, Dict, IO, List, Optional, TypeVar, Union, overload from azure.core.exceptions import ( ClientAuthenticationError, @@ -41,7 +40,7 @@ if sys.version_info >= (3, 9): from collections.abc import MutableMapping else: - from typing import MutableMapping # type: ignore # pylint: disable=ungrouped-imports + from typing import MutableMapping # type: ignore JSON = MutableMapping[str, Any] # pylint: disable=unsubscriptable-object _Unset: Any = object() T = TypeVar("T") @@ -115,7 +114,6 @@ async def _complete( model: Optional[str] = None, **kwargs: Any ) -> _models.ChatCompletions: - # pylint: disable=too-many-locals """Gets chat completions for the provided chat messages. Completions support a wide variety of tasks and generate text that continues from or "completes" @@ -204,7 +202,7 @@ async def _complete( :rtype: ~azure.ai.inference.models.ChatCompletions :raises ~azure.core.exceptions.HttpResponseError: """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { + error_map: MutableMapping = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError, @@ -294,7 +292,7 @@ async def _get_model_info(self, **kwargs: Any) -> _models.ModelInfo: :rtype: ~azure.ai.inference.models.ModelInfo :raises ~azure.core.exceptions.HttpResponseError: """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { + error_map: MutableMapping = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError, @@ -425,7 +423,7 @@ async def _embed( :rtype: ~azure.ai.inference.models.EmbeddingsResult :raises ~azure.core.exceptions.HttpResponseError: """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { + error_map: MutableMapping = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError, @@ -507,7 +505,7 @@ async def _get_model_info(self, **kwargs: Any) -> _models.ModelInfo: :rtype: ~azure.ai.inference.models.ModelInfo :raises ~azure.core.exceptions.HttpResponseError: """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { + error_map: MutableMapping = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError, @@ -572,7 +570,7 @@ async def _embed( async def _embed( self, *, - input: List[_models.EmbeddingInput], + input: List[_models.ImageEmbeddingInput], extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, content_type: str = "application/json", dimensions: Optional[int] = None, @@ -596,7 +594,7 @@ async def _embed( self, body: Union[JSON, IO[bytes]] = _Unset, *, - input: List[_models.EmbeddingInput] = _Unset, + input: List[_models.ImageEmbeddingInput] = _Unset, extra_params: Optional[Union[str, _models._enums.ExtraParameters]] = None, dimensions: Optional[int] = None, encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, @@ -612,7 +610,7 @@ async def _embed( :keyword input: Input image to embed. To embed multiple inputs in a single request, pass an array. The input must not exceed the max input tokens for the model. Required. - :paramtype input: list[~azure.ai.inference.models.EmbeddingInput] + :paramtype input: list[~azure.ai.inference.models.ImageEmbeddingInput] :keyword extra_params: Controls what happens if extra parameters, undefined by the REST API, are passed in the JSON request payload. This sets the HTTP request header ``extra-parameters``. Known values are: "error", "drop", and @@ -641,7 +639,7 @@ async def _embed( :rtype: ~azure.ai.inference.models.EmbeddingsResult :raises ~azure.core.exceptions.HttpResponseError: """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { + error_map: MutableMapping = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError, @@ -723,7 +721,7 @@ async def _get_model_info(self, **kwargs: Any) -> _models.ModelInfo: :rtype: ~azure.ai.inference.models.ModelInfo :raises ~azure.core.exceptions.HttpResponseError: """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { + error_map: MutableMapping = { 401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError, diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_patch.py b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_patch.py index 2bdfd67a40cb..f7dd32510333 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_patch.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_patch.py @@ -2,1266 +2,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -# pylint: disable=too-many-lines) """Customize generated code here. Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize """ -import json -import logging -import sys +from typing import List -from io import IOBase -from typing import Any, Dict, Union, IO, List, Literal, Optional, overload, Type, TYPE_CHECKING, AsyncIterable - -from azure.core.pipeline import PipelineResponse -from azure.core.credentials import AzureKeyCredential -from azure.core.tracing.decorator_async import distributed_trace_async -from azure.core.utils import case_insensitive_dict -from azure.core.exceptions import ( - ClientAuthenticationError, - HttpResponseError, - map_error, - ResourceExistsError, - ResourceNotFoundError, - ResourceNotModifiedError, -) -from .. import models as _models -from .._model_base import SdkJSONEncoder, _deserialize -from ._client import ChatCompletionsClient as ChatCompletionsClientGenerated -from ._client import EmbeddingsClient as EmbeddingsClientGenerated -from ._client import ImageEmbeddingsClient as ImageEmbeddingsClientGenerated -from .._operations._operations import ( - build_chat_completions_complete_request, - build_embeddings_embed_request, - build_image_embeddings_embed_request, -) - -if TYPE_CHECKING: - # pylint: disable=unused-import,ungrouped-imports - from azure.core.credentials_async import AsyncTokenCredential - -if sys.version_info >= (3, 9): - from collections.abc import MutableMapping -else: - from typing import MutableMapping # type: ignore # pylint: disable=ungrouped-imports - -JSON = MutableMapping[str, Any] # pylint: disable=unsubscriptable-object -_Unset: Any = object() -_LOGGER = logging.getLogger(__name__) - - -async def load_client( - endpoint: str, credential: Union[AzureKeyCredential, "AsyncTokenCredential"], **kwargs: Any -) -> Union["ChatCompletionsClient", "EmbeddingsClient", "ImageEmbeddingsClient"]: - # pylint: disable=line-too-long - """ - Load a client from a given endpoint URL. The method makes a REST API call to the `/info` route - on the given endpoint, to determine the model type and therefore which client to instantiate. - - :param endpoint: Service host. Required. - :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is either a - AzureKeyCredential type or a AsyncTokenCredential type. Required. - :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials_async.AsyncTokenCredential - :keyword api_version: The API version to use for this operation. Default value is - "2024-05-01-preview". Note that overriding this default value may result in unsupported - behavior. - :paramtype api_version: str - :return: The appropriate asynchronous client associated with the given endpoint - :rtype: ~azure.ai.inference.aio.ChatCompletionsClient or ~azure.ai.inference.aio.EmbeddingsClient - or ~azure.ai.inference.aio.ImageEmbeddingsClient - :raises ~azure.core.exceptions.HttpResponseError: - """ - - async with ChatCompletionsClient( - endpoint, credential, **kwargs - ) as client: # Pick any of the clients, it does not matter. - model_info = await client.get_model_info() # type: ignore - - _LOGGER.info("model_info=%s", model_info) - if not model_info.model_type: - raise ValueError( - "The AI model information is missing a value for `model type`. Cannot create an appropriate client." - ) - - # TODO: Remove "completions" and "embedding" once Mistral Large and Cohere fixes their model type - if model_info.model_type in (_models.ModelType.CHAT, "completion", "chat-completion", "chat-completions"): - chat_completion_client = ChatCompletionsClient(endpoint, credential, **kwargs) - chat_completion_client._model_info = ( # pylint: disable=protected-access,attribute-defined-outside-init - model_info - ) - return chat_completion_client - - if model_info.model_type in (_models.ModelType.EMBEDDINGS, "embedding"): - embedding_client = EmbeddingsClient(endpoint, credential, **kwargs) - embedding_client._model_info = model_info # pylint: disable=protected-access,attribute-defined-outside-init - return embedding_client - - if model_info.model_type == _models.ModelType.IMAGE_EMBEDDINGS: - image_embedding_client = ImageEmbeddingsClient(endpoint, credential, **kwargs) - image_embedding_client._model_info = ( # pylint: disable=protected-access,attribute-defined-outside-init - model_info - ) - return image_embedding_client - - raise ValueError(f"No client available to support AI model type `{model_info.model_type}`") - - -class ChatCompletionsClient(ChatCompletionsClientGenerated): # pylint: disable=too-many-instance-attributes - """ChatCompletionsClient. - - :param endpoint: Service host. Required. - :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is either a - AzureKeyCredential type or a AsyncTokenCredential type. Required. - :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials_async.AsyncTokenCredential - :keyword frequency_penalty: A value that influences the probability of generated tokens - appearing based on their cumulative frequency in generated text. - Positive values will make tokens less likely to appear as their frequency increases and - decrease the likelihood of the model repeating the same statements verbatim. - Supported range is [-2, 2]. - Default value is None. - :paramtype frequency_penalty: float - :keyword presence_penalty: A value that influences the probability of generated tokens - appearing based on their existing - presence in generated text. - Positive values will make tokens less likely to appear when they already exist and increase - the model's likelihood to output new topics. - Supported range is [-2, 2]. - Default value is None. - :paramtype presence_penalty: float - :keyword temperature: The sampling temperature to use that controls the apparent creativity of - generated completions. - Higher values will make output more random while lower values will make results more focused - and deterministic. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. - Default value is None. - :paramtype temperature: float - :keyword top_p: An alternative to sampling with temperature called nucleus sampling. This value - causes the - model to consider the results of tokens with the provided probability mass. As an example, a - value of 0.15 will cause only the tokens comprising the top 15% of probability mass to be - considered. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. - Default value is None. - :paramtype top_p: float - :keyword max_tokens: The maximum number of tokens to generate. Default value is None. - :paramtype max_tokens: int - :keyword response_format: The format that the model must output. Use this to enable JSON mode - instead of the default text mode. - Note that to enable JSON mode, some AI models may also require you to instruct the model to - produce JSON via a system or user message. Default value is None. - :paramtype response_format: ~azure.ai.inference.models.ChatCompletionsResponseFormat - :keyword stop: A collection of textual sequences that will end completions generation. Default - value is None. - :paramtype stop: list[str] - :keyword tools: The available tool definitions that the chat completions request can use, - including caller-defined functions. Default value is None. - :paramtype tools: list[~azure.ai.inference.models.ChatCompletionsToolDefinition] - :keyword tool_choice: If specified, the model will configure which of the provided tools it can - use for the chat completions response. Is either a Union[str, - "_models.ChatCompletionsToolChoicePreset"] type or a ChatCompletionsNamedToolChoice type. - Default value is None. - :paramtype tool_choice: str or ~azure.ai.inference.models.ChatCompletionsToolChoicePreset or - ~azure.ai.inference.models.ChatCompletionsNamedToolChoice - :keyword seed: If specified, the system will make a best effort to sample deterministically - such that repeated requests with the - same seed and parameters should return the same result. Determinism is not guaranteed. - Default value is None. - :paramtype seed: int - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :keyword api_version: The API version to use for this operation. Default value is - "2024-05-01-preview". Note that overriding this default value may result in unsupported - behavior. - :paramtype api_version: str - """ - - def __init__( - self, - endpoint: str, - credential: Union[AzureKeyCredential, "AsyncTokenCredential"], - *, - frequency_penalty: Optional[float] = None, - presence_penalty: Optional[float] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, - response_format: Optional[_models.ChatCompletionsResponseFormat] = None, - stop: Optional[List[str]] = None, - tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, - tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] - ] = None, - seed: Optional[int] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> None: - - self._model_info: Optional[_models.ModelInfo] = None - - # Store default chat completions settings, to be applied in all future service calls - # unless overridden by arguments in the `complete` method. - self._frequency_penalty = frequency_penalty - self._presence_penalty = presence_penalty - self._temperature = temperature - self._top_p = top_p - self._max_tokens = max_tokens - self._response_format = response_format - self._stop = stop - self._tools = tools - self._tool_choice = tool_choice - self._seed = seed - self._model = model - self._model_extras = model_extras - - # For Key auth, we need to send these two auth HTTP request headers simultaneously: - # 1. "Authorization: Bearer " - # 2. "api-key: " - # This is because Serverless API, Managed Compute and GitHub endpoints support the first header, - # and Azure OpenAI and the new Unified Inference endpoints support the second header. - # The first header will be taken care of by auto-generated code. - # The second one is added here. - if isinstance(credential, AzureKeyCredential): - headers = kwargs.pop("headers", {}) - if "api-key" not in headers: - headers["api-key"] = credential.key - kwargs["headers"] = headers - - super().__init__(endpoint, credential, **kwargs) - - @overload - async def complete( - self, - *, - messages: List[_models.ChatRequestMessage], - stream: Literal[False] = False, - frequency_penalty: Optional[float] = None, - presence_penalty: Optional[float] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, - response_format: Optional[_models.ChatCompletionsResponseFormat] = None, - stop: Optional[List[str]] = None, - tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, - tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] - ] = None, - seed: Optional[int] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> _models.ChatCompletions: ... - - @overload - async def complete( - self, - *, - messages: List[_models.ChatRequestMessage], - stream: Literal[True], - frequency_penalty: Optional[float] = None, - presence_penalty: Optional[float] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, - response_format: Optional[_models.ChatCompletionsResponseFormat] = None, - stop: Optional[List[str]] = None, - tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, - tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] - ] = None, - seed: Optional[int] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> AsyncIterable[_models.StreamingChatCompletionsUpdate]: ... - - @overload - async def complete( - self, - *, - messages: List[_models.ChatRequestMessage], - stream: Optional[bool] = None, - frequency_penalty: Optional[float] = None, - presence_penalty: Optional[float] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, - response_format: Optional[_models.ChatCompletionsResponseFormat] = None, - stop: Optional[List[str]] = None, - tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, - tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] - ] = None, - seed: Optional[int] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> Union[AsyncIterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: - # pylint: disable=line-too-long - """Gets chat completions for the provided chat messages. - Completions support a wide variety of tasks and generate text that continues from or - "completes" provided prompt data. The method makes a REST API call to the `/chat/completions` route - on the given endpoint. - When using this method with `stream=True`, the response is streamed - back to the client. Iterate over the resulting StreamingChatCompletions - object to get content updates as they arrive. By default, the response is a ChatCompletions object - (non-streaming). - - :keyword messages: The collection of context messages associated with this chat completions - request. - Typical usage begins with a chat message for the System role that provides instructions for - the behavior of the assistant, followed by alternating messages between the User and - Assistant roles. Required. - :paramtype messages: list[~azure.ai.inference.models.ChatRequestMessage] - :keyword stream: A value indicating whether chat completions should be streamed for this request. - Default value is False. If streaming is enabled, the response will be a StreamingChatCompletions. - Otherwise the response will be a ChatCompletions. - :paramtype stream: bool - :keyword frequency_penalty: A value that influences the probability of generated tokens - appearing based on their cumulative frequency in generated text. - Positive values will make tokens less likely to appear as their frequency increases and - decrease the likelihood of the model repeating the same statements verbatim. - Supported range is [-2, 2]. - Default value is None. - :paramtype frequency_penalty: float - :keyword presence_penalty: A value that influences the probability of generated tokens - appearing based on their existing - presence in generated text. - Positive values will make tokens less likely to appear when they already exist and increase - the model's likelihood to output new topics. - Supported range is [-2, 2]. - Default value is None. - :paramtype presence_penalty: float - :keyword temperature: The sampling temperature to use that controls the apparent creativity of - generated completions. - Higher values will make output more random while lower values will make results more focused - and deterministic. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. - Default value is None. - :paramtype temperature: float - :keyword top_p: An alternative to sampling with temperature called nucleus sampling. This value - causes the - model to consider the results of tokens with the provided probability mass. As an example, a - value of 0.15 will cause only the tokens comprising the top 15% of probability mass to be - considered. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. - Default value is None. - :paramtype top_p: float - :keyword max_tokens: The maximum number of tokens to generate. Default value is None. - :paramtype max_tokens: int - :keyword response_format: The format that the model must output. Use this to enable JSON mode - instead of the default text mode. - Note that to enable JSON mode, some AI models may also require you to instruct the model to - produce JSON via a system or user message. Default value is None. - :paramtype response_format: ~azure.ai.inference.models.ChatCompletionsResponseFormat - :keyword stop: A collection of textual sequences that will end completions generation. Default - value is None. - :paramtype stop: list[str] - :keyword tools: The available tool definitions that the chat completions request can use, - including caller-defined functions. Default value is None. - :paramtype tools: list[~azure.ai.inference.models.ChatCompletionsToolDefinition] - :keyword tool_choice: If specified, the model will configure which of the provided tools it can - use for the chat completions response. Is either a Union[str, - "_models.ChatCompletionsToolChoicePreset"] type or a ChatCompletionsNamedToolChoice type. - Default value is None. - :paramtype tool_choice: str or ~azure.ai.inference.models.ChatCompletionsToolChoicePreset or - ~azure.ai.inference.models.ChatCompletionsNamedToolChoice - :keyword seed: If specified, the system will make a best effort to sample deterministically - such that repeated requests with the - same seed and parameters should return the same result. Determinism is not guaranteed. - Default value is None. - :paramtype seed: int - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :return: ChatCompletions for non-streaming, or AsyncIterable[StreamingChatCompletionsUpdate] for streaming. - :rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.AsyncStreamingChatCompletions - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - async def complete( - self, - body: JSON, - *, - content_type: str = "application/json", - **kwargs: Any, - ) -> Union[AsyncIterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: - # pylint: disable=line-too-long - """Gets chat completions for the provided chat messages. - Completions support a wide variety of tasks and generate text that continues from or - "completes" provided prompt data. - - :param body: An object of type MutableMapping[str, Any], such as a dictionary, that - specifies the full request payload. Required. - :type body: JSON - :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. - Default value is "application/json". - :paramtype content_type: str - :return: ChatCompletions for non-streaming, or AsyncIterable[StreamingChatCompletionsUpdate] for streaming. - :rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.AsyncStreamingChatCompletions - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - async def complete( - self, - body: IO[bytes], - *, - content_type: str = "application/json", - **kwargs: Any, - ) -> Union[AsyncIterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: - # pylint: disable=line-too-long - """Gets chat completions for the provided chat messages. - Completions support a wide variety of tasks and generate text that continues from or - "completes" provided prompt data. - - :param body: Specifies the full request payload. Required. - :type body: IO[bytes] - :keyword content_type: Body Parameter content-type. Content type parameter for binary body. - Default value is "application/json". - :paramtype content_type: str - :return: ChatCompletions for non-streaming, or AsyncIterable[StreamingChatCompletionsUpdate] for streaming. - :rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.AsyncStreamingChatCompletions - :raises ~azure.core.exceptions.HttpResponseError: - """ - - # pylint:disable=client-method-missing-tracing-decorator-async - async def complete( - self, - body: Union[JSON, IO[bytes]] = _Unset, - *, - messages: List[_models.ChatRequestMessage] = _Unset, - stream: Optional[bool] = None, - frequency_penalty: Optional[float] = None, - presence_penalty: Optional[float] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - max_tokens: Optional[int] = None, - response_format: Optional[_models.ChatCompletionsResponseFormat] = None, - stop: Optional[List[str]] = None, - tools: Optional[List[_models.ChatCompletionsToolDefinition]] = None, - tool_choice: Optional[ - Union[str, _models.ChatCompletionsToolChoicePreset, _models.ChatCompletionsNamedToolChoice] - ] = None, - seed: Optional[int] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> Union[AsyncIterable[_models.StreamingChatCompletionsUpdate], _models.ChatCompletions]: - # pylint: disable=line-too-long - # pylint: disable=too-many-locals - """Gets chat completions for the provided chat messages. - Completions support a wide variety of tasks and generate text that continues from or - "completes" provided prompt data. When using this method with `stream=True`, the response is streamed - back to the client. Iterate over the resulting :class:`~azure.ai.inference.models.StreamingChatCompletions` - object to get content updates as they arrive. - - :param body: Is either a MutableMapping[str, Any] type (like a dictionary) or a IO[bytes] type - that specifies the full request payload. Required. - :type body: JSON or IO[bytes] - :keyword messages: The collection of context messages associated with this chat completions - request. - Typical usage begins with a chat message for the System role that provides instructions for - the behavior of the assistant, followed by alternating messages between the User and - Assistant roles. Required. - :paramtype messages: list[~azure.ai.inference.models.ChatRequestMessage] - :keyword stream: A value indicating whether chat completions should be streamed for this request. - Default value is False. If streaming is enabled, the response will be a StreamingChatCompletions. - Otherwise the response will be a ChatCompletions. - :paramtype stream: bool - :keyword frequency_penalty: A value that influences the probability of generated tokens - appearing based on their cumulative frequency in generated text. - Positive values will make tokens less likely to appear as their frequency increases and - decrease the likelihood of the model repeating the same statements verbatim. - Supported range is [-2, 2]. - Default value is None. - :paramtype frequency_penalty: float - :keyword presence_penalty: A value that influences the probability of generated tokens - appearing based on their existing - presence in generated text. - Positive values will make tokens less likely to appear when they already exist and increase - the model's likelihood to output new topics. - Supported range is [-2, 2]. - Default value is None. - :paramtype presence_penalty: float - :keyword temperature: The sampling temperature to use that controls the apparent creativity of - generated completions. - Higher values will make output more random while lower values will make results more focused - and deterministic. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. - Default value is None. - :paramtype temperature: float - :keyword top_p: An alternative to sampling with temperature called nucleus sampling. This value - causes the - model to consider the results of tokens with the provided probability mass. As an example, a - value of 0.15 will cause only the tokens comprising the top 15% of probability mass to be - considered. - It is not recommended to modify temperature and top_p for the same completions request as the - interaction of these two settings is difficult to predict. - Supported range is [0, 1]. - Default value is None. - :paramtype top_p: float - :keyword max_tokens: The maximum number of tokens to generate. Default value is None. - :paramtype max_tokens: int - :keyword response_format: The format that the model must output. Use this to enable JSON mode - instead of the default text mode. - Note that to enable JSON mode, some AI models may also require you to instruct the model to - produce JSON via a system or user message. Default value is None. - :paramtype response_format: ~azure.ai.inference.models.ChatCompletionsResponseFormat - :keyword stop: A collection of textual sequences that will end completions generation. Default - value is None. - :paramtype stop: list[str] - :keyword tools: The available tool definitions that the chat completions request can use, - including caller-defined functions. Default value is None. - :paramtype tools: list[~azure.ai.inference.models.ChatCompletionsToolDefinition] - :keyword tool_choice: If specified, the model will configure which of the provided tools it can - use for the chat completions response. Is either a Union[str, - "_models.ChatCompletionsToolChoicePreset"] type or a ChatCompletionsNamedToolChoice type. - Default value is None. - :paramtype tool_choice: str or ~azure.ai.inference.models.ChatCompletionsToolChoicePreset or - ~azure.ai.inference.models.ChatCompletionsNamedToolChoice - :keyword seed: If specified, the system will make a best effort to sample deterministically - such that repeated requests with the - same seed and parameters should return the same result. Determinism is not guaranteed. - Default value is None. - :paramtype seed: int - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :return: ChatCompletions for non-streaming, or AsyncIterable[StreamingChatCompletionsUpdate] for streaming. - :rtype: ~azure.ai.inference.models.ChatCompletions or ~azure.ai.inference.models.AsyncStreamingChatCompletions - :raises ~azure.core.exceptions.HttpResponseError: - """ - error_map = { - 401: ClientAuthenticationError, - 404: ResourceNotFoundError, - 409: ResourceExistsError, - 304: ResourceNotModifiedError, - } - error_map.update(kwargs.pop("error_map", {}) or {}) - - _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) - _params = kwargs.pop("params", {}) or {} - _extra_parameters: Union[_models._enums.ExtraParameters, None] = None - - content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - - if body is _Unset: - if messages is _Unset: - raise TypeError("missing required argument: messages") - body = { - "messages": messages, - "stream": stream, - "frequency_penalty": frequency_penalty if frequency_penalty is not None else self._frequency_penalty, - "max_tokens": max_tokens if max_tokens is not None else self._max_tokens, - "model": model if model is not None else self._model, - "presence_penalty": presence_penalty if presence_penalty is not None else self._presence_penalty, - "response_format": response_format if response_format is not None else self._response_format, - "seed": seed if seed is not None else self._seed, - "stop": stop if stop is not None else self._stop, - "temperature": temperature if temperature is not None else self._temperature, - "tool_choice": tool_choice if tool_choice is not None else self._tool_choice, - "tools": tools if tools is not None else self._tools, - "top_p": top_p if top_p is not None else self._top_p, - } - if model_extras is not None and bool(model_extras): - body.update(model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access - elif self._model_extras is not None and bool(self._model_extras): - body.update(self._model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access - body = {k: v for k, v in body.items() if v is not None} - elif isinstance(body, dict) and "stream" in body and isinstance(body["stream"], bool): - stream = body["stream"] - content_type = content_type or "application/json" - _content = None - if isinstance(body, (IOBase, bytes)): - _content = body - else: - _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore - - _request = build_chat_completions_complete_request( - extra_params=_extra_parameters, - content_type=content_type, - api_version=self._config.api_version, - content=_content, - headers=_headers, - params=_params, - ) - path_format_arguments = { - "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), - } - _request.url = self._client.format_url(_request.url, **path_format_arguments) - - _stream = stream or False - pipeline_response: PipelineResponse = await self._client._pipeline.run( # type: ignore # pylint: disable=protected-access - _request, stream=_stream, **kwargs - ) - - response = pipeline_response.http_response - - if response.status_code not in [200]: - if _stream: - await response.read() # Load the body in memory and close the socket - map_error(status_code=response.status_code, response=response, error_map=error_map) - raise HttpResponseError(response=response) - - if _stream: - return _models.AsyncStreamingChatCompletions(response) - - return _deserialize(_models._patch.ChatCompletions, response.json()) # pylint: disable=protected-access - - @distributed_trace_async - async def get_model_info(self, **kwargs: Any) -> _models.ModelInfo: - # pylint: disable=line-too-long - """Returns information about the AI model. - The method makes a REST API call to the ``/info`` route on the given endpoint. - This method will only work when using Serverless API or Managed Compute endpoint. - It will not work for GitHub Models endpoint or Azure OpenAI endpoint. - - :return: ModelInfo. The ModelInfo is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.ModelInfo - :raises ~azure.core.exceptions.HttpResponseError: - """ - if not self._model_info: - self._model_info = await self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init - return self._model_info - - def __str__(self) -> str: - # pylint: disable=client-method-name-no-double-underscore - return super().__str__() + f"\n{self._model_info}" if self._model_info else super().__str__() - - -class EmbeddingsClient(EmbeddingsClientGenerated): - """EmbeddingsClient. - - :param endpoint: Service host. Required. - :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is either a - AzureKeyCredential type or a AsyncTokenCredential type. Required. - :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials_async.AsyncTokenCredential - :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should - have. Default value is None. - :paramtype dimensions: int - :keyword encoding_format: Optional. The desired format for the returned embeddings. - Known values are: - "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. - :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat - :keyword input_type: Optional. The type of the input. Known values are: - "text", "query", and "document". Default value is None. - :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :keyword api_version: The API version to use for this operation. Default value is - "2024-05-01-preview". Note that overriding this default value may result in unsupported - behavior. - :paramtype api_version: str - """ - - def __init__( - self, - endpoint: str, - credential: Union[AzureKeyCredential, "AsyncTokenCredential"], - *, - dimensions: Optional[int] = None, - encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, - input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> None: - - self._model_info: Optional[_models.ModelInfo] = None - - # Store default embeddings settings, to be applied in all future service calls - # unless overridden by arguments in the `embed` method. - self._dimensions = dimensions - self._encoding_format = encoding_format - self._input_type = input_type - self._model = model - self._model_extras = model_extras - - # For Key auth, we need to send these two auth HTTP request headers simultaneously: - # 1. "Authorization: Bearer " - # 2. "api-key: " - # This is because Serverless API, Managed Compute and GitHub endpoints support the first header, - # and Azure OpenAI and the new Unified Inference endpoints support the second header. - # The first header will be taken care of by auto-generated code. - # The second one is added here. - if isinstance(credential, AzureKeyCredential): - headers = kwargs.pop("headers", {}) - if "api-key" not in headers: - headers["api-key"] = credential.key - kwargs["headers"] = headers - - super().__init__(endpoint, credential, **kwargs) - - @overload - async def embed( - self, - *, - input: List[str], - dimensions: Optional[int] = None, - encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, - input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> _models.EmbeddingsResult: - """Return the embedding vectors for given text prompts. - The method makes a REST API call to the `/embeddings` route on the given endpoint. - - :keyword input: Input text to embed, encoded as a string or array of tokens. - To embed multiple inputs in a single request, pass an array - of strings or array of token arrays. Required. - :paramtype input: list[str] - :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should - have. Default value is None. - :paramtype dimensions: int - :keyword encoding_format: Optional. The desired format for the returned embeddings. - Known values are: - "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. - :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat - :keyword input_type: Optional. The type of the input. Known values are: - "text", "query", and "document". Default value is None. - :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - async def embed( - self, - body: JSON, - *, - content_type: str = "application/json", - **kwargs: Any, - ) -> _models.EmbeddingsResult: - """Return the embedding vectors for given text prompts. - The method makes a REST API call to the `/embeddings` route on the given endpoint. - - :param body: An object of type MutableMapping[str, Any], such as a dictionary, that - specifies the full request payload. Required. - :type body: JSON - :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. - Default value is "application/json". - :paramtype content_type: str - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - async def embed( - self, - body: IO[bytes], - *, - content_type: str = "application/json", - **kwargs: Any, - ) -> _models.EmbeddingsResult: - """Return the embedding vectors for given text prompts. - The method makes a REST API call to the `/embeddings` route on the given endpoint. - - :param body: Specifies the full request payload. Required. - :type body: IO[bytes] - :keyword content_type: Body Parameter content-type. Content type parameter for binary body. - Default value is "application/json". - :paramtype content_type: str - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @distributed_trace_async - async def embed( - self, - body: Union[JSON, IO[bytes]] = _Unset, - *, - input: List[str] = _Unset, - dimensions: Optional[int] = None, - encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, - input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> _models.EmbeddingsResult: - # pylint: disable=line-too-long - """Return the embedding vectors for given text prompts. - The method makes a REST API call to the `/embeddings` route on the given endpoint. - - :param body: Is either a MutableMapping[str, Any] type (like a dictionary) or a IO[bytes] type - that specifies the full request payload. Required. - :type body: JSON or IO[bytes] - :keyword input: Input text to embed, encoded as a string or array of tokens. - To embed multiple inputs in a single request, pass an array - of strings or array of token arrays. Required. - :paramtype input: list[str] - :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should - have. Default value is None. - :paramtype dimensions: int - :keyword encoding_format: Optional. The desired format for the returned embeddings. - Known values are: - "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. - :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat - :keyword input_type: Optional. The type of the input. Known values are: - "text", "query", and "document". Default value is None. - :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { - 401: ClientAuthenticationError, - 404: ResourceNotFoundError, - 409: ResourceExistsError, - 304: ResourceNotModifiedError, - } - error_map.update(kwargs.pop("error_map", {}) or {}) - - _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) - _params = kwargs.pop("params", {}) or {} - _extra_parameters: Union[_models._enums.ExtraParameters, None] = None - - content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - - if body is _Unset: - if input is _Unset: - raise TypeError("missing required argument: input") - body = { - "input": input, - "dimensions": dimensions if dimensions is not None else self._dimensions, - "encoding_format": encoding_format if encoding_format is not None else self._encoding_format, - "input_type": input_type if input_type is not None else self._input_type, - "model": model if model is not None else self._model, - } - if model_extras is not None and bool(model_extras): - body.update(model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access - elif self._model_extras is not None and bool(self._model_extras): - body.update(self._model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access - body = {k: v for k, v in body.items() if v is not None} - content_type = content_type or "application/json" - _content = None - if isinstance(body, (IOBase, bytes)): - _content = body - else: - _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore - - _request = build_embeddings_embed_request( - extra_params=_extra_parameters, - content_type=content_type, - api_version=self._config.api_version, - content=_content, - headers=_headers, - params=_params, - ) - path_format_arguments = { - "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), - } - _request.url = self._client.format_url(_request.url, **path_format_arguments) - - _stream = kwargs.pop("stream", False) - pipeline_response: PipelineResponse = await self._client._pipeline.run( # type: ignore # pylint: disable=protected-access - _request, stream=_stream, **kwargs - ) - - response = pipeline_response.http_response - - if response.status_code not in [200]: - if _stream: - await response.read() # Load the body in memory and close the socket - map_error(status_code=response.status_code, response=response, error_map=error_map) - raise HttpResponseError(response=response) - - if _stream: - deserialized = response.iter_bytes() - else: - deserialized = _deserialize( - _models._patch.EmbeddingsResult, response.json() # pylint: disable=protected-access - ) - - return deserialized # type: ignore - - @distributed_trace_async - async def get_model_info(self, **kwargs: Any) -> _models.ModelInfo: - # pylint: disable=line-too-long - """Returns information about the AI model. - The method makes a REST API call to the ``/info`` route on the given endpoint. - This method will only work when using Serverless API or Managed Compute endpoint. - It will not work for GitHub Models endpoint or Azure OpenAI endpoint. - - :return: ModelInfo. The ModelInfo is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.ModelInfo - :raises ~azure.core.exceptions.HttpResponseError: - """ - if not self._model_info: - self._model_info = await self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init - return self._model_info - - def __str__(self) -> str: - # pylint: disable=client-method-name-no-double-underscore - return super().__str__() + f"\n{self._model_info}" if self._model_info else super().__str__() - - -class ImageEmbeddingsClient(ImageEmbeddingsClientGenerated): - """ImageEmbeddingsClient. - - :param endpoint: Service host. Required. - :type endpoint: str - :param credential: Credential used to authenticate requests to the service. Is either a - AzureKeyCredential type or a AsyncTokenCredential type. Required. - :type credential: ~azure.core.credentials.AzureKeyCredential or - ~azure.core.credentials_async.AsyncTokenCredential - :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should - have. Default value is None. - :paramtype dimensions: int - :keyword encoding_format: Optional. The desired format for the returned embeddings. - Known values are: - "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. - :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat - :keyword input_type: Optional. The type of the input. Known values are: - "text", "query", and "document". Default value is None. - :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :keyword api_version: The API version to use for this operation. Default value is - "2024-05-01-preview". Note that overriding this default value may result in unsupported - behavior. - :paramtype api_version: str - """ - - def __init__( - self, - endpoint: str, - credential: Union[AzureKeyCredential, "AsyncTokenCredential"], - *, - dimensions: Optional[int] = None, - encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, - input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> None: - - self._model_info: Optional[_models.ModelInfo] = None - - # Store default embeddings settings, to be applied in all future service calls - # unless overridden by arguments in the `embed` method. - self._dimensions = dimensions - self._encoding_format = encoding_format - self._input_type = input_type - self._model = model - self._model_extras = model_extras - - # For Key auth, we need to send these two auth HTTP request headers simultaneously: - # 1. "Authorization: Bearer " - # 2. "api-key: " - # This is because Serverless API, Managed Compute and GitHub endpoints support the first header, - # and Azure OpenAI and the new Unified Inference endpoints support the second header. - # The first header will be taken care of by auto-generated code. - # The second one is added here. - if isinstance(credential, AzureKeyCredential): - headers = kwargs.pop("headers", {}) - if "api-key" not in headers: - headers["api-key"] = credential.key - kwargs["headers"] = headers - - super().__init__(endpoint, credential, **kwargs) - - @overload - async def embed( - self, - *, - input: List[_models.EmbeddingInput], - dimensions: Optional[int] = None, - encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, - input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> _models.EmbeddingsResult: - """Return the embedding vectors for given images. - The method makes a REST API call to the `/images/embeddings` route on the given endpoint. - - :keyword input: Input image to embed. To embed multiple inputs in a single request, pass an - array. - The input must not exceed the max input tokens for the model. Required. - :paramtype input: list[~azure.ai.inference.models.EmbeddingInput] - :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should - have. Default value is None. - :paramtype dimensions: int - :keyword encoding_format: Optional. The desired format for the returned embeddings. - Known values are: - "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. - :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat - :keyword input_type: Optional. Known values are: - "text", "query", and "document". Default value is None. - :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - async def embed( - self, - body: JSON, - *, - content_type: str = "application/json", - **kwargs: Any, - ) -> _models.EmbeddingsResult: - """Return the embedding vectors for given images. - The method makes a REST API call to the `/images/embeddings` route on the given endpoint. - - :param body: An object of type MutableMapping[str, Any], such as a dictionary, that - specifies the full request payload. Required. - :type body: JSON - :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. - Default value is "application/json". - :paramtype content_type: str - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @overload - async def embed( - self, - body: IO[bytes], - *, - content_type: str = "application/json", - **kwargs: Any, - ) -> _models.EmbeddingsResult: - """Return the embedding vectors for given images. - The method makes a REST API call to the `/images/embeddings` route on the given endpoint. - - :param body: Specifies the full request payload. Required. - :type body: IO[bytes] - :keyword content_type: Body Parameter content-type. Content type parameter for binary body. - Default value is "application/json". - :paramtype content_type: str - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - - @distributed_trace_async - async def embed( - self, - body: Union[JSON, IO[bytes]] = _Unset, - *, - input: List[_models.EmbeddingInput] = _Unset, - dimensions: Optional[int] = None, - encoding_format: Optional[Union[str, _models.EmbeddingEncodingFormat]] = None, - input_type: Optional[Union[str, _models.EmbeddingInputType]] = None, - model: Optional[str] = None, - model_extras: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> _models.EmbeddingsResult: - # pylint: disable=line-too-long - """Return the embedding vectors for given images. - The method makes a REST API call to the `/images/embeddings` route on the given endpoint. - - :param body: Is either a MutableMapping[str, Any] type (like a dictionary) or a IO[bytes] type - that specifies the full request payload. Required. - :type body: JSON or IO[bytes] - :keyword input: Input image to embed. To embed multiple inputs in a single request, pass an - array. - The input must not exceed the max input tokens for the model. Required. - :paramtype input: list[~azure.ai.inference.models.EmbeddingInput] - :keyword dimensions: Optional. The number of dimensions the resulting output embeddings should - have. Default value is None. - :paramtype dimensions: int - :keyword encoding_format: Optional. The desired format for the returned embeddings. - Known values are: - "base64", "binary", "float", "int8", "ubinary", and "uint8". Default value is None. - :paramtype encoding_format: str or ~azure.ai.inference.models.EmbeddingEncodingFormat - :keyword input_type: Optional. The type of the input. Known values are: - "text", "query", and "document". Default value is None. - :paramtype input_type: str or ~azure.ai.inference.models.EmbeddingInputType - :keyword model: ID of the specific AI model to use, if more than one model is available on the - endpoint. Default value is None. - :paramtype model: str - :keyword model_extras: Additional, model-specific parameters that are not in the - standard request payload. They will be added as-is to the root of the JSON in the request body. - How the service handles these extra parameters depends on the value of the - ``extra-parameters`` request header. Default value is None. - :paramtype model_extras: dict[str, Any] - :return: EmbeddingsResult. The EmbeddingsResult is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.EmbeddingsResult - :raises ~azure.core.exceptions.HttpResponseError: - """ - error_map: MutableMapping[int, Type[HttpResponseError]] = { - 401: ClientAuthenticationError, - 404: ResourceNotFoundError, - 409: ResourceExistsError, - 304: ResourceNotModifiedError, - } - error_map.update(kwargs.pop("error_map", {}) or {}) - - _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) - _params = kwargs.pop("params", {}) or {} - _extra_parameters: Union[_models._enums.ExtraParameters, None] = None - - content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - - if body is _Unset: - if input is _Unset: - raise TypeError("missing required argument: input") - body = { - "input": input, - "dimensions": dimensions if dimensions is not None else self._dimensions, - "encoding_format": encoding_format if encoding_format is not None else self._encoding_format, - "input_type": input_type if input_type is not None else self._input_type, - "model": model if model is not None else self._model, - } - if model_extras is not None and bool(model_extras): - body.update(model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access - elif self._model_extras is not None and bool(self._model_extras): - body.update(self._model_extras) - _extra_parameters = _models._enums.ExtraParameters.PASS_THROUGH # pylint: disable=protected-access - body = {k: v for k, v in body.items() if v is not None} - content_type = content_type or "application/json" - _content = None - if isinstance(body, (IOBase, bytes)): - _content = body - else: - _content = json.dumps(body, cls=SdkJSONEncoder, exclude_readonly=True) # type: ignore - - _request = build_image_embeddings_embed_request( - extra_params=_extra_parameters, - content_type=content_type, - api_version=self._config.api_version, - content=_content, - headers=_headers, - params=_params, - ) - path_format_arguments = { - "endpoint": self._serialize.url("self._config.endpoint", self._config.endpoint, "str", skip_quote=True), - } - _request.url = self._client.format_url(_request.url, **path_format_arguments) - - _stream = kwargs.pop("stream", False) - pipeline_response: PipelineResponse = await self._client._pipeline.run( # type: ignore # pylint: disable=protected-access - _request, stream=_stream, **kwargs - ) - - response = pipeline_response.http_response - - if response.status_code not in [200]: - if _stream: - await response.read() # Load the body in memory and close the socket - map_error(status_code=response.status_code, response=response, error_map=error_map) - raise HttpResponseError(response=response) - - if _stream: - deserialized = response.iter_bytes() - else: - deserialized = _deserialize( - _models._patch.EmbeddingsResult, response.json() # pylint: disable=protected-access - ) - - return deserialized # type: ignore - - @distributed_trace_async - async def get_model_info(self, **kwargs: Any) -> _models.ModelInfo: - # pylint: disable=line-too-long - """Returns information about the AI model. - The method makes a REST API call to the ``/info`` route on the given endpoint. - This method will only work when using Serverless API or Managed Compute endpoint. - It will not work for GitHub Models endpoint or Azure OpenAI endpoint. - - :return: ModelInfo. The ModelInfo is compatible with MutableMapping - :rtype: ~azure.ai.inference.models.ModelInfo - :raises ~azure.core.exceptions.HttpResponseError: - """ - if not self._model_info: - self._model_info = await self._get_model_info(**kwargs) # pylint: disable=attribute-defined-outside-init - return self._model_info - - def __str__(self) -> str: - # pylint: disable=client-method-name-no-double-underscore - return super().__str__() + f"\n{self._model_info}" if self._model_info else super().__str__() - - -__all__: List[str] = [ - "load_client", - "ChatCompletionsClient", - "EmbeddingsClient", - "ImageEmbeddingsClient", -] # Add all objects you want publicly available to users at this package level +__all__: List[str] = [] # Add all objects you want publicly available to users at this package level def patch_sdk(): diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_vendor.py b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_vendor.py index dd91e1ea130f..b430582ca1fc 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_vendor.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/aio/_vendor.py @@ -15,7 +15,6 @@ ) if TYPE_CHECKING: - # pylint: disable=unused-import,ungrouped-imports from azure.core import AsyncPipelineClient from .._serialization import Deserializer, Serializer diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/models/__init__.py b/sdk/ai/azure-ai-inference/azure/ai/inference/models/__init__.py index 1832edc83399..bebd5c247a23 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/models/__init__.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/models/__init__.py @@ -5,54 +5,62 @@ # Code generated by Microsoft (R) Python Code Generator. # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- +# pylint: disable=wrong-import-position -from ._models import AssistantMessage -from ._models import ChatChoice -from ._patch import ChatCompletions -from ._models import ChatCompletionsNamedToolChoice -from ._models import ChatCompletionsNamedToolChoiceFunction -from ._models import ChatCompletionsResponseFormat -from ._models import ChatCompletionsResponseFormatJSON -from ._models import ChatCompletionsResponseFormatText -from ._models import ChatCompletionsToolCall -from ._models import ChatCompletionsToolDefinition -from ._models import ChatRequestMessage -from ._models import ChatResponseMessage -from ._models import CompletionsUsage -from ._models import ContentItem -from ._models import EmbeddingInput -from ._models import EmbeddingItem -from ._patch import EmbeddingsResult -from ._models import EmbeddingsUsage -from ._models import FunctionCall -from ._models import FunctionDefinition -from ._models import ImageContentItem -from ._patch import ImageUrl -from ._models import ModelInfo -from ._models import StreamingChatChoiceUpdate -from ._models import StreamingChatCompletionsUpdate -from ._models import StreamingChatResponseMessageUpdate -from ._models import StreamingChatResponseToolCallUpdate -from ._models import SystemMessage -from ._models import TextContentItem -from ._models import ToolMessage -from ._models import UserMessage +from typing import TYPE_CHECKING -from ._enums import ChatCompletionsToolChoicePreset -from ._enums import ChatRole -from ._enums import CompletionsFinishReason -from ._enums import EmbeddingEncodingFormat -from ._enums import EmbeddingInputType -from ._enums import ImageDetailLevel -from ._enums import ModelType +if TYPE_CHECKING: + from ._patch import * # pylint: disable=unused-wildcard-import -from ._patch import StreamingChatCompletions -from ._patch import AsyncStreamingChatCompletions + +from ._models import ( # type: ignore + AssistantMessage, + ChatChoice, + ChatCompletions, + ChatCompletionsNamedToolChoice, + ChatCompletionsNamedToolChoiceFunction, + ChatCompletionsResponseFormat, + ChatCompletionsResponseFormatJSON, + ChatCompletionsResponseFormatText, + ChatCompletionsToolCall, + ChatCompletionsToolDefinition, + ChatRequestMessage, + ChatResponseMessage, + CompletionsUsage, + ContentItem, + EmbeddingItem, + EmbeddingsResult, + EmbeddingsUsage, + FunctionCall, + FunctionDefinition, + ImageContentItem, + ImageEmbeddingInput, + ImageUrl, + ModelInfo, + StreamingChatChoiceUpdate, + StreamingChatCompletionsUpdate, + StreamingChatResponseMessageUpdate, + StreamingChatResponseToolCallUpdate, + SystemMessage, + TextContentItem, + ToolMessage, + UserMessage, +) + +from ._enums import ( # type: ignore + ChatCompletionsToolChoicePreset, + ChatRole, + CompletionsFinishReason, + EmbeddingEncodingFormat, + EmbeddingInputType, + ImageDetailLevel, + ModelType, +) +from ._patch import __all__ as _patch_all +from ._patch import * from ._patch import patch_sdk as _patch_sdk __all__ = [ - "StreamingChatCompletions", - "AsyncStreamingChatCompletions", "AssistantMessage", "ChatChoice", "ChatCompletions", @@ -67,13 +75,13 @@ "ChatResponseMessage", "CompletionsUsage", "ContentItem", - "EmbeddingInput", "EmbeddingItem", "EmbeddingsResult", "EmbeddingsUsage", "FunctionCall", "FunctionDefinition", "ImageContentItem", + "ImageEmbeddingInput", "ImageUrl", "ModelInfo", "StreamingChatChoiceUpdate", @@ -92,5 +100,5 @@ "ImageDetailLevel", "ModelType", ] - +__all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/models/_enums.py b/sdk/ai/azure-ai-inference/azure/ai/inference/models/_enums.py index 830a93f75472..61443cbfbb85 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/models/_enums.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/models/_enums.py @@ -121,14 +121,14 @@ class ModelType(str, Enum, metaclass=CaseInsensitiveEnumMeta): """The type of AI model.""" EMBEDDINGS = "embeddings" - """Embeddings.""" + """A model capable of generating embeddings from a text""" IMAGE_GENERATION = "image_generation" - """Image generation""" + """A model capable of generating images from an image and text description""" TEXT_GENERATION = "text_generation" - """Text generation""" + """A text generation model""" IMAGE_EMBEDDINGS = "image_embeddings" - """Image embeddings""" + """A model capable of generating embeddings from an image""" AUDIO_GENERATION = "audio_generation" - """Audio generation""" - CHAT = "chat" - """Chat completions""" + """A text-to-audio generative model""" + CHAT_COMPLETION = "chat_completion" + """A model capable of taking chat-formatted messages and generate responses""" diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/models/_models.py b/sdk/ai/azure-ai-inference/azure/ai/inference/models/_models.py index 4ac8f16f94d1..8439f33fc234 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/models/_models.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/models/_models.py @@ -1,11 +1,12 @@ -# coding=utf-8 # pylint: disable=too-many-lines +# coding=utf-8 # -------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # Code generated by Microsoft (R) Python Code Generator. # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- +# pylint: disable=useless-super-delegation import datetime from typing import Any, Dict, List, Literal, Mapping, Optional, TYPE_CHECKING, Union, overload @@ -15,7 +16,6 @@ from ._enums import ChatRole if TYPE_CHECKING: - # pylint: disable=unused-import,ungrouped-imports from .. import models as _models @@ -42,16 +42,16 @@ def __init__( self, *, role: str, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -88,16 +88,16 @@ def __init__( *, content: Optional[str] = None, tool_calls: Optional[List["_models.ChatCompletionsToolCall"]] = None, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, role=ChatRole.ASSISTANT, **kwargs) @@ -132,16 +132,16 @@ def __init__( index: int, finish_reason: Union[str, "_models.CompletionsFinishReason"], message: "_models.ChatResponseMessage", - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -193,16 +193,16 @@ def __init__( model: str, usage: "_models.CompletionsUsage", choices: List["_models.ChatChoice"], - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -232,10 +232,10 @@ def __init__( self, *, function: "_models.ChatCompletionsNamedToolChoiceFunction", - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] @@ -264,16 +264,16 @@ def __init__( self, *, name: str, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -303,16 +303,16 @@ def __init__( self, *, type: str, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -336,16 +336,16 @@ class ChatCompletionsResponseFormatJSON(ChatCompletionsResponseFormat, discrimin @overload def __init__( self, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, type="json_object", **kwargs) @@ -366,16 +366,16 @@ class ChatCompletionsResponseFormatText(ChatCompletionsResponseFormat, discrimin @overload def __init__( self, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, type="text", **kwargs) @@ -408,10 +408,10 @@ def __init__( *, id: str, # pylint: disable=redefined-builtin function: "_models.FunctionCall", - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] @@ -447,10 +447,10 @@ def __init__( self, *, function: "_models.FunctionDefinition", - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] @@ -493,16 +493,16 @@ def __init__( role: Union[str, "_models.ChatRole"], content: str, tool_calls: Optional[List["_models.ChatCompletionsToolCall"]] = None, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -537,16 +537,16 @@ def __init__( completion_tokens: int, prompt_tokens: int, total_tokens: int, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -571,53 +571,16 @@ def __init__( self, *, type: str, - ): ... - - @overload - def __init__(self, mapping: Mapping[str, Any]): - """ - :param mapping: raw JSON to initialize the model. - :type mapping: Mapping[str, Any] - """ - - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation - super().__init__(*args, **kwargs) - - -class EmbeddingInput(_model_base.Model): - """Represents an image with optional text. - - All required parameters must be populated in order to send to server. - - :ivar image: The input image, in PNG format. Required. - :vartype image: str - :ivar text: Optional. The text input to feed into the model (like DINO, CLIP). - Returns a 422 error if the model doesn't support the value or parameter. - :vartype text: str - """ - - image: str = rest_field() - """The input image, in PNG format. Required.""" - text: Optional[str] = rest_field() - """Optional. The text input to feed into the model (like DINO, CLIP). - Returns a 422 error if the model doesn't support the value or parameter.""" + ) -> None: ... @overload - def __init__( - self, - *, - image: str, - text: Optional[str] = None, - ): ... - - @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -634,7 +597,7 @@ class EmbeddingItem(_model_base.Model): :vartype index: int """ - embedding: Union["str", List[float]] = rest_field() + embedding: Union[str, List[float]] = rest_field() """List of embedding values for the input prompt. These represent a measurement of the vector-based relatedness of the provided input. Or a base64 encoded string of the embedding vector. Required. Is either a str type or a [float] type.""" @@ -647,16 +610,16 @@ def __init__( *, embedding: Union[str, List[float]], index: int, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -694,16 +657,16 @@ def __init__( data: List["_models.EmbeddingItem"], usage: "_models.EmbeddingsUsage", model: str, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -731,16 +694,16 @@ def __init__( *, prompt_tokens: int, total_tokens: int, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -772,16 +735,16 @@ def __init__( *, name: str, arguments: str, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -817,16 +780,16 @@ def __init__( name: str, description: Optional[str] = None, parameters: Optional[Any] = None, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -855,19 +818,58 @@ def __init__( self, *, image_url: "_models.ImageUrl", - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, type="image_url", **kwargs) +class ImageEmbeddingInput(_model_base.Model): + """Represents an image with optional text. + + All required parameters must be populated in order to send to server. + + :ivar image: The input image encoded in base64 string as a data URL. Example: + ``data:image/{format};base64,{data}``. Required. + :vartype image: str + :ivar text: Optional. The text input to feed into the model (like DINO, CLIP). + Returns a 422 error if the model doesn't support the value or parameter. + :vartype text: str + """ + + image: str = rest_field() + """The input image encoded in base64 string as a data URL. Example: + ``data:image/{format};base64,{data}``. Required.""" + text: Optional[str] = rest_field() + """Optional. The text input to feed into the model (like DINO, CLIP). + Returns a 422 error if the model doesn't support the value or parameter.""" + + @overload + def __init__( + self, + *, + image: str, + text: Optional[str] = None, + ) -> None: ... + + @overload + def __init__(self, mapping: Mapping[str, Any]) -> None: + """ + :param mapping: raw JSON to initialize the model. + :type mapping: Mapping[str, Any] + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + class ImageUrl(_model_base.Model): """An internet location from which the model may retrieve an image. @@ -894,16 +896,16 @@ def __init__( *, url: str, detail: Optional[Union[str, "_models.ImageDetailLevel"]] = None, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -915,7 +917,7 @@ class ModelInfo(_model_base.Model): :vartype model_name: str :ivar model_type: The type of the AI model. A Unique identifier for the profile. Required. Known values are: "embeddings", "image_generation", "text_generation", "image_embeddings", - "audio_generation", and "chat". + "audio_generation", and "chat_completion". :vartype model_type: str or ~azure.ai.inference.models.ModelType :ivar model_provider_name: The model provider name. For example: ``Microsoft Research``. Required. @@ -927,7 +929,7 @@ class ModelInfo(_model_base.Model): model_type: Union[str, "_models.ModelType"] = rest_field() """The type of the AI model. A Unique identifier for the profile. Required. Known values are: \"embeddings\", \"image_generation\", \"text_generation\", \"image_embeddings\", - \"audio_generation\", and \"chat\".""" + \"audio_generation\", and \"chat_completion\".""" model_provider_name: str = rest_field() """The model provider name. For example: ``Microsoft Research``. Required.""" @@ -938,16 +940,16 @@ def __init__( model_name: str, model_type: Union[str, "_models.ModelType"], model_provider_name: str, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -982,16 +984,16 @@ def __init__( index: int, finish_reason: Union[str, "_models.CompletionsFinishReason"], delta: "_models.StreamingChatResponseMessageUpdate", - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -1046,16 +1048,16 @@ def __init__( model: str, usage: "_models.CompletionsUsage", choices: List["_models.StreamingChatChoiceUpdate"], - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -1090,16 +1092,16 @@ def __init__( role: Optional[Union[str, "_models.ChatRole"]] = None, content: Optional[str] = None, tool_calls: Optional[List["_models.StreamingChatResponseToolCallUpdate"]] = None, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -1124,16 +1126,16 @@ def __init__( *, id: str, # pylint: disable=redefined-builtin function: "_models.FunctionCall", - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -1162,16 +1164,16 @@ def __init__( self, *, content: str, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, role=ChatRole.SYSTEM, **kwargs) @@ -1198,16 +1200,16 @@ def __init__( self, *, text: str, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, type="text", **kwargs) @@ -1240,16 +1242,16 @@ def __init__( *, content: str, tool_call_id: str, - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, role=ChatRole.TOOL, **kwargs) @@ -1269,7 +1271,7 @@ class UserMessage(ChatRequestMessage, discriminator="user"): role: Literal[ChatRole.USER] = rest_discriminator(name="role") # type: ignore """The chat role associated with this message, which is always 'user' for user messages. Required. The role that provides input for chat completions.""" - content: Union["str", List["_models.ContentItem"]] = rest_field() + content: Union[str, List["_models.ContentItem"]] = rest_field() """The contents of the user message, with available input types varying by selected model. Required. Is either a str type or a [ContentItem] type.""" @@ -1278,14 +1280,14 @@ def __init__( self, *, content: Union[str, List["_models.ContentItem"]], - ): ... + ) -> None: ... @overload - def __init__(self, mapping: Mapping[str, Any]): + def __init__(self, mapping: Mapping[str, Any]) -> None: """ :param mapping: raw JSON to initialize the model. :type mapping: Mapping[str, Any] """ - def __init__(self, *args: Any, **kwargs: Any) -> None: # pylint: disable=useless-super-delegation + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, role=ChatRole.USER, **kwargs) diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/models/_patch.py b/sdk/ai/azure-ai-inference/azure/ai/inference/models/_patch.py index 61c718eea63f..f7dd32510333 100644 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/models/_patch.py +++ b/sdk/ai/azure-ai-inference/azure/ai/inference/models/_patch.py @@ -6,273 +6,9 @@ Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize """ -import asyncio -import base64 -import json -import logging -import queue -import re -import sys +from typing import List -from typing import Any, List, AsyncIterator, Iterator, Optional, Union -from azure.core.rest import HttpResponse, AsyncHttpResponse -from ._models import ImageUrl as ImageUrlGenerated -from ._models import ChatCompletions as ChatCompletionsGenerated -from ._models import EmbeddingsResult as EmbeddingsResultGenerated -from .. import models as _models - -if sys.version_info >= (3, 11): - from typing import Self -else: - from typing_extensions import Self - -logger = logging.getLogger(__name__) - - -class ChatCompletions(ChatCompletionsGenerated): - """Representation of the response data from a chat completions request. - Completions support a wide variety of tasks and generate text that continues from or - "completes" - provided prompt data. - - - :ivar id: A unique identifier associated with this chat completions response. Required. - :vartype id: str - :ivar created: The first timestamp associated with generation activity for this completions - response, - represented as seconds since the beginning of the Unix epoch of 00:00 on 1 Jan 1970. Required. - :vartype created: ~datetime.datetime - :ivar model: The model used for the chat completion. Required. - :vartype model: str - :ivar usage: Usage information for tokens processed and generated as part of this completions - operation. Required. - :vartype usage: ~azure.ai.inference.models.CompletionsUsage - :ivar choices: The collection of completions choices associated with this completions response. - Generally, ``n`` choices are generated per provided prompt with a default value of 1. - Token limits and other settings may limit the number of choices generated. Required. - :vartype choices: list[~azure.ai.inference.models.ChatChoice] - """ - - def __str__(self) -> str: - # pylint: disable=client-method-name-no-double-underscore - return json.dumps(self.as_dict(), indent=2) - - -class EmbeddingsResult(EmbeddingsResultGenerated): - """Representation of the response data from an embeddings request. - Embeddings measure the relatedness of text strings and are commonly used for search, - clustering, - recommendations, and other similar scenarios. - - - :ivar data: Embedding values for the prompts submitted in the request. Required. - :vartype data: list[~azure.ai.inference.models.EmbeddingItem] - :ivar usage: Usage counts for tokens input using the embeddings API. Required. - :vartype usage: ~azure.ai.inference.models.EmbeddingsUsage - :ivar model: The model ID used to generate this result. Required. - :vartype model: str - """ - - def __str__(self) -> str: - # pylint: disable=client-method-name-no-double-underscore - return json.dumps(self.as_dict(), indent=2) - - -class ImageUrl(ImageUrlGenerated): - - @classmethod - def load( - cls, *, image_file: str, image_format: str, detail: Optional[Union[str, "_models.ImageDetailLevel"]] = None - ) -> Self: - """ - Create an ImageUrl object from a local image file. The method reads the image - file and encodes it as a base64 string, which together with the image format - is then used to format the JSON `url` value passed in the request payload. - - :ivar image_file: The name of the local image file to load. Required. - :vartype image_file: str - :ivar image_format: The MIME type format of the image. For example: "jpeg", "png". Required. - :vartype image_format: str - :ivar detail: The evaluation quality setting to use, which controls relative prioritization of - speed, token consumption, and accuracy. Known values are: "auto", "low", and "high". - :vartype detail: str or ~azure.ai.inference.models.ImageDetailLevel - :return: An ImageUrl object with the image data encoded as a base64 string. - :rtype: ~azure.ai.inference.models.ImageUrl - :raises FileNotFoundError: when the image file could not be opened. - """ - with open(image_file, "rb") as f: - image_data = base64.b64encode(f.read()).decode("utf-8") - url = f"data:image/{image_format};base64,{image_data}" - return cls(url=url, detail=detail) - - -class BaseStreamingChatCompletions: - """A base class for the sync and async streaming chat completions responses, holding any common code - to deserializes the Server Sent Events (SSE) response stream into chat completions updates, each one - represented by a StreamingChatCompletionsUpdate object. - """ - - # Enable detailed logs of SSE parsing. For development only, should be `False` by default. - _ENABLE_CLASS_LOGS = False - - # The prefix of each line in the SSE stream that contains a JSON string - # to deserialize into a StreamingChatCompletionsUpdate object - _SSE_DATA_EVENT_PREFIX = "data: " - - # The line indicating the end of the SSE stream - _SSE_DATA_EVENT_DONE = "data: [DONE]" - - def __init__(self): - self._queue: "queue.Queue[_models.StreamingChatCompletionsUpdate]" = queue.Queue() - self._incomplete_json = "" - self._done = False # Will be set to True when reading 'data: [DONE]' line - - def _deserialize_and_add_to_queue(self, element: bytes) -> bool: - - # Clear the queue of StreamingChatCompletionsUpdate before processing the next block - self._queue.queue.clear() - - # Convert `bytes` to string and split the string by newline, while keeping the new line char. - # the last may be a partial "line" that does not contain a newline char at the end. - line_list: List[str] = re.split(r"(?<=\n)", element.decode("utf-8")) - for index, line in enumerate(line_list): - - if self._ENABLE_CLASS_LOGS: - logger.debug("[Original line] %s", repr(line)) - - if index == 0: - line = self._incomplete_json + line - self._incomplete_json = "" - - if index == len(line_list) - 1 and not line.endswith("\n"): - self._incomplete_json = line - return False - - if self._ENABLE_CLASS_LOGS: - logger.debug("[Modified line] %s", repr(line)) - - if line == "\n": # Empty line, indicating flush output to client - continue - - if not line.startswith(self._SSE_DATA_EVENT_PREFIX): - raise ValueError(f"SSE event not supported (line `{line}`)") - - if line.startswith(self._SSE_DATA_EVENT_DONE): - if self._ENABLE_CLASS_LOGS: - logger.debug("[Done]") - return True - - # If you reached here, the line should contain `data: {...}\n` - # where the curly braces contain a valid JSON object. - # Deserialize it into a StreamingChatCompletionsUpdate object - # and add it to the queue. - # pylint: disable=W0212 # Access to a protected member _deserialize of a client class - update = _models.StreamingChatCompletionsUpdate._deserialize( - json.loads(line[len(self._SSE_DATA_EVENT_PREFIX) : -1]), [] - ) - - # We skip any update that has a None or empty choices list - # (this is what OpenAI Python SDK does) - if update.choices: - - # We update all empty content strings to None - # (this is what OpenAI Python SDK does) - # for choice in update.choices: - # if not choice.delta.content: - # choice.delta.content = None - - self._queue.put(update) - - if self._ENABLE_CLASS_LOGS: - logger.debug("[Added to queue]") - - return False - - -class StreamingChatCompletions(BaseStreamingChatCompletions): - """Represents an interator over StreamingChatCompletionsUpdate objects. It can be used for either synchronous or - asynchronous iterations. The class deserializes the Server Sent Events (SSE) response stream - into chat completions updates, each one represented by a StreamingChatCompletionsUpdate object. - """ - - def __init__(self, response: HttpResponse): - super().__init__() - self._response = response - self._bytes_iterator: Iterator[bytes] = response.iter_bytes() - - def __iter__(self) -> Any: - return self - - def __next__(self) -> "_models.StreamingChatCompletionsUpdate": - while self._queue.empty() and not self._done: - self._done = self._read_next_block() - if self._queue.empty(): - raise StopIteration - return self._queue.get() - - def _read_next_block(self) -> bool: - if self._ENABLE_CLASS_LOGS: - logger.debug("[Reading next block]") - try: - element = self._bytes_iterator.__next__() - except StopIteration: - self.close() - return True - return self._deserialize_and_add_to_queue(element) - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # type: ignore - self.close() - - def close(self) -> None: - self._response.close() - - -class AsyncStreamingChatCompletions(BaseStreamingChatCompletions): - """Represents an async interator over StreamingChatCompletionsUpdate objects. - It can be used for either synchronous or asynchronous iterations. The class - deserializes the Server Sent Events (SSE) response stream into chat - completions updates, each one represented by a StreamingChatCompletionsUpdate object. - """ - - def __init__(self, response: AsyncHttpResponse): - super().__init__() - self._response = response - self._bytes_iterator: AsyncIterator[bytes] = response.iter_bytes() - - def __aiter__(self) -> Any: - return self - - async def __anext__(self) -> "_models.StreamingChatCompletionsUpdate": - while self._queue.empty() and not self._done: - self._done = await self._read_next_block_async() - if self._queue.empty(): - raise StopAsyncIteration - return self._queue.get() - - async def _read_next_block_async(self) -> bool: - if self._ENABLE_CLASS_LOGS: - logger.debug("[Reading next block]") - try: - element = await self._bytes_iterator.__anext__() - except StopAsyncIteration: - await self.aclose() - return True - return self._deserialize_and_add_to_queue(element) - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # type: ignore - asyncio.run(self.aclose()) - - async def aclose(self) -> None: - await self._response.close() - - -__all__: List[str] = [ - "ImageUrl", - "ChatCompletions", - "EmbeddingsResult", - "StreamingChatCompletions", - "AsyncStreamingChatCompletions", -] # Add all objects you want publicly available to users at this package level +__all__: List[str] = [] # Add all objects you want publicly available to users at this package level def patch_sdk(): diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/__init__.py b/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/__init__.py deleted file mode 100644 index 2e11b31cb6a4..000000000000 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# ------------------------------------ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# ------------------------------------ -# pylint: disable=unused-import -from ._patch import patch_sdk as _patch_sdk, PromptTemplate - -_patch_sdk() diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/_core.py b/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/_core.py deleted file mode 100644 index ec6702995149..000000000000 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/_core.py +++ /dev/null @@ -1,312 +0,0 @@ -# ------------------------------------ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# ------------------------------------ -# mypy: disable-error-code="assignment,attr-defined,index,arg-type" -# pylint: disable=line-too-long,R,consider-iterating-dictionary,raise-missing-from,dangerous-default-value -from __future__ import annotations -import os -from dataclasses import dataclass, field, asdict -from pathlib import Path -from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Union -from ._tracer import Tracer, to_dict -from ._utils import load_json - - -@dataclass -class ToolCall: - id: str - name: str - arguments: str - - -@dataclass -class PropertySettings: - """PropertySettings class to define the properties of the model - - Attributes - ---------- - type : str - The type of the property - default : Any - The default value of the property - description : str - The description of the property - """ - - type: Literal["string", "number", "array", "object", "boolean"] - default: Union[str, int, float, List, Dict, bool, None] = field(default=None) - description: str = field(default="") - - -@dataclass -class ModelSettings: - """ModelSettings class to define the model of the prompty - - Attributes - ---------- - api : str - The api of the model - configuration : Dict - The configuration of the model - parameters : Dict - The parameters of the model - response : Dict - The response of the model - """ - - api: str = field(default="") - configuration: Dict = field(default_factory=dict) - parameters: Dict = field(default_factory=dict) - response: Dict = field(default_factory=dict) - - -@dataclass -class TemplateSettings: - """TemplateSettings class to define the template of the prompty - - Attributes - ---------- - type : str - The type of the template - parser : str - The parser of the template - """ - - type: str = field(default="mustache") - parser: str = field(default="") - - -@dataclass -class Prompty: - """Prompty class to define the prompty - - Attributes - ---------- - name : str - The name of the prompty - description : str - The description of the prompty - authors : List[str] - The authors of the prompty - tags : List[str] - The tags of the prompty - version : str - The version of the prompty - base : str - The base of the prompty - basePrompty : Prompty - The base prompty - model : ModelSettings - The model of the prompty - sample : Dict - The sample of the prompty - inputs : Dict[str, PropertySettings] - The inputs of the prompty - outputs : Dict[str, PropertySettings] - The outputs of the prompty - template : TemplateSettings - The template of the prompty - file : FilePath - The file of the prompty - content : Union[str, List[str], Dict] - The content of the prompty - """ - - # metadata - name: str = field(default="") - description: str = field(default="") - authors: List[str] = field(default_factory=list) - tags: List[str] = field(default_factory=list) - version: str = field(default="") - base: str = field(default="") - basePrompty: Union[Prompty, None] = field(default=None) - # model - model: ModelSettings = field(default_factory=ModelSettings) - - # sample - sample: Dict = field(default_factory=dict) - - # input / output - inputs: Dict[str, PropertySettings] = field(default_factory=dict) - outputs: Dict[str, PropertySettings] = field(default_factory=dict) - - # template - template: TemplateSettings = field(default_factory=TemplateSettings) - - file: Union[Path, str] = field(default="") - content: Union[str, List[str], Dict] = field(default="") - - def to_safe_dict(self) -> Dict[str, Any]: - d = {} - if self.model: - d["model"] = asdict(self.model) - _mask_secrets(d, ["model", "configuration"]) - if self.template: - d["template"] = asdict(self.template) - if self.inputs: - d["inputs"] = {k: asdict(v) for k, v in self.inputs.items()} - if self.outputs: - d["outputs"] = {k: asdict(v) for k, v in self.outputs.items()} - if self.file: - d["file"] = str(self.file.as_posix()) if isinstance(self.file, Path) else self.file - return d - - @staticmethod - def hoist_base_prompty(top: Prompty, base: Prompty) -> Prompty: - top.name = base.name if top.name == "" else top.name - top.description = base.description if top.description == "" else top.description - top.authors = list(set(base.authors + top.authors)) - top.tags = list(set(base.tags + top.tags)) - top.version = base.version if top.version == "" else top.version - - top.model.api = base.model.api if top.model.api == "" else top.model.api - top.model.configuration = param_hoisting(top.model.configuration, base.model.configuration) - top.model.parameters = param_hoisting(top.model.parameters, base.model.parameters) - top.model.response = param_hoisting(top.model.response, base.model.response) - - top.sample = param_hoisting(top.sample, base.sample) - - top.basePrompty = base - - return top - - @staticmethod - def _process_file(file: str, parent: Path) -> Any: - file_path = Path(parent / Path(file)).resolve().absolute() - if file_path.exists(): - items = load_json(file_path) - if isinstance(items, list): - return [Prompty.normalize(value, parent) for value in items] - elif isinstance(items, Dict): - return {key: Prompty.normalize(value, parent) for key, value in items.items()} - else: - return items - else: - raise FileNotFoundError(f"File {file} not found") - - @staticmethod - def _process_env(variable: str, env_error=True, default: Union[str, None] = None) -> Any: - if variable in os.environ.keys(): - return os.environ[variable] - else: - if default: - return default - if env_error: - raise ValueError(f"Variable {variable} not found in environment") - - return "" - - @staticmethod - def normalize(attribute: Any, parent: Path, env_error=True) -> Any: - if isinstance(attribute, str): - attribute = attribute.strip() - if attribute.startswith("${") and attribute.endswith("}"): - # check if env or file - variable = attribute[2:-1].split(":") - if variable[0] == "env" and len(variable) > 1: - return Prompty._process_env( - variable[1], - env_error, - variable[2] if len(variable) > 2 else None, - ) - elif variable[0] == "file" and len(variable) > 1: - return Prompty._process_file(variable[1], parent) - else: - raise ValueError(f"Invalid attribute format ({attribute})") - else: - return attribute - elif isinstance(attribute, list): - return [Prompty.normalize(value, parent) for value in attribute] - elif isinstance(attribute, Dict): - return {key: Prompty.normalize(value, parent) for key, value in attribute.items()} - else: - return attribute - - -def param_hoisting(top: Dict[str, Any], bottom: Dict[str, Any], top_key: Union[str, None] = None) -> Dict[str, Any]: - if top_key: - new_dict = {**top[top_key]} if top_key in top else {} - else: - new_dict = {**top} - for key, value in bottom.items(): - if not key in new_dict: - new_dict[key] = value - return new_dict - - -class PromptyStream(Iterator): - """PromptyStream class to iterate over LLM stream. - Necessary for Prompty to handle streaming data when tracing.""" - - def __init__(self, name: str, iterator: Iterator): - self.name = name - self.iterator = iterator - self.items: List[Any] = [] - self.__name__ = "PromptyStream" - - def __iter__(self): - return self - - def __next__(self): - try: - # enumerate but add to list - o = self.iterator.__next__() - self.items.append(o) - return o - - except StopIteration: - # StopIteration is raised - # contents are exhausted - if len(self.items) > 0: - with Tracer.start("PromptyStream") as trace: - trace("signature", f"{self.name}.PromptyStream") - trace("inputs", "None") - trace("result", [to_dict(s) for s in self.items]) - - raise StopIteration - - -class AsyncPromptyStream(AsyncIterator): - """AsyncPromptyStream class to iterate over LLM stream. - Necessary for Prompty to handle streaming data when tracing.""" - - def __init__(self, name: str, iterator: AsyncIterator): - self.name = name - self.iterator = iterator - self.items: List[Any] = [] - self.__name__ = "AsyncPromptyStream" - - def __aiter__(self): - return self - - async def __anext__(self): - try: - # enumerate but add to list - o = await self.iterator.__anext__() - self.items.append(o) - return o - - except StopAsyncIteration: - # StopIteration is raised - # contents are exhausted - if len(self.items) > 0: - with Tracer.start("AsyncPromptyStream") as trace: - trace("signature", f"{self.name}.AsyncPromptyStream") - trace("inputs", "None") - trace("result", [to_dict(s) for s in self.items]) - - raise StopAsyncIteration - - -def _mask_secrets(d: Dict[str, Any], path: list[str], patterns: list[str] = ["key", "secret"]) -> bool: - sub_d = d - for key in path: - if key not in sub_d: - return False - sub_d = sub_d[key] - - for k, v in sub_d.items(): - if any([pattern in k.lower() for pattern in patterns]): - sub_d[k] = "*" * len(v) - return True diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/_invoker.py b/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/_invoker.py deleted file mode 100644 index d682662e7b01..000000000000 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/_invoker.py +++ /dev/null @@ -1,295 +0,0 @@ -# ------------------------------------ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# ------------------------------------ -# mypy: disable-error-code="return-value,operator" -# pylint: disable=line-too-long,R,docstring-missing-param,docstring-missing-return,docstring-missing-rtype,unnecessary-pass -import abc -from typing import Any, Callable, Dict, Literal -from ._tracer import trace -from ._core import Prompty - - -class Invoker(abc.ABC): - """Abstract class for Invoker - - Attributes - ---------- - prompty : Prompty - The prompty object - name : str - The name of the invoker - - """ - - def __init__(self, prompty: Prompty) -> None: - self.prompty = prompty - self.name = self.__class__.__name__ - - @abc.abstractmethod - def invoke(self, data: Any) -> Any: - """Abstract method to invoke the invoker - - Parameters - ---------- - data : Any - The data to be invoked - - Returns - ------- - Any - The invoked - """ - pass - - @abc.abstractmethod - async def invoke_async(self, data: Any) -> Any: - """Abstract method to invoke the invoker asynchronously - - Parameters - ---------- - data : Any - The data to be invoked - - Returns - ------- - Any - The invoked - """ - pass - - @trace - def run(self, data: Any) -> Any: - """Method to run the invoker - - Parameters - ---------- - data : Any - The data to be invoked - - Returns - ------- - Any - The invoked - """ - return self.invoke(data) - - @trace - async def run_async(self, data: Any) -> Any: - """Method to run the invoker asynchronously - - Parameters - ---------- - data : Any - The data to be invoked - - Returns - ------- - Any - The invoked - """ - return await self.invoke_async(data) - - -class InvokerFactory: - """Factory class for Invoker""" - - _renderers: Dict[str, Invoker] = {} - _parsers: Dict[str, Invoker] = {} - _executors: Dict[str, Invoker] = {} - _processors: Dict[str, Invoker] = {} - - @classmethod - def add_renderer(cls, name: str, invoker: Invoker) -> None: - cls._renderers[name] = invoker - - @classmethod - def add_parser(cls, name: str, invoker: Invoker) -> None: - cls._parsers[name] = invoker - - @classmethod - def add_executor(cls, name: str, invoker: Invoker) -> None: - cls._executors[name] = invoker - - @classmethod - def add_processor(cls, name: str, invoker: Invoker) -> None: - cls._processors[name] = invoker - - @classmethod - def register_renderer(cls, name: str) -> Callable: - def inner_wrapper(wrapped_class: Invoker) -> Callable: - cls._renderers[name] = wrapped_class - return wrapped_class # type: ignore - - return inner_wrapper - - @classmethod - def register_parser(cls, name: str) -> Callable: - def inner_wrapper(wrapped_class: Invoker) -> Callable: - cls._parsers[name] = wrapped_class - return wrapped_class # type: ignore - - return inner_wrapper - - @classmethod - def register_executor(cls, name: str) -> Callable: - def inner_wrapper(wrapped_class: Invoker) -> Callable: - cls._executors[name] = wrapped_class - return wrapped_class # type: ignore - - return inner_wrapper - - @classmethod - def register_processor(cls, name: str) -> Callable: - def inner_wrapper(wrapped_class: Invoker) -> Callable: - cls._processors[name] = wrapped_class - return wrapped_class # type: ignore - - return inner_wrapper - - @classmethod - def _get_name( - cls, - type: Literal["renderer", "parser", "executor", "processor"], - prompty: Prompty, - ) -> str: - if type == "renderer": - return prompty.template.type - elif type == "parser": - return f"{prompty.template.parser}.{prompty.model.api}" - elif type == "executor": - return prompty.model.configuration["type"] - elif type == "processor": - return prompty.model.configuration["type"] - else: - raise ValueError(f"Type {type} not found") - - @classmethod - def _get_invoker( - cls, - type: Literal["renderer", "parser", "executor", "processor"], - prompty: Prompty, - ) -> Invoker: - if type == "renderer": - name = prompty.template.type - if name not in cls._renderers: - raise ValueError(f"Renderer {name} not found") - - return cls._renderers[name](prompty) # type: ignore - - elif type == "parser": - name = f"{prompty.template.parser}.{prompty.model.api}" - if name not in cls._parsers: - raise ValueError(f"Parser {name} not found") - - return cls._parsers[name](prompty) # type: ignore - - elif type == "executor": - name = prompty.model.configuration["type"] - if name not in cls._executors: - raise ValueError(f"Executor {name} not found") - - return cls._executors[name](prompty) # type: ignore - - elif type == "processor": - name = prompty.model.configuration["type"] - if name not in cls._processors: - raise ValueError(f"Processor {name} not found") - - return cls._processors[name](prompty) # type: ignore - - else: - raise ValueError(f"Type {type} not found") - - @classmethod - def run( - cls, - type: Literal["renderer", "parser", "executor", "processor"], - prompty: Prompty, - data: Any, - default: Any = None, - ): - name = cls._get_name(type, prompty) - if name.startswith("NOOP") and default is not None: - return default - elif name.startswith("NOOP"): - return data - - invoker = cls._get_invoker(type, prompty) - value = invoker.run(data) - return value - - @classmethod - async def run_async( - cls, - type: Literal["renderer", "parser", "executor", "processor"], - prompty: Prompty, - data: Any, - default: Any = None, - ): - name = cls._get_name(type, prompty) - if name.startswith("NOOP") and default is not None: - return default - elif name.startswith("NOOP"): - return data - invoker = cls._get_invoker(type, prompty) - value = await invoker.run_async(data) - return value - - @classmethod - def run_renderer(cls, prompty: Prompty, data: Any, default: Any = None) -> Any: - return cls.run("renderer", prompty, data, default) - - @classmethod - async def run_renderer_async(cls, prompty: Prompty, data: Any, default: Any = None) -> Any: - return await cls.run_async("renderer", prompty, data, default) - - @classmethod - def run_parser(cls, prompty: Prompty, data: Any, default: Any = None) -> Any: - return cls.run("parser", prompty, data, default) - - @classmethod - async def run_parser_async(cls, prompty: Prompty, data: Any, default: Any = None) -> Any: - return await cls.run_async("parser", prompty, data, default) - - @classmethod - def run_executor(cls, prompty: Prompty, data: Any, default: Any = None) -> Any: - return cls.run("executor", prompty, data, default) - - @classmethod - async def run_executor_async(cls, prompty: Prompty, data: Any, default: Any = None) -> Any: - return await cls.run_async("executor", prompty, data, default) - - @classmethod - def run_processor(cls, prompty: Prompty, data: Any, default: Any = None) -> Any: - return cls.run("processor", prompty, data, default) - - @classmethod - async def run_processor_async(cls, prompty: Prompty, data: Any, default: Any = None) -> Any: - return await cls.run_async("processor", prompty, data, default) - - -class InvokerException(Exception): - """Exception class for Invoker""" - - def __init__(self, message: str, type: str) -> None: - super().__init__(message) - self.type = type - - def __str__(self) -> str: - return f"{super().__str__()}. Make sure to pip install any necessary package extras (i.e. could be something like `pip install prompty[{self.type}]`) for {self.type} as well as import the appropriate invokers (i.e. could be something like `import prompty.{self.type}`)." - - -@InvokerFactory.register_renderer("NOOP") -@InvokerFactory.register_parser("NOOP") -@InvokerFactory.register_executor("NOOP") -@InvokerFactory.register_processor("NOOP") -@InvokerFactory.register_parser("prompty.embedding") -@InvokerFactory.register_parser("prompty.image") -@InvokerFactory.register_parser("prompty.completion") -class NoOp(Invoker): - def invoke(self, data: Any) -> Any: - return data - - async def invoke_async(self, data: str) -> Any: - return self.invoke(data) diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/_mustache.py b/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/_mustache.py deleted file mode 100644 index f7a0c21d8bb8..000000000000 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/_mustache.py +++ /dev/null @@ -1,671 +0,0 @@ -# ------------------------------------ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# ------------------------------------ -# pylint: disable=line-too-long,R,consider-using-dict-items,docstring-missing-return,docstring-missing-rtype,docstring-missing-param,global-statement,unused-argument,global-variable-not-assigned,protected-access,logging-fstring-interpolation,deprecated-method -from __future__ import annotations -import logging -from collections.abc import Iterator, Sequence -from types import MappingProxyType -from typing import ( - Any, - Dict, - List, - Literal, - Mapping, - Optional, - Union, - cast, -) -from typing_extensions import TypeAlias - -logger = logging.getLogger(__name__) - - -Scopes: TypeAlias = List[Union[Literal[False, 0], Mapping[str, Any]]] - - -# Globals -_CURRENT_LINE = 1 -_LAST_TAG_LINE = None - - -class ChevronError(SyntaxError): - """Custom exception for Chevron errors.""" - - -# -# Helper functions -# - - -def grab_literal(template: str, l_del: str) -> tuple[str, str]: - """Parse a literal from the template. - - Args: - template: The template to parse. - l_del: The left delimiter. - - Returns: - Tuple[str, str]: The literal and the template. - """ - - global _CURRENT_LINE - - try: - # Look for the next tag and move the template to it - literal, template = template.split(l_del, 1) - _CURRENT_LINE += literal.count("\n") - return (literal, template) - - # There are no more tags in the template? - except ValueError: - # Then the rest of the template is a literal - return (template, "") - - -def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool: - """Do a preliminary check to see if a tag could be a standalone. - - Args: - template: The template. (Not used.) - literal: The literal. - is_standalone: Whether the tag is standalone. - - Returns: - bool: Whether the tag could be a standalone. - """ - - # If there is a newline, or the previous tag was a standalone - if literal.find("\n") != -1 or is_standalone: - padding = literal.split("\n")[-1] - - # If all the characters since the last newline are spaces - # Then the next tag could be a standalone - # Otherwise it can't be - return padding.isspace() or padding == "" - else: - return False - - -def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool: - """Do a final check to see if a tag could be a standalone. - - Args: - template: The template. - tag_type: The type of the tag. - is_standalone: Whether the tag is standalone. - - Returns: - bool: Whether the tag could be a standalone. - """ - - # Check right side if we might be a standalone - if is_standalone and tag_type not in ["variable", "no escape"]: - on_newline = template.split("\n", 1) - - # If the stuff to the right of us are spaces we're a standalone - return on_newline[0].isspace() or not on_newline[0] - - # If we're a tag can't be a standalone - else: - return False - - -def parse_tag(template: str, l_del: str, r_del: str) -> tuple[tuple[str, str], str]: - """Parse a tag from a template. - - Args: - template: The template. - l_del: The left delimiter. - r_del: The right delimiter. - - Returns: - Tuple[Tuple[str, str], str]: The tag and the template. - - Raises: - ChevronError: If the tag is unclosed. - ChevronError: If the set delimiter tag is unclosed. - """ - global _CURRENT_LINE - global _LAST_TAG_LINE - - tag_types = { - "!": "comment", - "#": "section", - "^": "inverted section", - "/": "end", - ">": "partial", - "=": "set delimiter?", - "{": "no escape?", - "&": "no escape", - } - - # Get the tag - try: - tag, template = template.split(r_del, 1) - except ValueError as e: - msg = "unclosed tag " f"at line {_CURRENT_LINE}" - raise ChevronError(msg) from e - - # Find the type meaning of the first character - tag_type = tag_types.get(tag[0], "variable") - - # If the type is not a variable - if tag_type != "variable": - # Then that first character is not needed - tag = tag[1:] - - # If we might be a set delimiter tag - if tag_type == "set delimiter?": - # Double check to make sure we are - if tag.endswith("="): - tag_type = "set delimiter" - # Remove the equal sign - tag = tag[:-1] - - # Otherwise we should complain - else: - msg = "unclosed set delimiter tag\n" f"at line {_CURRENT_LINE}" - raise ChevronError(msg) - - elif ( - # If we might be a no html escape tag - tag_type == "no escape?" - # And we have a third curly brace - # (And are using curly braces as delimiters) - and l_del == "{{" - and r_del == "}}" - and template.startswith("}") - ): - # Then we are a no html escape tag - template = template[1:] - tag_type = "no escape" - - # Strip the whitespace off the key and return - return ((tag_type, tag.strip()), template) - - -# -# The main tokenizing function -# - - -def tokenize(template: str, def_ldel: str = "{{", def_rdel: str = "}}") -> Iterator[tuple[str, str]]: - """Tokenize a mustache template. - - Tokenizes a mustache template in a generator fashion, - using file-like objects. It also accepts a string containing - the template. - - - Arguments: - - template -- a file-like object, or a string of a mustache template - - def_ldel -- The default left delimiter - ("{{" by default, as in spec compliant mustache) - - def_rdel -- The default right delimiter - ("}}" by default, as in spec compliant mustache) - - - Returns: - - A generator of mustache tags in the form of a tuple - - -- (tag_type, tag_key) - - Where tag_type is one of: - * literal - * section - * inverted section - * end - * partial - * no escape - - And tag_key is either the key or in the case of a literal tag, - the literal itself. - """ - - global _CURRENT_LINE, _LAST_TAG_LINE - _CURRENT_LINE = 1 - _LAST_TAG_LINE = None - - is_standalone = True - open_sections = [] - l_del = def_ldel - r_del = def_rdel - - while template: - literal, template = grab_literal(template, l_del) - - # If the template is completed - if not template: - # Then yield the literal and leave - yield ("literal", literal) - break - - # Do the first check to see if we could be a standalone - is_standalone = l_sa_check(template, literal, is_standalone) - - # Parse the tag - tag, template = parse_tag(template, l_del, r_del) - tag_type, tag_key = tag - - # Special tag logic - - # If we are a set delimiter tag - if tag_type == "set delimiter": - # Then get and set the delimiters - dels = tag_key.strip().split(" ") - l_del, r_del = dels[0], dels[-1] - - # If we are a section tag - elif tag_type in ["section", "inverted section"]: - # Then open a new section - open_sections.append(tag_key) - _LAST_TAG_LINE = _CURRENT_LINE - - # If we are an end tag - elif tag_type == "end": - # Then check to see if the last opened section - # is the same as us - try: - last_section = open_sections.pop() - except IndexError as e: - msg = f'Trying to close tag "{tag_key}"\n' "Looks like it was not opened.\n" f"line {_CURRENT_LINE + 1}" - raise ChevronError(msg) from e - if tag_key != last_section: - # Otherwise we need to complain - msg = ( - f'Trying to close tag "{tag_key}"\n' - f'last open tag is "{last_section}"\n' - f"line {_CURRENT_LINE + 1}" - ) - raise ChevronError(msg) - - # Do the second check to see if we're a standalone - is_standalone = r_sa_check(template, tag_type, is_standalone) - - # Which if we are - if is_standalone: - # Remove the stuff before the newline - template = template.split("\n", 1)[-1] - - # Partials need to keep the spaces on their left - if tag_type != "partial": - # But other tags don't - literal = literal.rstrip(" ") - - # Start yielding - # Ignore literals that are empty - if literal != "": - yield ("literal", literal) - - # Ignore comments and set delimiters - if tag_type not in ["comment", "set delimiter?"]: - yield (tag_type, tag_key) - - # If there are any open sections when we're done - if open_sections: - # Then we need to complain - msg = ( - "Unexpected EOF\n" - f'the tag "{open_sections[-1]}" was never closed\n' - f"was opened at line {_LAST_TAG_LINE}" - ) - raise ChevronError(msg) - - -# -# Helper functions -# - - -def _html_escape(string: str) -> str: - """HTML escape all of these " & < >""" - - html_codes = { - '"': """, - "<": "<", - ">": ">", - } - - # & must be handled first - string = string.replace("&", "&") - for char in html_codes: - string = string.replace(char, html_codes[char]) - return string - - -def _get_key( - key: str, - scopes: Scopes, - warn: bool, - keep: bool, - def_ldel: str, - def_rdel: str, -) -> Any: - """Get a key from the current scope""" - - # If the key is a dot - if key == ".": - # Then just return the current scope - return scopes[0] - - # Loop through the scopes - for scope in scopes: - try: - # Return an empty string if falsy, with two exceptions - # 0 should return 0, and False should return False - if scope in (0, False): - return scope - - # For every dot separated key - for child in key.split("."): - # Return an empty string if falsy, with two exceptions - # 0 should return 0, and False should return False - if scope in (0, False): - return scope - # Move into the scope - try: - # Try subscripting (Normal dictionaries) - scope = cast(Dict[str, Any], scope)[child] - except (TypeError, AttributeError): - try: - scope = getattr(scope, child) - except (TypeError, AttributeError): - # Try as a list - scope = scope[int(child)] # type: ignore - - try: - # This allows for custom falsy data types - # https://github.com/noahmorrison/chevron/issues/35 - if scope._CHEVRON_return_scope_when_falsy: # type: ignore - return scope - except AttributeError: - if scope in (0, False): - return scope - return scope or "" - except (AttributeError, KeyError, IndexError, ValueError): - # We couldn't find the key in the current scope - # We'll try again on the next pass - pass - - # We couldn't find the key in any of the scopes - - if warn: - logger.warn(f"Could not find key '{key}'") - - if keep: - return f"{def_ldel} {key} {def_rdel}" - - return "" - - -def _get_partial(name: str, partials_dict: Mapping[str, str]) -> str: - """Load a partial""" - try: - # Maybe the partial is in the dictionary - return partials_dict[name] - except KeyError: - return "" - - -# -# The main rendering function -# -g_token_cache: Dict[str, List[tuple[str, str]]] = {} - -EMPTY_DICT: MappingProxyType[str, str] = MappingProxyType({}) - - -def render( - template: Union[str, List[tuple[str, str]]] = "", - data: Mapping[str, Any] = EMPTY_DICT, - partials_dict: Mapping[str, str] = EMPTY_DICT, - padding: str = "", - def_ldel: str = "{{", - def_rdel: str = "}}", - scopes: Optional[Scopes] = None, - warn: bool = False, - keep: bool = False, -) -> str: - """Render a mustache template. - - Renders a mustache template with a data scope and inline partial capability. - - Arguments: - - template -- A file-like object or a string containing the template. - - data -- A python dictionary with your data scope. - - partials_path -- The path to where your partials are stored. - If set to None, then partials won't be loaded from the file system - (defaults to '.'). - - partials_ext -- The extension that you want the parser to look for - (defaults to 'mustache'). - - partials_dict -- A python dictionary which will be search for partials - before the filesystem is. {'include': 'foo'} is the same - as a file called include.mustache - (defaults to {}). - - padding -- This is for padding partials, and shouldn't be used - (but can be if you really want to). - - def_ldel -- The default left delimiter - ("{{" by default, as in spec compliant mustache). - - def_rdel -- The default right delimiter - ("}}" by default, as in spec compliant mustache). - - scopes -- The list of scopes that get_key will look through. - - warn -- Log a warning when a template substitution isn't found in the data - - keep -- Keep unreplaced tags when a substitution isn't found in the data. - - - Returns: - - A string containing the rendered template. - """ - - # If the template is a sequence but not derived from a string - if isinstance(template, Sequence) and not isinstance(template, str): - # Then we don't need to tokenize it - # But it does need to be a generator - tokens: Iterator[tuple[str, str]] = (token for token in template) - else: - if template in g_token_cache: - tokens = (token for token in g_token_cache[template]) - else: - # Otherwise make a generator - tokens = tokenize(template, def_ldel, def_rdel) - - output = "" - - if scopes is None: - scopes = [data] - - # Run through the tokens - for tag, key in tokens: - # Set the current scope - current_scope = scopes[0] - - # If we're an end tag - if tag == "end": - # Pop out of the latest scope - del scopes[0] - - # If the current scope is falsy and not the only scope - elif not current_scope and len(scopes) != 1: - if tag in ["section", "inverted section"]: - # Set the most recent scope to a falsy value - scopes.insert(0, False) - - # If we're a literal tag - elif tag == "literal": - # Add padding to the key and add it to the output - output += key.replace("\n", "\n" + padding) - - # If we're a variable tag - elif tag == "variable": - # Add the html escaped key to the output - thing = _get_key(key, scopes, warn=warn, keep=keep, def_ldel=def_ldel, def_rdel=def_rdel) - if thing is True and key == ".": - # if we've coerced into a boolean by accident - # (inverted tags do this) - # then get the un-coerced object (next in the stack) - thing = scopes[1] - if not isinstance(thing, str): - thing = str(thing) - output += _html_escape(thing) - - # If we're a no html escape tag - elif tag == "no escape": - # Just lookup the key and add it - thing = _get_key(key, scopes, warn=warn, keep=keep, def_ldel=def_ldel, def_rdel=def_rdel) - if not isinstance(thing, str): - thing = str(thing) - output += thing - - # If we're a section tag - elif tag == "section": - # Get the sections scope - scope = _get_key(key, scopes, warn=warn, keep=keep, def_ldel=def_ldel, def_rdel=def_rdel) - - # If the scope is a callable (as described in - # https://mustache.github.io/mustache.5.html) - if callable(scope): - # Generate template text from tags - text = "" - tags: List[tuple[str, str]] = [] - for token in tokens: - if token == ("end", key): - break - - tags.append(token) - tag_type, tag_key = token - if tag_type == "literal": - text += tag_key - elif tag_type == "no escape": - text += f"{def_ldel}& {tag_key} {def_rdel}" - else: - text += "{}{} {}{}".format( - def_ldel, - { - "comment": "!", - "section": "#", - "inverted section": "^", - "end": "/", - "partial": ">", - "set delimiter": "=", - "no escape": "&", - "variable": "", - }[tag_type], - tag_key, - def_rdel, - ) - - g_token_cache[text] = tags - - rend = scope( - text, - lambda template, data=None: render( - template, - data={}, - partials_dict=partials_dict, - padding=padding, - def_ldel=def_ldel, - def_rdel=def_rdel, - scopes=data and [data] + scopes or scopes, - warn=warn, - keep=keep, - ), - ) - - output += rend # type: ignore[reportOperatorIssue] - - # If the scope is a sequence, an iterator or generator but not - # derived from a string - elif isinstance(scope, (Sequence, Iterator)) and not isinstance(scope, str): - # Then we need to do some looping - - # Gather up all the tags inside the section - # (And don't be tricked by nested end tags with the same key) - # TODO: This feels like it still has edge cases, no? - tags = [] - tags_with_same_key = 0 - for token in tokens: - if token == ("section", key): - tags_with_same_key += 1 - if token == ("end", key): - tags_with_same_key -= 1 - if tags_with_same_key < 0: - break - tags.append(token) - - # For every item in the scope - for thing in scope: - # Append it as the most recent scope and render - new_scope = [thing] + scopes - rend = render( - template=tags, - scopes=new_scope, - padding=padding, - partials_dict=partials_dict, - def_ldel=def_ldel, - def_rdel=def_rdel, - warn=warn, - keep=keep, - ) - - output += rend - - else: - # Otherwise we're just a scope section - scopes.insert(0, scope) # type: ignore[reportArgumentType] - - # If we're an inverted section - elif tag == "inverted section": - # Add the flipped scope to the scopes - scope = _get_key(key, scopes, warn=warn, keep=keep, def_ldel=def_ldel, def_rdel=def_rdel) - scopes.insert(0, cast(Literal[False], not scope)) - - # If we're a partial - elif tag == "partial": - # Load the partial - partial = _get_partial(key, partials_dict) - - # Find what to pad the partial with - left = output.rpartition("\n")[2] - part_padding = padding - if left.isspace(): - part_padding += left - - # Render the partial - part_out = render( - template=partial, - partials_dict=partials_dict, - def_ldel=def_ldel, - def_rdel=def_rdel, - padding=part_padding, - scopes=scopes, - warn=warn, - keep=keep, - ) - - # If the partial was indented - if left.isspace(): - # then remove the spaces from the end - part_out = part_out.rstrip(" \t") - - # Add the partials output to the output - output += part_out - - return output diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/_parsers.py b/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/_parsers.py deleted file mode 100644 index de3c570e5c89..000000000000 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/_parsers.py +++ /dev/null @@ -1,156 +0,0 @@ -# ------------------------------------ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# ------------------------------------ -# mypy: disable-error-code="union-attr,return-value" -# pylint: disable=line-too-long,R,consider-using-enumerate,docstring-missing-param,docstring-missing-return,docstring-missing-rtype -import re -import base64 -from pathlib import Path -from typing import Any, Union -from ._core import Prompty -from ._invoker import Invoker, InvokerFactory - - -ROLES = ["assistant", "function", "system", "user"] - - -@InvokerFactory.register_parser("prompty.chat") -class PromptyChatParser(Invoker): - """Prompty Chat Parser""" - - def __init__(self, prompty: Prompty) -> None: - super().__init__(prompty) - self.path = Path(self.prompty.file).parent - - def invoke(self, data: str) -> Any: - return invoke_parser(self.path, data) - - async def invoke_async(self, data: str) -> Any: - """Invoke the Prompty Chat Parser (Async) - - Parameters - ---------- - data : str - The data to parse - - Returns - ------- - str - The parsed data - """ - return self.invoke(data) - - -def _inline_image(path: Union[Path, None], image_item: str) -> str: - """Inline Image - - Parameters - ---------- - image_item : str - The image item to inline - - Returns - ------- - str - The inlined image - """ - # pass through if it's a url or base64 encoded or the path is None - if image_item.startswith("http") or image_item.startswith("data") or path is None: - return image_item - # otherwise, it's a local file - need to base64 encode it - else: - image_path = (path if path is not None else Path(".")) / image_item - with open(image_path, "rb") as f: - base64_image = base64.b64encode(f.read()).decode("utf-8") - - if image_path.suffix == ".png": - return f"data:image/png;base64,{base64_image}" - elif image_path.suffix == ".jpg": - return f"data:image/jpeg;base64,{base64_image}" - elif image_path.suffix == ".jpeg": - return f"data:image/jpeg;base64,{base64_image}" - else: - raise ValueError( - f"Invalid image format {image_path.suffix} - currently only .png and .jpg / .jpeg are supported." - ) - - -def _parse_content(path: Union[Path, None], content: str): - """for parsing inline images - - Parameters - ---------- - content : str - The content to parse - - Returns - ------- - any - The parsed content - """ - # regular expression to parse markdown images - image = r"(?P!\[[^\]]*\])\((?P.*?)(?=\"|\))\)" - matches = re.findall(image, content, flags=re.MULTILINE) - if len(matches) > 0: - content_items = [] - content_chunks = re.split(image, content, flags=re.MULTILINE) - current_chunk = 0 - for i in range(len(content_chunks)): - # image entry - if current_chunk < len(matches) and content_chunks[i] == matches[current_chunk][0]: - content_items.append( - { - "type": "image_url", - "image_url": {"url": _inline_image(path, matches[current_chunk][1].split(" ")[0].strip())}, - } - ) - # second part of image entry - elif current_chunk < len(matches) and content_chunks[i] == matches[current_chunk][1]: - current_chunk += 1 - # text entry - else: - if len(content_chunks[i].strip()) > 0: - content_items.append({"type": "text", "text": content_chunks[i].strip()}) - return content_items - else: - return content - - -def invoke_parser(path: Union[Path, None], data: str) -> Any: - """Invoke the Prompty Chat Parser - - Parameters - ---------- - data : str - The data to parse - - Returns - ------- - str - The parsed data - """ - messages = [] - separator = r"(?i)^\s*#?\s*(" + "|".join(ROLES) + r")\s*:\s*\n" - - # get valid chunks - remove empty items - chunks = [item for item in re.split(separator, data, flags=re.MULTILINE) if len(item.strip()) > 0] - - # if no starter role, then inject system role - if not chunks[0].strip().lower() in ROLES: - chunks.insert(0, "system") - - # if last chunk is role entry, then remove (no content?) - if chunks[-1].strip().lower() in ROLES: - chunks.pop() - - if len(chunks) % 2 != 0: - raise ValueError("Invalid prompt format") - - # create messages - for i in range(0, len(chunks), 2): - role = chunks[i].strip().lower() - content = chunks[i + 1].strip() - messages.append({"role": role, "content": _parse_content(path, content)}) - - return messages diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/_patch.py b/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/_patch.py deleted file mode 100644 index 14ad4f62b4c1..000000000000 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/_patch.py +++ /dev/null @@ -1,124 +0,0 @@ -# ------------------------------------ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# ------------------------------------ -# pylint: disable=line-too-long,R -"""Customize generated code here. - -Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize -""" - -import traceback -from pathlib import Path -from typing import Any, Dict, List, Optional -from typing_extensions import Self -from ._core import Prompty -from ._mustache import render -from ._parsers import invoke_parser -from ._prompty_utils import load, prepare -from ._utils import remove_leading_empty_space - - -class PromptTemplate: - """The helper class which takes variant of inputs, e.g. Prompty format or string, and returns the parsed prompt in an array.""" - - @classmethod - def from_prompty(cls, file_path: str) -> Self: - """Initialize a PromptTemplate object from a prompty file. - - :param file_path: The path to the prompty file. - :type file_path: str - :return: The PromptTemplate object. - :rtype: PromptTemplate - """ - if not file_path: - raise ValueError("Please provide file_path") - - # Get the absolute path of the file by `traceback.extract_stack()`, it's "-2" because: - # In the stack, the last function is the current function. - # The second last function is the caller function, which is the root of the file_path. - stack = traceback.extract_stack() - caller = Path(stack[-2].filename) - abs_file_path = Path(caller.parent / Path(file_path)).resolve().absolute() - - prompty = load(str(abs_file_path)) - return cls(prompty=prompty) - - @classmethod - def from_string(cls, prompt_template: str, api: str = "chat", model_name: Optional[str] = None) -> Self: - """Initialize a PromptTemplate object from a message template. - - :param prompt_template: The prompt template string. - :type prompt_template: str - :param api: The API type, e.g. "chat" or "completion". - :type api: str - :param model_name: The model name, e.g. "gpt-4o-mini". - :type model_name: str - :return: The PromptTemplate object. - :rtype: PromptTemplate - """ - return cls( - api=api, - prompt_template=prompt_template, - model_name=model_name, - prompty=None, - ) - - def __init__( - self, - *, - api: str = "chat", - prompty: Optional[Prompty] = None, - prompt_template: Optional[str] = None, - model_name: Optional[str] = None, - ) -> None: - self.prompty = prompty - if self.prompty is not None: - self.model_name = ( - self.prompty.model.configuration["azure_deployment"] - if "azure_deployment" in self.prompty.model.configuration - else None - ) - self.parameters = self.prompty.model.parameters - self._config = {} - elif prompt_template is not None: - self.model_name = model_name - self.parameters = {} - # _config is a dict to hold the internal configuration - self._config = { - "api": api if api is not None else "chat", - "prompt_template": prompt_template, - } - else: - raise ValueError("Please pass valid arguments for PromptTemplate") - - def create_messages(self, data: Optional[Dict[str, Any]] = None, **kwargs) -> List[Dict[str, Any]]: - """Render the prompt template with the given data. - - :param data: The data to render the prompt template with. - :type data: Optional[Dict[str, Any]] - :return: The rendered prompt template. - :rtype: List[Dict[str, Any]] - """ - if data is None: - data = kwargs - - if self.prompty is not None: - parsed = prepare(self.prompty, data) - return parsed - elif "prompt_template" in self._config: - prompt_template = remove_leading_empty_space(self._config["prompt_template"]) - system_prompt_str = render(prompt_template, data) - parsed = invoke_parser(None, system_prompt_str) - return parsed - else: - raise ValueError("Please provide valid prompt template") - - -def patch_sdk(): - """Do not remove from this file. - - `patch_sdk` is a last resort escape hatch that allows you to do customizations - you can't accomplish using the techniques described in - https://aka.ms/azsdk/python/dpcodegen/python/customize - """ diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/_prompty_utils.py b/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/_prompty_utils.py deleted file mode 100644 index 5ea38bda6229..000000000000 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/_prompty_utils.py +++ /dev/null @@ -1,415 +0,0 @@ -# ------------------------------------ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# ------------------------------------ -# mypy: disable-error-code="assignment" -# pylint: disable=R,docstring-missing-param,docstring-missing-return,docstring-missing-rtype,dangerous-default-value,redefined-outer-name,unused-wildcard-import,wildcard-import,raise-missing-from -import traceback -from pathlib import Path -from typing import Any, Dict, List, Union -from ._tracer import trace -from ._invoker import InvokerFactory -from ._core import ( - ModelSettings, - Prompty, - PropertySettings, - TemplateSettings, - param_hoisting, -) -from ._utils import ( - load_global_config, - load_prompty, -) - -from ._renderers import * -from ._parsers import * - - -@trace(description="Create a headless prompty object for programmatic use.") -def headless( - api: str, - content: Union[str, List[str], dict], - configuration: Dict[str, Any] = {}, - parameters: Dict[str, Any] = {}, - connection: str = "default", -) -> Prompty: - """Create a headless prompty object for programmatic use. - - Parameters - ---------- - api : str - The API to use for the model - content : Union[str, List[str], dict] - The content to process - configuration : Dict[str, Any], optional - The configuration to use, by default {} - parameters : Dict[str, Any], optional - The parameters to use, by default {} - connection : str, optional - The connection to use, by default "default" - - Returns - ------- - Prompty - The headless prompty object - - Example - ------- - >>> import prompty - >>> p = prompty.headless( - api="embedding", - configuration={"type": "azure", "azure_deployment": "text-embedding-ada-002"}, - content="hello world", - ) - >>> emb = prompty.execute(p) - - """ - - # get caller's path (to get relative path for prompty.json) - caller = Path(traceback.extract_stack()[-2].filename) - templateSettings = TemplateSettings(type="NOOP", parser="NOOP") - modelSettings = ModelSettings( - api=api, - configuration=Prompty.normalize( - param_hoisting(configuration, load_global_config(caller.parent, connection)), - caller.parent, - ), - parameters=parameters, - ) - - return Prompty(model=modelSettings, template=templateSettings, content=content) - - -def _load_raw_prompty(attributes: dict, content: str, p: Path, global_config: dict): - if "model" not in attributes: - attributes["model"] = {} - - if "configuration" not in attributes["model"]: - attributes["model"]["configuration"] = global_config - else: - attributes["model"]["configuration"] = param_hoisting( - attributes["model"]["configuration"], - global_config, - ) - - # pull model settings out of attributes - try: - model = ModelSettings(**attributes.pop("model")) - except Exception as e: - raise ValueError(f"Error in model settings: {e}") - - # pull template settings - try: - if "template" in attributes: - t = attributes.pop("template") - if isinstance(t, dict): - template = TemplateSettings(**t) - # has to be a string denoting the type - else: - template = TemplateSettings(type=t, parser="prompty") - else: - template = TemplateSettings(type="mustache", parser="prompty") - except Exception as e: - raise ValueError(f"Error in template loader: {e}") - - # formalize inputs and outputs - if "inputs" in attributes: - try: - inputs = {k: PropertySettings(**v) for (k, v) in attributes.pop("inputs").items()} - except Exception as e: - raise ValueError(f"Error in inputs: {e}") - else: - inputs = {} - if "outputs" in attributes: - try: - outputs = {k: PropertySettings(**v) for (k, v) in attributes.pop("outputs").items()} - except Exception as e: - raise ValueError(f"Error in outputs: {e}") - else: - outputs = {} - - prompty = Prompty( - **attributes, - model=model, - inputs=inputs, - outputs=outputs, - template=template, - content=content, - file=p, - ) - - return prompty - - -@trace(description="Load a prompty file.") -def load(prompty_file: Union[str, Path], configuration: str = "default") -> Prompty: - """Load a prompty file. - - Parameters - ---------- - prompty_file : Union[str, Path] - The path to the prompty file - configuration : str, optional - The configuration to use, by default "default" - - Returns - ------- - Prompty - The loaded prompty object - - Example - ------- - >>> import prompty - >>> p = prompty.load("prompts/basic.prompty") - >>> print(p) - """ - - p = Path(prompty_file) - if not p.is_absolute(): - # get caller's path (take into account trace frame) - caller = Path(traceback.extract_stack()[-3].filename) - p = Path(caller.parent / p).resolve().absolute() - - # load dictionary from prompty file - matter = load_prompty(p) - - attributes = matter["attributes"] - content = matter["body"] - - # normalize attribute dictionary resolve keys and files - attributes = Prompty.normalize(attributes, p.parent) - - # load global configuration - global_config = Prompty.normalize(load_global_config(p.parent, configuration), p.parent) - - prompty = _load_raw_prompty(attributes, content, p, global_config) - - # recursive loading of base prompty - if "base" in attributes: - # load the base prompty from the same directory as the current prompty - base = load(p.parent / attributes["base"]) - prompty = Prompty.hoist_base_prompty(prompty, base) - - return prompty - - -@trace(description="Prepare the inputs for the prompt.") -def prepare( - prompt: Prompty, - inputs: Dict[str, Any] = {}, -): - """Prepare the inputs for the prompt. - - Parameters - ---------- - prompt : Prompty - The prompty object - inputs : Dict[str, Any], optional - The inputs to the prompt, by default {} - - Returns - ------- - dict - The prepared and hidrated template shaped to the LLM model - - Example - ------- - >>> import prompty - >>> p = prompty.load("prompts/basic.prompty") - >>> inputs = {"name": "John Doe"} - >>> content = prompty.prepare(p, inputs) - """ - inputs = param_hoisting(inputs, prompt.sample) - - render = InvokerFactory.run_renderer(prompt, inputs, prompt.content) - result = InvokerFactory.run_parser(prompt, render) - - return result - - -@trace(description="Prepare the inputs for the prompt.") -async def prepare_async( - prompt: Prompty, - inputs: Dict[str, Any] = {}, -): - """Prepare the inputs for the prompt. - - Parameters - ---------- - prompt : Prompty - The prompty object - inputs : Dict[str, Any], optional - The inputs to the prompt, by default {} - - Returns - ------- - dict - The prepared and hidrated template shaped to the LLM model - - Example - ------- - >>> import prompty - >>> p = prompty.load("prompts/basic.prompty") - >>> inputs = {"name": "John Doe"} - >>> content = await prompty.prepare_async(p, inputs) - """ - inputs = param_hoisting(inputs, prompt.sample) - - render = await InvokerFactory.run_renderer_async(prompt, inputs, prompt.content) - result = await InvokerFactory.run_parser_async(prompt, render) - - return result - - -@trace(description="Run the prepared Prompty content against the model.") -def run( - prompt: Prompty, - content: Union[dict, list, str], - configuration: Dict[str, Any] = {}, - parameters: Dict[str, Any] = {}, - raw: bool = False, -): - """Run the prepared Prompty content. - - Parameters - ---------- - prompt : Prompty - The prompty object - content : Union[dict, list, str] - The content to process - configuration : Dict[str, Any], optional - The configuration to use, by default {} - parameters : Dict[str, Any], optional - The parameters to use, by default {} - raw : bool, optional - Whether to skip processing, by default False - - Returns - ------- - Any - The result of the prompt - - Example - ------- - >>> import prompty - >>> p = prompty.load("prompts/basic.prompty") - >>> inputs = {"name": "John Doe"} - >>> content = prompty.prepare(p, inputs) - >>> result = prompty.run(p, content) - """ - - if configuration != {}: - prompt.model.configuration = param_hoisting(configuration, prompt.model.configuration) - - if parameters != {}: - prompt.model.parameters = param_hoisting(parameters, prompt.model.parameters) - - result = InvokerFactory.run_executor(prompt, content) - if not raw: - result = InvokerFactory.run_processor(prompt, result) - - return result - - -@trace(description="Run the prepared Prompty content against the model.") -async def run_async( - prompt: Prompty, - content: Union[dict, list, str], - configuration: Dict[str, Any] = {}, - parameters: Dict[str, Any] = {}, - raw: bool = False, -): - """Run the prepared Prompty content. - - Parameters - ---------- - prompt : Prompty - The prompty object - content : Union[dict, list, str] - The content to process - configuration : Dict[str, Any], optional - The configuration to use, by default {} - parameters : Dict[str, Any], optional - The parameters to use, by default {} - raw : bool, optional - Whether to skip processing, by default False - - Returns - ------- - Any - The result of the prompt - - Example - ------- - >>> import prompty - >>> p = prompty.load("prompts/basic.prompty") - >>> inputs = {"name": "John Doe"} - >>> content = await prompty.prepare_async(p, inputs) - >>> result = await prompty.run_async(p, content) - """ - - if configuration != {}: - prompt.model.configuration = param_hoisting(configuration, prompt.model.configuration) - - if parameters != {}: - prompt.model.parameters = param_hoisting(parameters, prompt.model.parameters) - - result = await InvokerFactory.run_executor_async(prompt, content) - if not raw: - result = await InvokerFactory.run_processor_async(prompt, result) - - return result - - -@trace(description="Execute a prompty") -def execute( - prompt: Union[str, Prompty], - configuration: Dict[str, Any] = {}, - parameters: Dict[str, Any] = {}, - inputs: Dict[str, Any] = {}, - raw: bool = False, - config_name: str = "default", -): - """Execute a prompty. - - Parameters - ---------- - prompt : Union[str, Prompty] - The prompty object or path to the prompty file - configuration : Dict[str, Any], optional - The configuration to use, by default {} - parameters : Dict[str, Any], optional - The parameters to use, by default {} - inputs : Dict[str, Any], optional - The inputs to the prompt, by default {} - raw : bool, optional - Whether to skip processing, by default False - connection : str, optional - The connection to use, by default "default" - - Returns - ------- - Any - The result of the prompt - - Example - ------- - >>> import prompty - >>> inputs = {"name": "John Doe"} - >>> result = prompty.execute("prompts/basic.prompty", inputs=inputs) - """ - if isinstance(prompt, str): - path = Path(prompt) - if not path.is_absolute(): - # get caller's path (take into account trace frame) - caller = Path(traceback.extract_stack()[-3].filename) - path = Path(caller.parent / path).resolve().absolute() - prompt = load(path, config_name) - - # prepare content - content = prepare(prompt, inputs) - - # run LLM model - result = run(prompt, content, configuration, parameters, raw) - - return result diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/_renderers.py b/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/_renderers.py deleted file mode 100644 index 0d682a7fe151..000000000000 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/_renderers.py +++ /dev/null @@ -1,30 +0,0 @@ -# ------------------------------------ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# ------------------------------------ -# mypy: disable-error-code="union-attr,assignment,arg-type" -from pathlib import Path -from ._core import Prompty -from ._invoker import Invoker, InvokerFactory -from ._mustache import render - - -@InvokerFactory.register_renderer("mustache") -class MustacheRenderer(Invoker): - """Render a mustache template.""" - - def __init__(self, prompty: Prompty) -> None: - super().__init__(prompty) - self.templates = {} - cur_prompt = self.prompty - while cur_prompt: - self.templates[Path(cur_prompt.file).name] = cur_prompt.content - cur_prompt = cur_prompt.basePrompty - self.name = Path(self.prompty.file).name - - def invoke(self, data: str) -> str: - generated = render(self.prompty.content, data) # type: ignore - return generated - - async def invoke_async(self, data: str) -> str: - return self.invoke(data) diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/_tracer.py b/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/_tracer.py deleted file mode 100644 index 24f800b465f4..000000000000 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/_tracer.py +++ /dev/null @@ -1,316 +0,0 @@ -# ------------------------------------ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# ------------------------------------ -# mypy: disable-error-code="union-attr,arg-type,misc,return-value,assignment,func-returns-value" -# pylint: disable=R,redefined-outer-name,bare-except,unspecified-encoding -import os -import json -import inspect -import traceback -import importlib -import contextlib -from pathlib import Path -from numbers import Number -from datetime import datetime -from functools import wraps, partial -from typing import Any, Callable, Dict, Iterator, List, Union - - -# clean up key value pairs for sensitive values -def sanitize(key: str, value: Any) -> Any: - if isinstance(value, str) and any([s in key.lower() for s in ["key", "token", "secret", "password", "credential"]]): - return len(str(value)) * "*" - - if isinstance(value, dict): - return {k: sanitize(k, v) for k, v in value.items()} - - return value - - -class Tracer: - _tracers: Dict[str, Callable[[str], Iterator[Callable[[str, Any], None]]]] = {} - - @classmethod - def add(cls, name: str, tracer: Callable[[str], Iterator[Callable[[str, Any], None]]]) -> None: - cls._tracers[name] = tracer - - @classmethod - def clear(cls) -> None: - cls._tracers = {} - - @classmethod - @contextlib.contextmanager - def start(cls, name: str) -> Iterator[Callable[[str, Any], None]]: - with contextlib.ExitStack() as stack: - traces: List[Any] = [stack.enter_context(tracer(name)) for tracer in cls._tracers.values()] # type: ignore - yield lambda key, value: [ # type: ignore - # normalize and sanitize any trace values - trace(key, sanitize(key, to_dict(value))) - for trace in traces - ] - - -def to_dict(obj: Any) -> Union[Dict[str, Any], List[Dict[str, Any]], str, Number, bool]: - # simple json types - if isinstance(obj, str) or isinstance(obj, Number) or isinstance(obj, bool): - return obj - - # datetime - if isinstance(obj, datetime): - return obj.isoformat() - - # safe Prompty obj serialization - if type(obj).__name__ == "Prompty": - return obj.to_safe_dict() - - # safe PromptyStream obj serialization - if type(obj).__name__ == "PromptyStream": - return "PromptyStream" - - if type(obj).__name__ == "AsyncPromptyStream": - return "AsyncPromptyStream" - - # recursive list and dict - if isinstance(obj, List): - return [to_dict(item) for item in obj] # type: ignore - - if isinstance(obj, Dict): - return {k: v if isinstance(v, str) else to_dict(v) for k, v in obj.items()} - - if isinstance(obj, Path): - return str(obj) - - # cast to string otherwise... - return str(obj) - - -def _name(func: Callable, args): - if hasattr(func, "__qualname__"): - signature = f"{func.__module__}.{func.__qualname__}" - else: - signature = f"{func.__module__}.{func.__name__}" - - # core invoker gets special treatment prompty.invoker.Invoker - core_invoker = signature.startswith("prompty.invoker.Invoker.run") - if core_invoker: - name = type(args[0]).__name__ - if signature.endswith("async"): - signature = f"{args[0].__module__}.{args[0].__class__.__name__}.invoke_async" - else: - signature = f"{args[0].__module__}.{args[0].__class__.__name__}.invoke" - else: - name = func.__name__ - - return name, signature - - -def _inputs(func: Callable, args, kwargs) -> dict: - ba = inspect.signature(func).bind(*args, **kwargs) - ba.apply_defaults() - - inputs = {k: to_dict(v) for k, v in ba.arguments.items() if k != "self"} - - return inputs - - -def _results(result: Any) -> Union[Dict, List[Dict], str, Number, bool]: - return to_dict(result) if result is not None else "None" - - -def _trace_sync(func: Union[Callable, None] = None, **okwargs: Any) -> Callable: - - @wraps(func) # type: ignore - def wrapper(*args, **kwargs): - name, signature = _name(func, args) # type: ignore - with Tracer.start(name) as trace: - trace("signature", signature) - - # support arbitrary keyword - # arguments for trace decorator - for k, v in okwargs.items(): - trace(k, to_dict(v)) - - inputs = _inputs(func, args, kwargs) # type: ignore - trace("inputs", inputs) - - try: - result = func(*args, **kwargs) # type: ignore - trace("result", _results(result)) - except Exception as e: - trace( - "result", - { - "exception": { - "type": type(e), - "traceback": (traceback.format_tb(tb=e.__traceback__) if e.__traceback__ else None), - "message": str(e), - "args": to_dict(e.args), - } - }, - ) - raise e - - return result - - return wrapper - - -def _trace_async(func: Union[Callable, None] = None, **okwargs: Any) -> Callable: - - @wraps(func) # type: ignore - async def wrapper(*args, **kwargs): - name, signature = _name(func, args) # type: ignore - with Tracer.start(name) as trace: - trace("signature", signature) - - # support arbitrary keyword - # arguments for trace decorator - for k, v in okwargs.items(): - trace(k, to_dict(v)) - - inputs = _inputs(func, args, kwargs) # type: ignore - trace("inputs", inputs) - try: - result = await func(*args, **kwargs) # type: ignore - trace("result", _results(result)) - except Exception as e: - trace( - "result", - { - "exception": { - "type": type(e), - "traceback": (traceback.format_tb(tb=e.__traceback__) if e.__traceback__ else None), - "message": str(e), - "args": to_dict(e.args), - } - }, - ) - raise e - - return result - - return wrapper - - -def trace(func: Union[Callable, None] = None, **kwargs: Any) -> Callable: - if func is None: - return partial(trace, **kwargs) - wrapped_method = _trace_async if inspect.iscoroutinefunction(func) else _trace_sync - return wrapped_method(func, **kwargs) - - -class PromptyTracer: - def __init__(self, output_dir: Union[str, None] = None) -> None: - if output_dir: - self.output = Path(output_dir).resolve().absolute() - else: - self.output = Path(Path(os.getcwd()) / ".runs").resolve().absolute() - - if not self.output.exists(): - self.output.mkdir(parents=True, exist_ok=True) - - self.stack: List[Dict[str, Any]] = [] - - @contextlib.contextmanager - def tracer(self, name: str) -> Iterator[Callable[[str, Any], None]]: - try: - self.stack.append({"name": name}) - frame = self.stack[-1] - frame["__time"] = { - "start": datetime.now(), - } - - def add(key: str, value: Any) -> None: - if key not in frame: - frame[key] = value - # multiple values creates list - else: - if isinstance(frame[key], list): - frame[key].append(value) - else: - frame[key] = [frame[key], value] - - yield add - finally: - frame = self.stack.pop() - start: datetime = frame["__time"]["start"] - end: datetime = datetime.now() - - # add duration to frame - frame["__time"] = { - "start": start.strftime("%Y-%m-%dT%H:%M:%S.%f"), - "end": end.strftime("%Y-%m-%dT%H:%M:%S.%f"), - "duration": int((end - start).total_seconds() * 1000), - } - - # hoist usage to parent frame - if "result" in frame and isinstance(frame["result"], dict): - if "usage" in frame["result"]: - frame["__usage"] = self.hoist_item( - frame["result"]["usage"], - frame["__usage"] if "__usage" in frame else {}, - ) - - # streamed results may have usage as well - if "result" in frame and isinstance(frame["result"], list): - for result in frame["result"]: - if isinstance(result, dict) and "usage" in result and isinstance(result["usage"], dict): - frame["__usage"] = self.hoist_item( - result["usage"], - frame["__usage"] if "__usage" in frame else {}, - ) - - # add any usage frames from below - if "__frames" in frame: - for child in frame["__frames"]: - if "__usage" in child: - frame["__usage"] = self.hoist_item( - child["__usage"], - frame["__usage"] if "__usage" in frame else {}, - ) - - # if stack is empty, dump the frame - if len(self.stack) == 0: - self.write_trace(frame) - # otherwise, append the frame to the parent - else: - if "__frames" not in self.stack[-1]: - self.stack[-1]["__frames"] = [] - self.stack[-1]["__frames"].append(frame) - - def hoist_item(self, src: Dict[str, Any], cur: Dict[str, Any]) -> Dict[str, Any]: - for key, value in src.items(): - if value is None or isinstance(value, list) or isinstance(value, dict): - continue - try: - if key not in cur: - cur[key] = value - else: - cur[key] += value - except: - continue - - return cur - - def write_trace(self, frame: Dict[str, Any]) -> None: - trace_file = self.output / f"{frame['name']}.{datetime.now().strftime('%Y%m%d.%H%M%S')}.tracy" - - v = importlib.metadata.version("prompty") # type: ignore - enriched_frame = { - "runtime": "python", - "version": v, - "trace": frame, - } - - with open(trace_file, "w") as f: - json.dump(enriched_frame, f, indent=4) - - -@contextlib.contextmanager -def console_tracer(name: str) -> Iterator[Callable[[str, Any], None]]: - try: - print(f"Starting {name}") - yield lambda key, value: print(f"{key}:\n{json.dumps(to_dict(value), indent=4)}") - finally: - print(f"Ending {name}") diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/_utils.py b/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/_utils.py deleted file mode 100644 index 22f284180ee1..000000000000 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/prompts/_utils.py +++ /dev/null @@ -1,100 +0,0 @@ -# ------------------------------------ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# ------------------------------------ -# mypy: disable-error-code="import-untyped,return-value" -# pylint: disable=line-too-long,R,wrong-import-order,global-variable-not-assigned) -import json -import os -import re -import sys -from typing import Any, Dict -from pathlib import Path - - -_yaml_regex = re.compile( - r"^\s*" + r"(?:---|\+\+\+)" + r"(.*?)" + r"(?:---|\+\+\+)" + r"\s*(.+)$", - re.S | re.M, -) - - -def load_text(file_path, encoding="utf-8"): - with open(file_path, "r", encoding=encoding) as file: - return file.read() - - -def load_json(file_path, encoding="utf-8"): - return json.loads(load_text(file_path, encoding=encoding)) - - -def load_global_config(prompty_path: Path = Path.cwd(), configuration: str = "default") -> Dict[str, Any]: - prompty_config_path = prompty_path.joinpath("prompty.json") - if os.path.exists(prompty_config_path): - c = load_json(prompty_config_path) - if configuration in c: - return c[configuration] - else: - raise ValueError(f'Item "{configuration}" not found in "{prompty_config_path}"') - else: - return {} - - -def load_prompty(file_path, encoding="utf-8") -> Dict[str, Any]: - contents = load_text(file_path, encoding=encoding) - return parse(contents) - - -def parse(contents): - try: - import yaml # type: ignore - except ImportError as exc: - raise ImportError("Please install pyyaml to use this function. Run `pip install pyyaml`.") from exc - - global _yaml_regex - - fmatter = "" - body = "" - result = _yaml_regex.search(contents) - - if result: - fmatter = result.group(1) - body = result.group(2) - return { - "attributes": yaml.load(fmatter, Loader=yaml.SafeLoader), - "body": body, - "frontmatter": fmatter, - } - - -def remove_leading_empty_space(multiline_str: str) -> str: - """ - Processes a multiline string by: - 1. Removing empty lines - 2. Finding the minimum leading spaces - 3. Indenting all lines to the minimum level - - :param multiline_str: The input multiline string. - :type multiline_str: str - :return: The processed multiline string. - :rtype: str - """ - lines = multiline_str.splitlines() - start_index = 0 - while start_index < len(lines) and lines[start_index].strip() == "": - start_index += 1 - - # Find the minimum number of leading spaces - min_spaces = sys.maxsize - for line in lines[start_index:]: - if len(line.strip()) == 0: - continue - spaces = len(line) - len(line.lstrip()) - spaces += line.lstrip().count("\t") * 2 # Count tabs as 2 spaces - min_spaces = min(min_spaces, spaces) - - # Remove leading spaces and indent to the minimum level - processed_lines = [] - for line in lines[start_index:]: - processed_lines.append(line[min_spaces:]) - - return "\n".join(processed_lines) diff --git a/sdk/ai/azure-ai-inference/azure/ai/inference/tracing.py b/sdk/ai/azure-ai-inference/azure/ai/inference/tracing.py deleted file mode 100644 index dc3f0ed982e4..000000000000 --- a/sdk/ai/azure-ai-inference/azure/ai/inference/tracing.py +++ /dev/null @@ -1,823 +0,0 @@ -# ------------------------------------ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# ------------------------------------ -import copy -from enum import Enum -import functools -import json -import importlib -import logging -import os -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union -from urllib.parse import urlparse - -# pylint: disable = no-name-in-module -from azure.core import CaseInsensitiveEnumMeta # type: ignore -from azure.core.settings import settings -from . import models as _models - -try: - # pylint: disable = no-name-in-module - from azure.core.tracing import AbstractSpan, SpanKind # type: ignore - from opentelemetry.trace import StatusCode, Span - - _tracing_library_available = True -except ModuleNotFoundError: - - _tracing_library_available = False - - -__all__ = [ - "AIInferenceInstrumentor", -] - - -_inference_traces_enabled: bool = False -_trace_inference_content: bool = False -_INFERENCE_GEN_AI_SYSTEM_NAME = "az.ai.inference" - - -class TraceType(str, Enum, metaclass=CaseInsensitiveEnumMeta): # pylint: disable=C4747 - """An enumeration class to represent different types of traces.""" - - INFERENCE = "Inference" - - -class AIInferenceInstrumentor: - """ - A class for managing the trace instrumentation of AI Inference. - - This class allows enabling or disabling tracing for AI Inference. - and provides functionality to check whether instrumentation is active. - - """ - - def __init__(self): - if not _tracing_library_available: - raise ModuleNotFoundError( - "Azure Core Tracing Opentelemetry is not installed. " - "Please install it using 'pip install azure-core-tracing-opentelemetry'" - ) - # In the future we could support different versions from the same library - # and have a parameter that specifies the version to use. - self._impl = _AIInferenceInstrumentorPreview() - - def instrument(self, enable_content_recording: Optional[bool] = None) -> None: - """ - Enable trace instrumentation for AI Inference. - - :param enable_content_recording: Whether content recording is enabled as part - of the traces or not. Content in this context refers to chat message content - and function call tool related function names, function parameter names and - values. True will enable content recording, False will disable it. If no value - s provided, then the value read from environment variable - AZURE_TRACING_GEN_AI_CONTENT_RECORDING_ENABLED is used. If the environment variable - is not found, then the value will default to False. Please note that successive calls - to instrument will always apply the content recording value provided with the most - recent call to instrument (including applying the environment variable if no value is - provided and defaulting to false if the environment variable is not found), even if - instrument was already previously called without uninstrument being called in between - the instrument calls. - - :type enable_content_recording: bool, optional - """ - self._impl.instrument(enable_content_recording=enable_content_recording) - - def uninstrument(self) -> None: - """ - Disable trace instrumentation for AI Inference. - - Raises: - RuntimeError: If instrumentation is not currently enabled. - - This method removes any active instrumentation, stopping the tracing - of AI Inference. - """ - self._impl.uninstrument() - - def is_instrumented(self) -> bool: - """ - Check if trace instrumentation for AI Inference is currently enabled. - - :return: True if instrumentation is active, False otherwise. - :rtype: bool - """ - return self._impl.is_instrumented() - - def is_content_recording_enabled(self) -> bool: - """ - This function gets the content recording value. - - :return: A bool value indicating whether content recording is enabled. - :rtype: bool - """ - return self._impl.is_content_recording_enabled() - - -class _AIInferenceInstrumentorPreview: - """ - A class for managing the trace instrumentation of AI Inference. - - This class allows enabling or disabling tracing for AI Inference. - and provides functionality to check whether instrumentation is active. - """ - - def _str_to_bool(self, s): - if s is None: - return False - return str(s).lower() == "true" - - def instrument(self, enable_content_recording: Optional[bool] = None): - """ - Enable trace instrumentation for AI Inference. - - :param enable_content_recording: Whether content recording is enabled as part - of the traces or not. Content in this context refers to chat message content - and function call tool related function names, function parameter names and - values. True will enable content recording, False will disable it. If no value - is provided, then the value read from environment variable - AZURE_TRACING_GEN_AI_CONTENT_RECORDING_ENABLED is used. If the environment variable - is not found, then the value will default to False. - - :type enable_content_recording: bool, optional - """ - if enable_content_recording is None: - var_value = os.environ.get("AZURE_TRACING_GEN_AI_CONTENT_RECORDING_ENABLED") - enable_content_recording = self._str_to_bool(var_value) - if not self.is_instrumented(): - self._instrument_inference(enable_content_recording) - else: - self._set_content_recording_enabled(enable_content_recording=enable_content_recording) - - def uninstrument(self): - """ - Disable trace instrumentation for AI Inference. - - This method removes any active instrumentation, stopping the tracing - of AI Inference. - """ - if self.is_instrumented(): - self._uninstrument_inference() - - def is_instrumented(self): - """ - Check if trace instrumentation for AI Inference is currently enabled. - - :return: True if instrumentation is active, False otherwise. - :rtype: bool - """ - return self._is_instrumented() - - def set_content_recording_enabled(self, enable_content_recording: bool = False) -> None: - """This function sets the content recording value. - - :param enable_content_recording: Indicates whether tracing of message content should be enabled. - This also controls whether function call tool function names, - parameter names and parameter values are traced. - :type enable_content_recording: bool - """ - self._set_content_recording_enabled(enable_content_recording=enable_content_recording) - - def is_content_recording_enabled(self) -> bool: - """This function gets the content recording value. - - :return: A bool value indicating whether content tracing is enabled. - :rtype bool - """ - return self._is_content_recording_enabled() - - def _set_attributes(self, span: "AbstractSpan", *attrs: Tuple[str, Any]) -> None: - for attr in attrs: - key, value = attr - if value is not None: - span.add_attribute(key, value) - - def _add_request_chat_message_event(self, span: "AbstractSpan", **kwargs: Any) -> None: - for message in kwargs.get("messages", []): - try: - message = message.as_dict() - except AttributeError: - pass - - if message.get("role"): - name = f"gen_ai.{message.get('role')}.message" - span.span_instance.add_event( - name=name, - attributes={ - "gen_ai.system": _INFERENCE_GEN_AI_SYSTEM_NAME, - "gen_ai.event.content": json.dumps(message), - }, - ) - - def _parse_url(self, url): - parsed = urlparse(url) - server_address = parsed.hostname - port = parsed.port - return server_address, port - - def _add_request_chat_attributes(self, span: "AbstractSpan", *args: Any, **kwargs: Any) -> None: - client = args[0] - endpoint = client._config.endpoint # pylint: disable=protected-access - server_address, port = self._parse_url(endpoint) - model = "chat" - if kwargs.get("model") is not None: - model_value = kwargs.get("model") - if model_value is not None: - model = model_value - - self._set_attributes( - span, - ("gen_ai.operation.name", "chat"), - ("gen_ai.system", _INFERENCE_GEN_AI_SYSTEM_NAME), - ("gen_ai.request.model", model), - ("gen_ai.request.max_tokens", kwargs.get("max_tokens")), - ("gen_ai.request.temperature", kwargs.get("temperature")), - ("gen_ai.request.top_p", kwargs.get("top_p")), - ("server.address", server_address), - ) - if port is not None and port != 443: - span.add_attribute("server.port", port) - - def _remove_function_call_names_and_arguments(self, tool_calls: list) -> list: - tool_calls_copy = copy.deepcopy(tool_calls) - for tool_call in tool_calls_copy: - if "function" in tool_call: - if "name" in tool_call["function"]: - del tool_call["function"]["name"] - if "arguments" in tool_call["function"]: - del tool_call["function"]["arguments"] - if not tool_call["function"]: - del tool_call["function"] - return tool_calls_copy - - def _get_finish_reasons(self, result) -> Optional[List[str]]: - if hasattr(result, "choices") and result.choices: - finish_reasons: List[str] = [] - for choice in result.choices: - finish_reason = getattr(choice, "finish_reason", None) - - if finish_reason is None: - # If finish_reason is None, default to "none" - finish_reasons.append("none") - elif hasattr(finish_reason, "value"): - # If finish_reason has a 'value' attribute (i.e., it's an enum), use it - finish_reasons.append(finish_reason.value) - elif isinstance(finish_reason, str): - # If finish_reason is a string, use it directly - finish_reasons.append(finish_reason) - else: - # Default to "none" - finish_reasons.append("none") - - return finish_reasons - return None - - def _get_finish_reason_for_choice(self, choice): - finish_reason = getattr(choice, "finish_reason", None) - if finish_reason is not None: - return finish_reason.value - - return "none" - - def _add_response_chat_message_event(self, span: "AbstractSpan", result: _models.ChatCompletions) -> None: - for choice in result.choices: - if _trace_inference_content: - full_response: Dict[str, Any] = { - "message": {"content": choice.message.content}, - "finish_reason": self._get_finish_reason_for_choice(choice), - "index": choice.index, - } - if choice.message.tool_calls: - full_response["message"]["tool_calls"] = [tool.as_dict() for tool in choice.message.tool_calls] - attributes = { - "gen_ai.system": _INFERENCE_GEN_AI_SYSTEM_NAME, - "gen_ai.event.content": json.dumps(full_response), - } - else: - response: Dict[str, Any] = { - "finish_reason": self._get_finish_reason_for_choice(choice), - "index": choice.index, - } - if choice.message.tool_calls: - response["message"] = {} - tool_calls_function_names_and_arguments_removed = self._remove_function_call_names_and_arguments( - choice.message.tool_calls - ) - response["message"]["tool_calls"] = [ - tool.as_dict() for tool in tool_calls_function_names_and_arguments_removed - ] - - attributes = { - "gen_ai.system": _INFERENCE_GEN_AI_SYSTEM_NAME, - "gen_ai.event.content": json.dumps(response), - } - span.span_instance.add_event(name="gen_ai.choice", attributes=attributes) - - def _add_response_chat_attributes( - self, - span: "AbstractSpan", - result: Union[_models.ChatCompletions, _models.StreamingChatCompletionsUpdate], - ) -> None: - self._set_attributes( - span, - ("gen_ai.response.id", result.id), - ("gen_ai.response.model", result.model), - ( - "gen_ai.usage.input_tokens", - (result.usage.prompt_tokens if hasattr(result, "usage") and result.usage else None), - ), - ( - "gen_ai.usage.output_tokens", - (result.usage.completion_tokens if hasattr(result, "usage") and result.usage else None), - ), - ) - finish_reasons = self._get_finish_reasons(result) - if not finish_reasons is None: - span.add_attribute("gen_ai.response.finish_reasons", finish_reasons) # type: ignore - - def _add_request_span_attributes(self, span: "AbstractSpan", _span_name: str, args: Any, kwargs: Any) -> None: - self._add_request_chat_attributes(span, *args, **kwargs) - if _trace_inference_content: - self._add_request_chat_message_event(span, **kwargs) - - def _add_response_span_attributes(self, span: "AbstractSpan", result: object) -> None: - if isinstance(result, _models.ChatCompletions): - self._add_response_chat_attributes(span, result) - self._add_response_chat_message_event(span, result) - # TODO add more models here - - def _accumulate_response(self, item, accumulate: Dict[str, Any]) -> None: - if item.finish_reason: - accumulate["finish_reason"] = item.finish_reason - if item.index: - accumulate["index"] = item.index - if item.delta.content: - accumulate.setdefault("message", {}) - accumulate["message"].setdefault("content", "") - accumulate["message"]["content"] += item.delta.content - if item.delta.tool_calls: - accumulate.setdefault("message", {}) - accumulate["message"].setdefault("tool_calls", []) - if item.delta.tool_calls is not None: - for tool_call in item.delta.tool_calls: - if tool_call.id: - accumulate["message"]["tool_calls"].append( - { - "id": tool_call.id, - "type": "", - "function": {"name": "", "arguments": ""}, - } - ) - if tool_call.function: - accumulate["message"]["tool_calls"][-1]["type"] = "function" - if tool_call.function and tool_call.function.name: - accumulate["message"]["tool_calls"][-1]["function"]["name"] = tool_call.function.name - if tool_call.function and tool_call.function.arguments: - accumulate["message"]["tool_calls"][-1]["function"]["arguments"] += tool_call.function.arguments - - def _accumulate_async_streaming_response(self, item, accumulate: Dict[str, Any]) -> None: - if not "choices" in item: - return - if "finish_reason" in item["choices"][0] and item["choices"][0]["finish_reason"]: - accumulate["finish_reason"] = item["choices"][0]["finish_reason"] - if "index" in item["choices"][0] and item["choices"][0]["index"]: - accumulate["index"] = item["choices"][0]["index"] - if not "delta" in item["choices"][0]: - return - if "content" in item["choices"][0]["delta"] and item["choices"][0]["delta"]["content"]: - accumulate.setdefault("message", {}) - accumulate["message"].setdefault("content", "") - accumulate["message"]["content"] += item["choices"][0]["delta"]["content"] - if "tool_calls" in item["choices"][0]["delta"] and item["choices"][0]["delta"]["tool_calls"]: - accumulate.setdefault("message", {}) - accumulate["message"].setdefault("tool_calls", []) - if item["choices"][0]["delta"]["tool_calls"] is not None: - for tool_call in item["choices"][0]["delta"]["tool_calls"]: - if tool_call.id: - accumulate["message"]["tool_calls"].append( - { - "id": tool_call.id, - "type": "", - "function": {"name": "", "arguments": ""}, - } - ) - if tool_call.function: - accumulate["message"]["tool_calls"][-1]["type"] = "function" - if tool_call.function and tool_call.function.name: - accumulate["message"]["tool_calls"][-1]["function"]["name"] = tool_call.function.name - if tool_call.function and tool_call.function.arguments: - accumulate["message"]["tool_calls"][-1]["function"]["arguments"] += tool_call.function.arguments - - def _wrapped_stream( - self, stream_obj: _models.StreamingChatCompletions, span: "AbstractSpan" - ) -> _models.StreamingChatCompletions: - class StreamWrapper(_models.StreamingChatCompletions): - def __init__(self, stream_obj, instrumentor): - super().__init__(stream_obj._response) - self._instrumentor = instrumentor - - def __iter__( # pyright: ignore [reportIncompatibleMethodOverride] - self, - ) -> Iterator[_models.StreamingChatCompletionsUpdate]: - accumulate: Dict[str, Any] = {} - try: - chunk = None - for chunk in stream_obj: - for item in chunk.choices: - self._instrumentor._accumulate_response(item, accumulate) - yield chunk - - if chunk is not None: - self._instrumentor._add_response_chat_attributes(span, chunk) - - except Exception as exc: - # Set the span status to error - if isinstance(span.span_instance, Span): # pyright: ignore [reportPossiblyUnboundVariable] - span.span_instance.set_status( - StatusCode.ERROR, # pyright: ignore [reportPossiblyUnboundVariable] - description=str(exc), - ) - module = exc.__module__ if hasattr(exc, "__module__") and exc.__module__ != "builtins" else "" - error_type = f"{module}.{type(exc).__name__}" if module else type(exc).__name__ - self._instrumentor._set_attributes(span, ("error.type", error_type)) - raise - - finally: - if stream_obj._done is False: - if accumulate.get("finish_reason") is None: - accumulate["finish_reason"] = "error" - else: - # Only one choice expected with streaming - accumulate["index"] = 0 - # Delete message if content tracing is not enabled - if not _trace_inference_content: - if "message" in accumulate: - if "content" in accumulate["message"]: - del accumulate["message"]["content"] - if not accumulate["message"]: - del accumulate["message"] - if "message" in accumulate: - if "tool_calls" in accumulate["message"]: - tool_calls_function_names_and_arguments_removed = ( - self._instrumentor._remove_function_call_names_and_arguments( - accumulate["message"]["tool_calls"] - ) - ) - accumulate["message"]["tool_calls"] = list( - tool_calls_function_names_and_arguments_removed - ) - - span.span_instance.add_event( - name="gen_ai.choice", - attributes={ - "gen_ai.system": _INFERENCE_GEN_AI_SYSTEM_NAME, - "gen_ai.event.content": json.dumps(accumulate), - }, - ) - span.finish() - - return StreamWrapper(stream_obj, self) - - def _async_wrapped_stream( - self, stream_obj: _models.AsyncStreamingChatCompletions, span: "AbstractSpan" - ) -> _models.AsyncStreamingChatCompletions: - class AsyncStreamWrapper(_models.AsyncStreamingChatCompletions): - def __init__(self, stream_obj, instrumentor, span): - super().__init__(stream_obj._response) - self._instrumentor = instrumentor - self._accumulate: Dict[str, Any] = {} - self._stream_obj = stream_obj - self.span = span - self._last_result = None - - async def __anext__(self) -> "_models.StreamingChatCompletionsUpdate": - try: - result = await super().__anext__() - self._instrumentor._accumulate_async_streaming_response( # pylint: disable=protected-access, line-too-long # pyright: ignore [reportFunctionMemberAccess] - result, self._accumulate - ) - self._last_result = result - except StopAsyncIteration as exc: - self._trace_stream_content() - raise exc - return result - - def _trace_stream_content(self) -> None: - if self._last_result: - self._instrumentor._add_response_chat_attributes( # pylint: disable=protected-access, line-too-long # pyright: ignore [reportFunctionMemberAccess] - span, self._last_result - ) - # Only one choice expected with streaming - self._accumulate["index"] = 0 - # Delete message if content tracing is not enabled - if not _trace_inference_content: - if "message" in self._accumulate: - if "content" in self._accumulate["message"]: - del self._accumulate["message"]["content"] - if not self._accumulate["message"]: - del self._accumulate["message"] - if "message" in self._accumulate: - if "tool_calls" in self._accumulate["message"]: - tools_no_recording = self._instrumentor._remove_function_call_names_and_arguments( # pylint: disable=protected-access, line-too-long # pyright: ignore [reportFunctionMemberAccess] - self._accumulate["message"]["tool_calls"] - ) - self._accumulate["message"]["tool_calls"] = list(tools_no_recording) - - self.span.span_instance.add_event( - name="gen_ai.choice", - attributes={ - "gen_ai.system": _INFERENCE_GEN_AI_SYSTEM_NAME, - "gen_ai.event.content": json.dumps(self._accumulate), - }, - ) - span.finish() - - async_stream_wrapper = AsyncStreamWrapper(stream_obj, self, span) - return async_stream_wrapper - - def _trace_sync_function( - self, - function: Callable, - *, - _args_to_ignore: Optional[List[str]] = None, - _trace_type=TraceType.INFERENCE, - _name: Optional[str] = None, - ) -> Callable: - """ - Decorator that adds tracing to a synchronous function. - - :param function: The function to be traced. - :type function: Callable - :param args_to_ignore: A list of argument names to be ignored in the trace. - Defaults to None. - :type: args_to_ignore: [List[str]], optional - :param trace_type: The type of the trace. Defaults to TraceType.INFERENCE. - :type trace_type: TraceType, optional - :param name: The name of the trace, will set to func name if not provided. - :type name: str, optional - :return: The traced function. - :rtype: Callable - """ - - @functools.wraps(function) - def inner(*args, **kwargs): - - span_impl_type = settings.tracing_implementation() - if span_impl_type is None: - return function(*args, **kwargs) - - class_function_name = function.__qualname__ - - if class_function_name.startswith("ChatCompletionsClient.complete"): - if kwargs.get("model") is None: - span_name = "chat" - else: - model = kwargs.get("model") - span_name = f"chat {model}" - - span = span_impl_type( - name=span_name, - kind=SpanKind.CLIENT, # pyright: ignore [reportPossiblyUnboundVariable] - ) - try: - # tracing events not supported in azure-core-tracing-opentelemetry - # so need to access the span instance directly - with span_impl_type.change_context(span.span_instance): - self._add_request_span_attributes(span, span_name, args, kwargs) - result = function(*args, **kwargs) - if kwargs.get("stream") is True: - return self._wrapped_stream(result, span) - self._add_response_span_attributes(span, result) - - except Exception as exc: - # Set the span status to error - if isinstance(span.span_instance, Span): # pyright: ignore [reportPossiblyUnboundVariable] - span.span_instance.set_status( - StatusCode.ERROR, # pyright: ignore [reportPossiblyUnboundVariable] - description=str(exc), - ) - module = getattr(exc, "__module__", "") - module = module if module != "builtins" else "" - error_type = f"{module}.{type(exc).__name__}" if module else type(exc).__name__ - self._set_attributes(span, ("error.type", error_type)) - span.finish() - raise - - span.finish() - return result - - # Handle the default case (if the function name does not match) - return None # Ensure all paths return - - return inner - - def _trace_async_function( - self, - function: Callable, - *, - _args_to_ignore: Optional[List[str]] = None, - _trace_type=TraceType.INFERENCE, - _name: Optional[str] = None, - ) -> Callable: - """ - Decorator that adds tracing to an asynchronous function. - - :param function: The function to be traced. - :type function: Callable - :param args_to_ignore: A list of argument names to be ignored in the trace. - Defaults to None. - :type: args_to_ignore: [List[str]], optional - :param trace_type: The type of the trace. Defaults to TraceType.INFERENCE. - :type trace_type: TraceType, optional - :param name: The name of the trace, will set to func name if not provided. - :type name: str, optional - :return: The traced function. - :rtype: Callable - """ - - @functools.wraps(function) - async def inner(*args, **kwargs): - span_impl_type = settings.tracing_implementation() - if span_impl_type is None: - return await function(*args, **kwargs) - - class_function_name = function.__qualname__ - - if class_function_name.startswith("ChatCompletionsClient.complete"): - if kwargs.get("model") is None: - span_name = "chat" - else: - model = kwargs.get("model") - span_name = f"chat {model}" - - span = span_impl_type( - name=span_name, - kind=SpanKind.CLIENT, # pyright: ignore [reportPossiblyUnboundVariable] - ) - try: - # tracing events not supported in azure-core-tracing-opentelemetry - # so need to access the span instance directly - with span_impl_type.change_context(span.span_instance): - self._add_request_span_attributes(span, span_name, args, kwargs) - result = await function(*args, **kwargs) - if kwargs.get("stream") is True: - return self._async_wrapped_stream(result, span) - self._add_response_span_attributes(span, result) - - except Exception as exc: - # Set the span status to error - if isinstance(span.span_instance, Span): # pyright: ignore [reportPossiblyUnboundVariable] - span.span_instance.set_status( - StatusCode.ERROR, # pyright: ignore [reportPossiblyUnboundVariable] - description=str(exc), - ) - module = getattr(exc, "__module__", "") - module = module if module != "builtins" else "" - error_type = f"{module}.{type(exc).__name__}" if module else type(exc).__name__ - self._set_attributes(span, ("error.type", error_type)) - span.finish() - raise - - span.finish() - return result - - # Handle the default case (if the function name does not match) - return None # Ensure all paths return - - return inner - - def _inject_async(self, f, _trace_type, _name): - wrapper_fun = self._trace_async_function(f) - wrapper_fun._original = f # pylint: disable=protected-access # pyright: ignore [reportFunctionMemberAccess] - return wrapper_fun - - def _inject_sync(self, f, _trace_type, _name): - wrapper_fun = self._trace_sync_function(f) - wrapper_fun._original = f # pylint: disable=protected-access # pyright: ignore [reportFunctionMemberAccess] - return wrapper_fun - - def _inference_apis(self): - sync_apis = ( - ( - "azure.ai.inference", - "ChatCompletionsClient", - "complete", - TraceType.INFERENCE, - "inference_chat_completions_complete", - ), - ) - async_apis = ( - ( - "azure.ai.inference.aio", - "ChatCompletionsClient", - "complete", - TraceType.INFERENCE, - "inference_chat_completions_complete", - ), - ) - return sync_apis, async_apis - - def _inference_api_list(self): - sync_apis, async_apis = self._inference_apis() - yield sync_apis, self._inject_sync - yield async_apis, self._inject_async - - def _generate_api_and_injector(self, apis): - for api, injector in apis: - for module_name, class_name, method_name, trace_type, name in api: - try: - module = importlib.import_module(module_name) - api = getattr(module, class_name) - if hasattr(api, method_name): - yield api, method_name, trace_type, injector, name - except AttributeError as e: - # Log the attribute exception with the missing class information - logging.warning( - "AttributeError: The module '%s' does not have the class '%s'. %s", - module_name, - class_name, - str(e), - ) - except Exception as e: # pylint: disable=broad-except - # Log other exceptions as a warning, as we're not sure what they might be - logging.warning("An unexpected error occurred: '%s'", str(e)) - - def _available_inference_apis_and_injectors(self): - """ - Generates a sequence of tuples containing Inference API classes, method names, and - corresponding injector functions. - - :return: A generator yielding tuples. - :rtype: tuple - """ - yield from self._generate_api_and_injector(self._inference_api_list()) - - def _instrument_inference(self, enable_content_tracing: bool = False): - """This function modifies the methods of the Inference API classes to - inject logic before calling the original methods. - The original methods are stored as _original attributes of the methods. - - :param enable_content_tracing: Indicates whether tracing of message content should be enabled. - This also controls whether function call tool function names, - parameter names and parameter values are traced. - :type enable_content_tracing: bool - """ - # pylint: disable=W0603 - global _inference_traces_enabled - global _trace_inference_content - if _inference_traces_enabled: - raise RuntimeError("Traces already started for azure.ai.inference") - _inference_traces_enabled = True - _trace_inference_content = enable_content_tracing - for ( - api, - method, - trace_type, - injector, - name, - ) in self._available_inference_apis_and_injectors(): - # Check if the method of the api class has already been modified - if not hasattr(getattr(api, method), "_original"): - setattr(api, method, injector(getattr(api, method), trace_type, name)) - - def _uninstrument_inference(self): - """This function restores the original methods of the Inference API classes - by assigning them back from the _original attributes of the modified methods. - """ - # pylint: disable=W0603 - global _inference_traces_enabled - global _trace_inference_content - _trace_inference_content = False - for api, method, _, _, _ in self._available_inference_apis_and_injectors(): - if hasattr(getattr(api, method), "_original"): - setattr(api, method, getattr(getattr(api, method), "_original")) - _inference_traces_enabled = False - - def _is_instrumented(self): - """This function returns True if Inference libary has already been instrumented - for tracing and False if it has not been instrumented. - - :return: A value indicating whether the Inference library is currently instrumented or not. - :rtype: bool - """ - return _inference_traces_enabled - - def _set_content_recording_enabled(self, enable_content_recording: bool = False) -> None: - """This function sets the content recording value. - - :param enable_content_recording: Indicates whether tracing of message content should be enabled. - This also controls whether function call tool function names, - parameter names and parameter values are traced. - :type enable_content_recording: bool - """ - global _trace_inference_content # pylint: disable=W0603 - _trace_inference_content = enable_content_recording - - def _is_content_recording_enabled(self) -> bool: - """This function gets the content recording value. - - :return: A bool value indicating whether content tracing is enabled. - :rtype bool - """ - return _trace_inference_content diff --git a/sdk/ai/azure-ai-inference/sdk_packaging.toml b/sdk/ai/azure-ai-inference/sdk_packaging.toml new file mode 100644 index 000000000000..e7687fdae93b --- /dev/null +++ b/sdk/ai/azure-ai-inference/sdk_packaging.toml @@ -0,0 +1,2 @@ +[packaging] +auto_update = false \ No newline at end of file diff --git a/sdk/ai/azure-ai-inference/setup.py b/sdk/ai/azure-ai-inference/setup.py index 999dd87812fa..c7b5395a3f9f 100644 --- a/sdk/ai/azure-ai-inference/setup.py +++ b/sdk/ai/azure-ai-inference/setup.py @@ -13,7 +13,7 @@ PACKAGE_NAME = "azure-ai-inference" -PACKAGE_PPRINT_NAME = "Azure AI Inference" +PACKAGE_PPRINT_NAME = "Azure Ai Inference" # a-b-c => a/b/c package_folder_path = PACKAGE_NAME.replace("-", "/") @@ -35,7 +35,7 @@ license="MIT License", author="Microsoft Corporation", author_email="azpysdkhelp@microsoft.com", - url="https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/ai/azure-ai-inference", + url="https://github.com/Azure/azure-sdk-for-python/tree/main/sdk", keywords="azure, azure sdk", classifiers=[ "Development Status :: 4 - Beta", @@ -62,10 +62,10 @@ package_data={ "azure.ai.inference": ["py.typed"], }, - install_requires=["isodate>=0.6.1", "azure-core>=1.30.0", "typing-extensions>=4.6.0"], + install_requires=[ + "isodate>=0.6.1", + "azure-core>=1.30.0", + "typing-extensions>=4.6.0", + ], python_requires=">=3.8", - extras_require={ - "opentelemetry": ["azure-core-tracing-opentelemetry"], - "prompts": ["pyyaml"], - }, ) diff --git a/sdk/ai/azure-ai-inference/tests/test_model_inference_client.py b/sdk/ai/azure-ai-inference/tests/test_model_inference_client.py index 5ea57b1e2935..036e1052b59d 100644 --- a/sdk/ai/azure-ai-inference/tests/test_model_inference_client.py +++ b/sdk/ai/azure-ai-inference/tests/test_model_inference_client.py @@ -1,3 +1,4 @@ +# pylint: disable=too-many-lines # ------------------------------------ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. diff --git a/sdk/ai/azure-ai-inference/tsp-location.yaml b/sdk/ai/azure-ai-inference/tsp-location.yaml index df185250688b..938dc177810e 100644 --- a/sdk/ai/azure-ai-inference/tsp-location.yaml +++ b/sdk/ai/azure-ai-inference/tsp-location.yaml @@ -1,4 +1,4 @@ directory: specification/ai/ModelClient -commit: 3e95e575e537024a02470cf59c7a78078dc10cd1 +commit: 834383067ac02f95702b17f494fc1df973bd9455 repo: Azure/azure-rest-api-specs -additionalDirectories: +additionalDirectories: