diff --git a/src/openai/lib/_azure_realtime.py b/src/openai/lib/_azure_realtime.py new file mode 100644 index 0000000000..0954664134 --- /dev/null +++ b/src/openai/lib/_azure_realtime.py @@ -0,0 +1,220 @@ +import logging +from typing import TYPE_CHECKING +from typing_extensions import override + +from .._types import Query, Headers +from .._compat import cached_property +from ..resources import beta +from .._exceptions import OpenAIError +from .._base_client import _merge_mappings +from ..resources.beta import realtime +from ..types.websocket_connection_options import WebsocketConnectionOptions + +if TYPE_CHECKING: + from .azure import AzureOpenAI, AsyncAzureOpenAI + +log: logging.Logger = logging.getLogger(__name__) + + +class RealtimeConnectionManager(realtime.realtime.RealtimeConnectionManager): + @override + def __enter__(self) -> realtime.realtime.RealtimeConnection: + """ + 👋 If your application doesn't work well with the context manager approach then you + can call this method directly to initiate a connection. + + **Warning**: You must remember to close the connection with `.close()`. + + ```py + connection = client.beta.realtime.connect(...).enter() + # ... + connection.close() + ``` + """ + try: + from websockets.sync.client import connect + except ImportError as exc: + raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc + + auth_headers = self.__client.auth_headers + extra_query = self.__extra_query + self.__client: AzureOpenAI + extra_query = { + **self.__extra_query, + "api-version": self.__client._api_version, + "deployment": self.__client._azure_deployment or self.__model, + } + if self.__client.api_key != "": + auth_headers = {"api-key": self.__client.api_key} + else: + token = self.__client._get_azure_ad_token() + if token: + auth_headers = {"Authorization": f"Bearer {token}"} + + url = self._prepare_url().copy_with( + params={ + **self.__client.base_url.params, + "model": self.__model, + **extra_query, + }, + ) + log.debug("Connecting to %s", url) + if self.__websocket_connection_options: + log.debug("Connection options: %s", self.__websocket_connection_options) + + self.__connection = realtime.realtime.RealtimeConnection( + connect( + str(url), + user_agent_header=self.__client.user_agent, + additional_headers=_merge_mappings( + { + **auth_headers, + "OpenAI-Beta": "realtime=v1", + }, + self.__extra_headers, + ), + **self.__websocket_connection_options, + ) + ) + + return self.__connection + + enter = __enter__ + + +class Realtime(realtime.Realtime): + @override + def connect( + self, + *, + model: str, + extra_query: Query = {}, + extra_headers: Headers = {}, + websocket_connection_options: WebsocketConnectionOptions = {}, + ) -> RealtimeConnectionManager: + """ + The Realtime API enables you to build low-latency, multi-modal conversational experiences. It currently supports text and audio as both input and output, as well as function calling. + + Some notable benefits of the API include: + + - Native speech-to-speech: Skipping an intermediate text format means low latency and nuanced output. + - Natural, steerable voices: The models have natural inflection and can laugh, whisper, and adhere to tone direction. + - Simultaneous multimodal output: Text is useful for moderation; faster-than-realtime audio ensures stable playback. + + The Realtime API is a stateful, event-based API that communicates over a WebSocket. + """ + return RealtimeConnectionManager( + client=self._client, + extra_query=extra_query, + extra_headers=extra_headers, + websocket_connection_options=websocket_connection_options, + model=model, + ) + + +class Beta(beta.Beta): + @cached_property + def realtime(self) -> Realtime: # type: ignore[reportImplicitOverride] + return Realtime(self._client) + + +class AsyncRealtimeConnectionManager(realtime.realtime.AsyncRealtimeConnectionManager): + @override + async def __aenter__(self) -> realtime.realtime.AsyncRealtimeConnection: + """ + 👋 If your application doesn't work well with the context manager approach then you + can call this method directly to initiate a connection. + + **Warning**: You must remember to close the connection with `.close()`. + + ```py + connection = client.beta.realtime.connect(...).enter() + # ... + connection.close() + ``` + """ + try: + from websockets.asyncio.client import connect + except ImportError as exc: + raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc + + auth_headers = self.__client.auth_headers + extra_query = self.__extra_query + self.__client: AsyncAzureOpenAI + extra_query = { + **self.__extra_query, + "api-version": self.__client._api_version, + "deployment": self.__client._azure_deployment or self.__model, + } + if self.__client.api_key != "": + auth_headers = {"api-key": self.__client.api_key} + else: + token = await self.__client._get_azure_ad_token() + if token: + auth_headers = {"Authorization": f"Bearer {token}"} + + url = self._prepare_url().copy_with( + params={ + **self.__client.base_url.params, + "model": self.__model, + **extra_query, + }, + ) + log.debug("Connecting to %s", url) + if self.__websocket_connection_options: + log.debug("Connection options: %s", self.__websocket_connection_options) + + self.__connection = realtime.realtime.AsyncRealtimeConnection( + await connect( + str(url), + user_agent_header=self.__client.user_agent, + additional_headers=_merge_mappings( + { + **auth_headers, + "OpenAI-Beta": "realtime=v1", + }, + self.__extra_headers, + ), + **self.__websocket_connection_options, + ) + ) + + return self.__connection + + enter = __aenter__ + + +class AsyncRealtime(realtime.AsyncRealtime): + @override + def connect( + self, + *, + model: str, + extra_query: Query = {}, + extra_headers: Headers = {}, + websocket_connection_options: WebsocketConnectionOptions = {}, + ) -> AsyncRealtimeConnectionManager: + """ + The Realtime API enables you to build low-latency, multi-modal conversational experiences. It currently supports text and audio as both input and output, as well as function calling. + + Some notable benefits of the API include: + + - Native speech-to-speech: Skipping an intermediate text format means low latency and nuanced output. + - Natural, steerable voices: The models have natural inflection and can laugh, whisper, and adhere to tone direction. + - Simultaneous multimodal output: Text is useful for moderation; faster-than-realtime audio ensures stable playback. + + The Realtime API is a stateful, event-based API that communicates over a WebSocket. + """ + return AsyncRealtimeConnectionManager( + client=self._client, + extra_query=extra_query, + extra_headers=extra_headers, + websocket_connection_options=websocket_connection_options, + model=model, + ) + + +class AsyncBeta(beta.AsyncBeta): + @cached_property + def realtime(self) -> AsyncRealtime: # type: ignore[reportImplicitOverride] + return AsyncRealtime(self._client) diff --git a/src/openai/lib/azure.py b/src/openai/lib/azure.py index 13d9f31838..6320a84081 100644 --- a/src/openai/lib/azure.py +++ b/src/openai/lib/azure.py @@ -15,6 +15,7 @@ from .._streaming import Stream, AsyncStream from .._exceptions import OpenAIError from .._base_client import DEFAULT_MAX_RETRIES, BaseClient +from ._azure_realtime import Beta, AsyncBeta _deployments_endpoints = set( [ @@ -65,6 +66,8 @@ def _build_request( class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI): + beta: Beta + @overload def __init__( self, @@ -224,6 +227,8 @@ def __init__( self._api_version = api_version self._azure_ad_token = azure_ad_token self._azure_ad_token_provider = azure_ad_token_provider + self._azure_deployment = azure_deployment + self.beta = Beta(self) # type: ignore[reportIncompatibleVariableOverride] @override def copy( @@ -309,6 +314,8 @@ def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], AsyncOpenAI): + beta: AsyncBeta + @overload def __init__( self, @@ -471,6 +478,8 @@ def __init__( self._api_version = api_version self._azure_ad_token = azure_ad_token self._azure_ad_token_provider = azure_ad_token_provider + self._azure_deployment = azure_deployment + self.beta = AsyncBeta(self) # type: ignore[reportIncompatibleVariableOverride] @override def copy(