Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions src/openai/lib/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import os
import inspect
from typing import Any, Union, Mapping, TypeVar, Callable, Awaitable, overload
from typing_extensions import Self, override
import functools

from typing import Any, Union, Mapping, TypeVar, Callable, Awaitable, overload, Iterable
from typing_extensions import Self, override, ParamSpec

import httpx

Expand All @@ -14,6 +16,13 @@
from .._streaming import Stream, AsyncStream
from .._exceptions import OpenAIError
from .._base_client import DEFAULT_MAX_RETRIES, BaseClient
from .._compat import cached_property
from ..resources.chat import Chat, Completions
from .azure_types import AzureChatExtensionConfiguration, AzureChatEnhancementConfiguration

P = ParamSpec("P")
T = TypeVar("T")


_deployments_endpoints = set(
[
Expand All @@ -40,6 +49,44 @@
API_KEY_SENTINEL = "".join(["<", "missing API key", ">"])


def with_azure_options_wrapper(func: Callable[P, T], extras: Any) -> Callable[P, T]:

@functools.wraps(func)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> T:
kwargs.update({"extra_body": extras})
return func(*args, **kwargs)

return wrapped


class CompletionsWithAzureOptions(Completions):
def __init__(self, completions: AzureCompletions, **extras: Any) -> None:
self._completions = completions

self.create = with_azure_options_wrapper(
completions.create, extras
)


class AzureCompletions(Completions):

def with_azure_options(
self,
*,
data_sources: Iterable[AzureChatExtensionConfiguration] | None = None,
enhancements: AzureChatEnhancementConfiguration | None = None
) -> CompletionsWithAzureOptions:
return CompletionsWithAzureOptions(self, data_sources=data_sources, enhancements=enhancements)


class AzureChat(Chat):

@override # type: ignore
@cached_property
def completions(self) -> AzureCompletions:
return AzureCompletions(self._client)


class MutuallyExclusiveAuthError(OpenAIError):
def __init__(self) -> None:
super().__init__(
Expand All @@ -62,6 +109,8 @@ def _build_request(


class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
chat: AzureChat

@overload
def __init__(
self,
Expand Down Expand Up @@ -216,6 +265,7 @@ def __init__(
self._api_version = api_version
self._azure_ad_token = azure_ad_token
self._azure_ad_token_provider = azure_ad_token_provider
self.chat = AzureChat(self) # type: ignore

@override
def copy(
Expand Down
9 changes: 8 additions & 1 deletion src/openai/lib/azure_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"OnYourDataVectorizationSourceType",
"AzureCognitiveSearchQueryType",
"ElasticsearchQueryType",
"AzureChatExtensionConfiguration"
]

AzureChatExtensionType = Literal["AzureCognitiveSearch", "AzureMLIndex", "AzureCosmosDB", "Elasticsearch", "Pinecone"]
Expand All @@ -52,7 +53,13 @@
OnYourDataVectorizationSourceType = Literal["Endpoint", "DeploymentName", "ModelId"]
AzureCognitiveSearchQueryType = Literal["simple", "semantic", "vector", "vectorSimpleHybrid", "vectorSemanticHybrid"]
ElasticsearchQueryType = Literal["simple", "vector"]

AzureChatExtensionConfiguration = Union[
"AzureCognitiveSearchChatExtensionConfiguration",
"AzureCosmosDBChatExtensionConfiguration",
"AzureMachineLearningIndexChatExtensionConfiguration",
"ElasticsearchChatExtensionConfiguration",
"PineconeChatExtensionConfiguration",
]

class AzureChatEnhancementConfiguration(TypedDict, total=False):

Expand Down