Skip to content

Commit 69d816d

Browse files
committed
lint
1 parent d7d991d commit 69d816d

File tree

3 files changed

+20
-4
lines changed

3 files changed

+20
-4
lines changed

src/openai/_utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,15 @@
2525
coerce_integer as coerce_integer,
2626
file_from_path as file_from_path,
2727
parse_datetime as parse_datetime,
28+
is_azure_client as is_azure_client,
2829
strip_not_given as strip_not_given,
2930
deepcopy_minimal as deepcopy_minimal,
3031
get_async_library as get_async_library,
3132
maybe_coerce_float as maybe_coerce_float,
3233
get_required_header as get_required_header,
3334
maybe_coerce_boolean as maybe_coerce_boolean,
3435
maybe_coerce_integer as maybe_coerce_integer,
36+
is_async_azure_client as is_async_azure_client,
3537
)
3638
from ._typing import (
3739
is_list_type as is_list_type,

src/openai/_utils/_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import inspect
66
import functools
77
from typing import (
8+
TYPE_CHECKING,
89
Any,
910
Tuple,
1011
Mapping,
@@ -30,6 +31,9 @@
3031
_SequenceT = TypeVar("_SequenceT", bound=Sequence[object])
3132
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
3233

34+
if TYPE_CHECKING:
35+
from ..lib.azure import AzureOpenAI, AsyncAzureOpenAI
36+
3337

3438
def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
3539
return [item for sublist in t for item in sublist]
@@ -412,3 +416,11 @@ def json_safe(data: object) -> object:
412416
return data.isoformat()
413417

414418
return data
419+
420+
421+
def is_azure_client(client: object) -> TypeGuard[AzureOpenAI]:
422+
return hasattr(client, "_azure_ad_token_provider")
423+
424+
425+
def is_async_azure_client(client: object) -> TypeGuard[AsyncAzureOpenAI]:
426+
return hasattr(client, "_azure_ad_token_provider")

src/openai/resources/beta/realtime/realtime.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
)
2222
from ...._types import NOT_GIVEN, Query, Headers, NotGiven
2323
from ...._utils import (
24+
is_azure_client,
2425
maybe_transform,
2526
strip_not_given,
2627
async_maybe_transform,
28+
is_async_azure_client,
2729
)
2830
from ...._compat import cached_property
2931
from ...._models import construct_type_unchecked
@@ -321,11 +323,11 @@ async def __aenter__(self) -> AsyncRealtimeConnection:
321323

322324
auth_headers = self.__client.auth_headers
323325
extra_query = self.__extra_query
324-
if self.__client.__class__.__name__ == "AsyncAzureOpenAI":
326+
if is_async_azure_client(self.__client):
325327
extra_query = {
326328
**self.__extra_query,
327329
"api-version": self.__client._api_version,
328-
"deployment": self.__client._azure_deployment or self.__model
330+
"deployment": self.__client._azure_deployment or self.__model,
329331
}
330332
if self.__client.api_key != "<missing API key>":
331333
auth_headers = {"api-key": self.__client.api_key}
@@ -513,11 +515,11 @@ def __enter__(self) -> RealtimeConnection:
513515

514516
auth_headers = self.__client.auth_headers
515517
extra_query = self.__extra_query
516-
if self.__client.__class__.__name__ == "AzureOpenAI":
518+
if is_azure_client(self.__client):
517519
extra_query = {
518520
**self.__extra_query,
519521
"api-version": self.__client._api_version,
520-
"deployment": self.__client._azure_deployment or self.__model
522+
"deployment": self.__client._azure_deployment or self.__model,
521523
}
522524
if self.__client.api_key != "<missing API key>":
523525
auth_headers = {"api-key": self.__client.api_key}

0 commit comments

Comments
 (0)