From 9445e3caf2ab2af3d9b830da2bcc1f5ad0d25c25 Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Tue, 17 Dec 2024 18:10:13 -0800 Subject: [PATCH 1/2] monkeypatch connect --- src/openai/lib/_azure_realtime.py | 215 ++++++++++++++++++++++++++++++ src/openai/lib/azure.py | 10 ++ 2 files changed, 225 insertions(+) create mode 100644 src/openai/lib/_azure_realtime.py diff --git a/src/openai/lib/_azure_realtime.py b/src/openai/lib/_azure_realtime.py new file mode 100644 index 0000000000..0927e6a1b1 --- /dev/null +++ b/src/openai/lib/_azure_realtime.py @@ -0,0 +1,215 @@ +import logging +from typing import TYPE_CHECKING + +from .._exceptions import OpenAIError +from ..resources import beta +from ..resources.beta import realtime +from .._compat import cached_property +from .._types import Query, Headers +from ..types.websocket_connection_options import WebsocketConnectionOptions +from .._base_client import _merge_mappings + +if TYPE_CHECKING: + from .azure import AzureOpenAI, AsyncAzureOpenAI + +log: logging.Logger = logging.getLogger(__name__) + + +class RealtimeConnectionManager(realtime.realtime.RealtimeConnectionManager): + def __enter__(self) -> realtime.realtime.RealtimeConnection: + """ + 👋 If your application doesn't work well with the context manager approach then you + can call this method directly to initiate a connection. + + **Warning**: You must remember to close the connection with `.close()`. + + ```py + connection = client.beta.realtime.connect(...).enter() + # ... + connection.close() + ``` + """ + try: + from websockets.sync.client import connect + except ImportError as exc: + raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc + + auth_headers = self.__client.auth_headers + extra_query = self.__extra_query + self.__client: AzureOpenAI + extra_query = { + **self.__extra_query, + "api-version": self.__client._api_version, + "deployment": self.__client._azure_deployment or self.__model + } + if self.__client.api_key != "": + auth_headers = {"api-key": self.__client.api_key} + else: + token = self.__client._get_azure_ad_token() + if token: + auth_headers = {"Authorization": f"Bearer {token}"} + + url = self._prepare_url().copy_with( + params={ + **self.__client.base_url.params, + "model": self.__model, + **extra_query, + }, + ) + log.debug("Connecting to %s", url) + if self.__websocket_connection_options: + log.debug("Connection options: %s", self.__websocket_connection_options) + + self.__connection = realtime.realtime.RealtimeConnection( + connect( + str(url), + user_agent_header=self.__client.user_agent, + additional_headers=_merge_mappings( + { + **auth_headers, + "OpenAI-Beta": "realtime=v1", + }, + self.__extra_headers, + ), + **self.__websocket_connection_options, + ) + ) + + return self.__connection + + enter = __enter__ + + +class Realtime(realtime.Realtime): + + def connect( + self, + *, + model: str, + extra_query: Query = {}, + extra_headers: Headers = {}, + websocket_connection_options: WebsocketConnectionOptions = {}, + ) -> RealtimeConnectionManager: + """ + The Realtime API enables you to build low-latency, multi-modal conversational experiences. It currently supports text and audio as both input and output, as well as function calling. + + Some notable benefits of the API include: + + - Native speech-to-speech: Skipping an intermediate text format means low latency and nuanced output. + - Natural, steerable voices: The models have natural inflection and can laugh, whisper, and adhere to tone direction. + - Simultaneous multimodal output: Text is useful for moderation; faster-than-realtime audio ensures stable playback. + + The Realtime API is a stateful, event-based API that communicates over a WebSocket. + """ + return RealtimeConnectionManager( + client=self._client, + extra_query=extra_query, + extra_headers=extra_headers, + websocket_connection_options=websocket_connection_options, + model=model, + ) + +class Beta(beta.Beta): + @cached_property + def realtime(self) -> Realtime: + return Realtime(self._client) + + +class AsyncRealtimeConnectionManager(realtime.realtime.AsyncRealtimeConnectionManager): + async def __aenter__(self) -> realtime.realtime.AsyncRealtimeConnection: + """ + 👋 If your application doesn't work well with the context manager approach then you + can call this method directly to initiate a connection. + + **Warning**: You must remember to close the connection with `.close()`. + + ```py + connection = client.beta.realtime.connect(...).enter() + # ... + connection.close() + ``` + """ + try: + from websockets.asyncio.client import connect + except ImportError as exc: + raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc + + auth_headers = self.__client.auth_headers + extra_query = self.__extra_query + self.__client: AsyncAzureOpenAI + extra_query = { + **self.__extra_query, + "api-version": self.__client._api_version, + "deployment": self.__client._azure_deployment or self.__model + } + if self.__client.api_key != "": + auth_headers = {"api-key": self.__client.api_key} + else: + token = await self.__client._get_azure_ad_token() + if token: + auth_headers = {"Authorization": f"Bearer {token}"} + + url = self._prepare_url().copy_with( + params={ + **self.__client.base_url.params, + "model": self.__model, + **extra_query, + }, + ) + log.debug("Connecting to %s", url) + if self.__websocket_connection_options: + log.debug("Connection options: %s", self.__websocket_connection_options) + + self.__connection = realtime.realtime.AsyncRealtimeConnection( + await connect( + str(url), + user_agent_header=self.__client.user_agent, + additional_headers=_merge_mappings( + { + **auth_headers, + "OpenAI-Beta": "realtime=v1", + }, + self.__extra_headers, + ), + **self.__websocket_connection_options, + ) + ) + + return self.__connection + + enter = __aenter__ + + +class AsyncRealtime(realtime.AsyncRealtime): + + def connect( + self, + *, + model: str, + extra_query: Query = {}, + extra_headers: Headers = {}, + websocket_connection_options: WebsocketConnectionOptions = {}, + ) -> AsyncRealtimeConnectionManager: + """ + The Realtime API enables you to build low-latency, multi-modal conversational experiences. It currently supports text and audio as both input and output, as well as function calling. + + Some notable benefits of the API include: + + - Native speech-to-speech: Skipping an intermediate text format means low latency and nuanced output. + - Natural, steerable voices: The models have natural inflection and can laugh, whisper, and adhere to tone direction. + - Simultaneous multimodal output: Text is useful for moderation; faster-than-realtime audio ensures stable playback. + + The Realtime API is a stateful, event-based API that communicates over a WebSocket. + """ + return AsyncRealtimeConnectionManager( + client=self._client, + extra_query=extra_query, + extra_headers=extra_headers, + websocket_connection_options=websocket_connection_options, + model=model, + ) + +class AsyncBeta(beta.AsyncBeta): + @cached_property + def realtime(self) -> AsyncRealtime: + return AsyncRealtime(self._client) diff --git a/src/openai/lib/azure.py b/src/openai/lib/azure.py index 13d9f31838..609a4d0ec1 100644 --- a/src/openai/lib/azure.py +++ b/src/openai/lib/azure.py @@ -15,6 +15,8 @@ from .._streaming import Stream, AsyncStream from .._exceptions import OpenAIError from .._base_client import DEFAULT_MAX_RETRIES, BaseClient +from ._azure_realtime import Beta, AsyncBeta + _deployments_endpoints = set( [ @@ -65,6 +67,8 @@ def _build_request( class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI): + beta: Beta + @overload def __init__( self, @@ -224,6 +228,8 @@ def __init__( self._api_version = api_version self._azure_ad_token = azure_ad_token self._azure_ad_token_provider = azure_ad_token_provider + self._azure_deployment = azure_deployment + self.beta = Beta(self) @override def copy( @@ -309,6 +315,8 @@ def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], AsyncOpenAI): + beta: AsyncBeta + @overload def __init__( self, @@ -471,6 +479,8 @@ def __init__( self._api_version = api_version self._azure_ad_token = azure_ad_token self._azure_ad_token_provider = azure_ad_token_provider + self._azure_deployment = azure_deployment + self.beta = AsyncBeta(self) @override def copy( From 82b5830c92d8cca4644e62d87c7762b837574774 Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Wed, 18 Dec 2024 10:26:49 -0800 Subject: [PATCH 2/2] lint --- src/openai/lib/_azure_realtime.py | 25 +++++++++++++++---------- src/openai/lib/azure.py | 5 ++--- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/openai/lib/_azure_realtime.py b/src/openai/lib/_azure_realtime.py index 0927e6a1b1..0954664134 100644 --- a/src/openai/lib/_azure_realtime.py +++ b/src/openai/lib/_azure_realtime.py @@ -1,13 +1,14 @@ import logging from typing import TYPE_CHECKING +from typing_extensions import override -from .._exceptions import OpenAIError +from .._types import Query, Headers +from .._compat import cached_property from ..resources import beta +from .._exceptions import OpenAIError +from .._base_client import _merge_mappings from ..resources.beta import realtime -from .._compat import cached_property -from .._types import Query, Headers from ..types.websocket_connection_options import WebsocketConnectionOptions -from .._base_client import _merge_mappings if TYPE_CHECKING: from .azure import AzureOpenAI, AsyncAzureOpenAI @@ -16,6 +17,7 @@ class RealtimeConnectionManager(realtime.realtime.RealtimeConnectionManager): + @override def __enter__(self) -> realtime.realtime.RealtimeConnection: """ 👋 If your application doesn't work well with the context manager approach then you @@ -40,7 +42,7 @@ def __enter__(self) -> realtime.realtime.RealtimeConnection: extra_query = { **self.__extra_query, "api-version": self.__client._api_version, - "deployment": self.__client._azure_deployment or self.__model + "deployment": self.__client._azure_deployment or self.__model, } if self.__client.api_key != "": auth_headers = {"api-key": self.__client.api_key} @@ -81,7 +83,7 @@ def __enter__(self) -> realtime.realtime.RealtimeConnection: class Realtime(realtime.Realtime): - + @override def connect( self, *, @@ -109,13 +111,15 @@ def connect( model=model, ) + class Beta(beta.Beta): @cached_property - def realtime(self) -> Realtime: + def realtime(self) -> Realtime: # type: ignore[reportImplicitOverride] return Realtime(self._client) class AsyncRealtimeConnectionManager(realtime.realtime.AsyncRealtimeConnectionManager): + @override async def __aenter__(self) -> realtime.realtime.AsyncRealtimeConnection: """ 👋 If your application doesn't work well with the context manager approach then you @@ -140,7 +144,7 @@ async def __aenter__(self) -> realtime.realtime.AsyncRealtimeConnection: extra_query = { **self.__extra_query, "api-version": self.__client._api_version, - "deployment": self.__client._azure_deployment or self.__model + "deployment": self.__client._azure_deployment or self.__model, } if self.__client.api_key != "": auth_headers = {"api-key": self.__client.api_key} @@ -181,7 +185,7 @@ async def __aenter__(self) -> realtime.realtime.AsyncRealtimeConnection: class AsyncRealtime(realtime.AsyncRealtime): - + @override def connect( self, *, @@ -209,7 +213,8 @@ def connect( model=model, ) + class AsyncBeta(beta.AsyncBeta): @cached_property - def realtime(self) -> AsyncRealtime: + def realtime(self) -> AsyncRealtime: # type: ignore[reportImplicitOverride] return AsyncRealtime(self._client) diff --git a/src/openai/lib/azure.py b/src/openai/lib/azure.py index 609a4d0ec1..6320a84081 100644 --- a/src/openai/lib/azure.py +++ b/src/openai/lib/azure.py @@ -17,7 +17,6 @@ from .._base_client import DEFAULT_MAX_RETRIES, BaseClient from ._azure_realtime import Beta, AsyncBeta - _deployments_endpoints = set( [ "/completions", @@ -229,7 +228,7 @@ def __init__( self._azure_ad_token = azure_ad_token self._azure_ad_token_provider = azure_ad_token_provider self._azure_deployment = azure_deployment - self.beta = Beta(self) + self.beta = Beta(self) # type: ignore[reportIncompatibleVariableOverride] @override def copy( @@ -480,7 +479,7 @@ def __init__( self._azure_ad_token = azure_ad_token self._azure_ad_token_provider = azure_ad_token_provider self._azure_deployment = azure_deployment - self.beta = AsyncBeta(self) + self.beta = AsyncBeta(self) # type: ignore[reportIncompatibleVariableOverride] @override def copy(