diff --git a/src/openai/lib/azure.py b/src/openai/lib/azure.py index 2c8b4dcd88..abd718e36c 100644 --- a/src/openai/lib/azure.py +++ b/src/openai/lib/azure.py @@ -2,8 +2,10 @@ import os import inspect -from typing import Any, Union, Mapping, TypeVar, Callable, Awaitable, overload -from typing_extensions import Self, override +import functools + +from typing import Any, Union, Mapping, TypeVar, Callable, Awaitable, overload, Iterable +from typing_extensions import Self, override, ParamSpec import httpx @@ -14,6 +16,13 @@ from .._streaming import Stream, AsyncStream from .._exceptions import OpenAIError from .._base_client import DEFAULT_MAX_RETRIES, BaseClient +from .._compat import cached_property +from ..resources.chat import Chat, Completions +from .azure_types import AzureChatExtensionConfiguration, AzureChatEnhancementConfiguration + +P = ParamSpec("P") +T = TypeVar("T") + _deployments_endpoints = set( [ @@ -40,6 +49,44 @@ API_KEY_SENTINEL = "".join(["<", "missing API key", ">"]) +def with_azure_options_wrapper(func: Callable[P, T], extras: Any) -> Callable[P, T]: + + @functools.wraps(func) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> T: + kwargs.update({"extra_body": extras}) + return func(*args, **kwargs) + + return wrapped + + +class CompletionsWithAzureOptions(Completions): + def __init__(self, completions: AzureCompletions, **extras: Any) -> None: + self._completions = completions + + self.create = with_azure_options_wrapper( + completions.create, extras + ) + + +class AzureCompletions(Completions): + + def with_azure_options( + self, + *, + data_sources: Iterable[AzureChatExtensionConfiguration] | None = None, + enhancements: AzureChatEnhancementConfiguration | None = None + ) -> CompletionsWithAzureOptions: + return CompletionsWithAzureOptions(self, data_sources=data_sources, enhancements=enhancements) + + +class AzureChat(Chat): + + @override # type: ignore + @cached_property + def completions(self) -> AzureCompletions: + return AzureCompletions(self._client) + + class MutuallyExclusiveAuthError(OpenAIError): def __init__(self) -> None: super().__init__( @@ -62,6 +109,8 @@ def _build_request( class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI): + chat: AzureChat + @overload def __init__( self, @@ -216,6 +265,7 @@ def __init__( self._api_version = api_version self._azure_ad_token = azure_ad_token self._azure_ad_token_provider = azure_ad_token_provider + self.chat = AzureChat(self) # type: ignore @override def copy( diff --git a/src/openai/lib/azure_types.py b/src/openai/lib/azure_types.py index 8c89b78fcb..51c58a2329 100644 --- a/src/openai/lib/azure_types.py +++ b/src/openai/lib/azure_types.py @@ -37,6 +37,7 @@ "OnYourDataVectorizationSourceType", "AzureCognitiveSearchQueryType", "ElasticsearchQueryType", + "AzureChatExtensionConfiguration" ] AzureChatExtensionType = Literal["AzureCognitiveSearch", "AzureMLIndex", "AzureCosmosDB", "Elasticsearch", "Pinecone"] @@ -52,7 +53,13 @@ OnYourDataVectorizationSourceType = Literal["Endpoint", "DeploymentName", "ModelId"] AzureCognitiveSearchQueryType = Literal["simple", "semantic", "vector", "vectorSimpleHybrid", "vectorSemanticHybrid"] ElasticsearchQueryType = Literal["simple", "vector"] - +AzureChatExtensionConfiguration = Union[ + "AzureCognitiveSearchChatExtensionConfiguration", + "AzureCosmosDBChatExtensionConfiguration", + "AzureMachineLearningIndexChatExtensionConfiguration", + "ElasticsearchChatExtensionConfiguration", + "PineconeChatExtensionConfiguration", +] class AzureChatEnhancementConfiguration(TypedDict, total=False):