Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/openai/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,17 @@
coerce_integer as coerce_integer,
file_from_path as file_from_path,
parse_datetime as parse_datetime,
is_azure_client as is_azure_client,
strip_not_given as strip_not_given,
deepcopy_minimal as deepcopy_minimal,
get_async_library as get_async_library,
maybe_coerce_float as maybe_coerce_float,
get_required_header as get_required_header,
maybe_coerce_boolean as maybe_coerce_boolean,
maybe_coerce_integer as maybe_coerce_integer,
is_async_azure_client as is_async_azure_client,
configure_azure_realtime as configure_azure_realtime,
configure_azure_realtime_async as configure_azure_realtime_async,
)
from ._typing import (
is_list_type as is_list_type,
Expand Down
52 changes: 51 additions & 1 deletion src/openai/_utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import inspect
import functools
from typing import (
TYPE_CHECKING,
Any,
Tuple,
Mapping,
Expand All @@ -21,7 +22,7 @@

import sniffio

from .._types import NotGiven, FileTypes, NotGivenOr, HeadersLike
from .._types import Query, NotGiven, FileTypes, NotGivenOr, HeadersLike
from .._compat import parse_date as parse_date, parse_datetime as parse_datetime

_T = TypeVar("_T")
Expand All @@ -30,6 +31,9 @@
_SequenceT = TypeVar("_SequenceT", bound=Sequence[object])
CallableT = TypeVar("CallableT", bound=Callable[..., Any])

if TYPE_CHECKING:
from ..lib.azure import AzureOpenAI, AsyncAzureOpenAI


def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
return [item for sublist in t for item in sublist]
Expand Down Expand Up @@ -412,3 +416,49 @@ def json_safe(data: object) -> object:
return data.isoformat()

return data


def is_azure_client(client: object) -> TypeGuard[AzureOpenAI]:
from ..lib.azure import AzureOpenAI

return isinstance(client, AzureOpenAI)


def is_async_azure_client(client: object) -> TypeGuard[AsyncAzureOpenAI]:
from ..lib.azure import AsyncAzureOpenAI

return isinstance(client, AsyncAzureOpenAI)


def configure_azure_realtime(client: AzureOpenAI, model: str, extra_query: Query) -> tuple[Query, dict[str, str]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is the reason the configure_azure_realtime and configure_azure_realtime_async functions are in _utils.py because of circular imports?

I think it'd be cleaner to define these as private methods on the client classes themselves, then the __aenter__ method changes are just

if is_async_azure_client(self.__client):
  extra_query, auth_headers = await self.__client._configure_realtime(self.__model, extra_query)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, that's a lot cleaner. Updated!

auth_headers = {}
query = {
**extra_query,
"api-version": client._api_version,
"deployment": model,
}
if client.api_key != "<missing API key>":
auth_headers = {"api-key": client.api_key}
else:
token = client._get_azure_ad_token()
if token:
auth_headers = {"Authorization": f"Bearer {token}"}
return query, auth_headers


async def configure_azure_realtime_async(
client: AsyncAzureOpenAI, model: str, extra_query: Query
) -> tuple[Query, dict[str, str]]:
auth_headers = {}
query = {
**extra_query,
"api-version": client._api_version,
"deployment": model,
}
if client.api_key != "<missing API key>":
auth_headers = {"api-key": client.api_key}
else:
token = await client._get_azure_ad_token()
if token:
auth_headers = {"Authorization": f"Bearer {token}"}
return query, auth_headers
22 changes: 18 additions & 4 deletions src/openai/resources/beta/realtime/realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@
)
from ...._types import NOT_GIVEN, Query, Headers, NotGiven
from ...._utils import (
is_azure_client,
maybe_transform,
strip_not_given,
async_maybe_transform,
is_async_azure_client,
configure_azure_realtime,
configure_azure_realtime_async,
)
from ...._compat import cached_property
from ...._models import construct_type_unchecked
Expand Down Expand Up @@ -319,11 +323,16 @@ async def __aenter__(self) -> AsyncRealtimeConnection:
except ImportError as exc:
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc

extra_query = self.__extra_query
auth_headers = self.__client.auth_headers
if is_async_azure_client(self.__client):
extra_query, auth_headers = await configure_azure_realtime_async(self.__client, self.__model, extra_query)

url = self._prepare_url().copy_with(
params={
**self.__client.base_url.params,
"model": self.__model,
**self.__extra_query,
**extra_query,
},
)
log.debug("Connecting to %s", url)
Expand All @@ -336,7 +345,7 @@ async def __aenter__(self) -> AsyncRealtimeConnection:
user_agent_header=self.__client.user_agent,
additional_headers=_merge_mappings(
{
**self.__client.auth_headers,
**auth_headers,
"OpenAI-Beta": "realtime=v1",
},
self.__extra_headers,
Expand Down Expand Up @@ -496,11 +505,16 @@ def __enter__(self) -> RealtimeConnection:
except ImportError as exc:
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc

extra_query = self.__extra_query
auth_headers = self.__client.auth_headers
if is_azure_client(self.__client):
extra_query, auth_headers = configure_azure_realtime(self.__client, self.__model, extra_query)

url = self._prepare_url().copy_with(
params={
**self.__client.base_url.params,
"model": self.__model,
**self.__extra_query,
**extra_query,
},
)
log.debug("Connecting to %s", url)
Expand All @@ -513,7 +527,7 @@ def __enter__(self) -> RealtimeConnection:
user_agent_header=self.__client.user_agent,
additional_headers=_merge_mappings(
{
**self.__client.auth_headers,
**auth_headers,
"OpenAI-Beta": "realtime=v1",
},
self.__extra_headers,
Expand Down
Loading