Skip to content

Commit 80672db

Browse files
committed
extract azure logic out of enter
1 parent 69d816d commit 80672db

File tree

3 files changed

+49
-27
lines changed

3 files changed

+49
-27
lines changed

src/openai/_utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
maybe_coerce_boolean as maybe_coerce_boolean,
3535
maybe_coerce_integer as maybe_coerce_integer,
3636
is_async_azure_client as is_async_azure_client,
37+
configure_azure_realtime as configure_azure_realtime,
38+
configure_azure_realtime_async as configure_azure_realtime_async,
3739
)
3840
from ._typing import (
3941
is_list_type as is_list_type,

src/openai/_utils/_utils.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
import sniffio
2424

25-
from .._types import NotGiven, FileTypes, NotGivenOr, HeadersLike
25+
from .._types import Query, NotGiven, FileTypes, NotGivenOr, HeadersLike
2626
from .._compat import parse_date as parse_date, parse_datetime as parse_datetime
2727

2828
_T = TypeVar("_T")
@@ -419,8 +419,46 @@ def json_safe(data: object) -> object:
419419

420420

421421
def is_azure_client(client: object) -> TypeGuard[AzureOpenAI]:
422-
return hasattr(client, "_azure_ad_token_provider")
422+
from ..lib.azure import AzureOpenAI
423+
424+
return isinstance(client, AzureOpenAI)
423425

424426

425427
def is_async_azure_client(client: object) -> TypeGuard[AsyncAzureOpenAI]:
426-
return hasattr(client, "_azure_ad_token_provider")
428+
from ..lib.azure import AsyncAzureOpenAI
429+
430+
return isinstance(client, AsyncAzureOpenAI)
431+
432+
433+
def configure_azure_realtime(client: AzureOpenAI, model: str, extra_query: Query) -> tuple[Query, dict[str, str]]:
434+
auth_headers = {}
435+
query = {
436+
**extra_query,
437+
"api-version": client._api_version,
438+
"deployment": client._azure_deployment or model,
439+
}
440+
if client.api_key != "<missing API key>":
441+
auth_headers = {"api-key": client.api_key}
442+
else:
443+
token = client._get_azure_ad_token()
444+
if token:
445+
auth_headers = {"Authorization": f"Bearer {token}"}
446+
return query, auth_headers
447+
448+
449+
async def configure_azure_realtime_async(
450+
client: AsyncAzureOpenAI, model: str, extra_query: Query
451+
) -> tuple[Query, dict[str, str]]:
452+
auth_headers = {}
453+
query = {
454+
**extra_query,
455+
"api-version": client._api_version,
456+
"deployment": client._azure_deployment or model,
457+
}
458+
if client.api_key != "<missing API key>":
459+
auth_headers = {"api-key": client.api_key}
460+
else:
461+
token = await client._get_azure_ad_token()
462+
if token:
463+
auth_headers = {"Authorization": f"Bearer {token}"}
464+
return query, auth_headers

src/openai/resources/beta/realtime/realtime.py

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
strip_not_given,
2727
async_maybe_transform,
2828
is_async_azure_client,
29+
configure_azure_realtime,
30+
configure_azure_realtime_async,
2931
)
3032
from ...._compat import cached_property
3133
from ...._models import construct_type_unchecked
@@ -321,20 +323,10 @@ async def __aenter__(self) -> AsyncRealtimeConnection:
321323
except ImportError as exc:
322324
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc
323325

324-
auth_headers = self.__client.auth_headers
325326
extra_query = self.__extra_query
327+
auth_headers = self.__client.auth_headers
326328
if is_async_azure_client(self.__client):
327-
extra_query = {
328-
**self.__extra_query,
329-
"api-version": self.__client._api_version,
330-
"deployment": self.__client._azure_deployment or self.__model,
331-
}
332-
if self.__client.api_key != "<missing API key>":
333-
auth_headers = {"api-key": self.__client.api_key}
334-
else:
335-
token = await self.__client._get_azure_ad_token()
336-
if token:
337-
auth_headers = {"Authorization": f"Bearer {token}"}
329+
extra_query, auth_headers = await configure_azure_realtime_async(self.__client, self.__model, extra_query)
338330

339331
url = self._prepare_url().copy_with(
340332
params={
@@ -513,20 +505,10 @@ def __enter__(self) -> RealtimeConnection:
513505
except ImportError as exc:
514506
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc
515507

516-
auth_headers = self.__client.auth_headers
517508
extra_query = self.__extra_query
509+
auth_headers = self.__client.auth_headers
518510
if is_azure_client(self.__client):
519-
extra_query = {
520-
**self.__extra_query,
521-
"api-version": self.__client._api_version,
522-
"deployment": self.__client._azure_deployment or self.__model,
523-
}
524-
if self.__client.api_key != "<missing API key>":
525-
auth_headers = {"api-key": self.__client.api_key}
526-
else:
527-
token = self.__client._get_azure_ad_token()
528-
if token:
529-
auth_headers = {"Authorization": f"Bearer {token}"}
511+
extra_query, auth_headers = configure_azure_realtime(self.__client, self.__model, extra_query)
530512

531513
url = self._prepare_url().copy_with(
532514
params={

0 commit comments

Comments
 (0)