Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,12 @@

from .lib import azure as _azure, pydantic_function_tool as pydantic_function_tool
from .version import VERSION as VERSION
from .lib.azure import AzureOpenAI as AzureOpenAI, AsyncAzureOpenAI as AsyncAzureOpenAI
from .lib.azure import (
AzureAuth as AzureAuth,
AzureOpenAI as AzureOpenAI,
AsyncAzureAuth as AsyncAzureAuth,
AsyncAzureOpenAI as AsyncAzureOpenAI,
)
from .lib._old_api import *
from .lib.streaming import (
AssistantEventHandler as AssistantEventHandler,
Expand Down Expand Up @@ -119,6 +124,8 @@

api_key: str | None = None

auth: AzureAuth | None = None

organization: str | None = None

project: str | None = None
Expand Down Expand Up @@ -165,6 +172,17 @@ def api_key(self, value: str | None) -> None: # type: ignore

api_key = value

@property # type: ignore
@override
def auth(self) -> AzureAuth | None:
return auth

@auth.setter # type: ignore
def auth(self, value: AzureAuth | None) -> None: # type: ignore
global auth

auth = value

@property # type: ignore
@override
def organization(self) -> str | None:
Expand Down Expand Up @@ -348,6 +366,7 @@ def _load_client() -> OpenAI: # type: ignore[reportUnusedFunction]

_client = _ModuleClient(
api_key=api_key,
auth=auth,
organization=organization,
project=project,
webhook_secret=webhook_secret,
Expand Down
79 changes: 63 additions & 16 deletions src/openai/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
get_async_library,
)
from ._compat import cached_property
from ._models import FinalRequestOptions
from ._version import __version__
from ._streaming import Stream as Stream, AsyncStream as AsyncStream
from ._exceptions import OpenAIError, APIStatusError
Expand All @@ -35,6 +36,7 @@
)

if TYPE_CHECKING:
from .lib.azure import AzureAuth, AsyncAzureAuth
from .resources import (
beta,
chat,
Expand Down Expand Up @@ -93,6 +95,7 @@ def __init__(
self,
*,
api_key: str | None = None,
auth: AzureAuth | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
Expand Down Expand Up @@ -124,13 +127,16 @@ def __init__(
- `project` from `OPENAI_PROJECT_ID`
- `webhook_secret` from `OPENAI_WEBHOOK_SECRET`
"""
if api_key and auth:
raise ValueError("The `api_key` and `auth` arguments are mutually exclusive")
if api_key is None:
api_key = os.environ.get("OPENAI_API_KEY")
if api_key is None:
if api_key is None and auth is None:
raise OpenAIError(
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
"The api_key or auth client option must be set either by passing api_key or auth to the client or by setting the OPENAI_API_KEY environment variable"
)
self.api_key = api_key
self.auth = auth
self.api_key = api_key or ""

if organization is None:
organization = os.environ.get("OPENAI_ORG_ID")
Expand Down Expand Up @@ -163,6 +169,7 @@ def __init__(
)

self._default_stream_cls = Stream
self._auth_headers: dict[str, str] = {}

@cached_property
def completions(self) -> Completions:
Expand Down Expand Up @@ -279,14 +286,25 @@ def with_streaming_response(self) -> OpenAIWithStreamedResponse:
def qs(self) -> Querystring:
return Querystring(array_format="brackets")

def refresh_auth_headers(self) -> None:
secret = self.auth.get_token() if self.auth else self.api_key
if not secret:
# if secret is an empty string, encoding the header will fail
# so we set it to an empty dict
# this is to avoid sending an invalid Authorization header
self._auth_headers = {}
else:
self._auth_headers = {"Authorization": f"Bearer {secret}"}

@override
def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
self.refresh_auth_headers()
return super()._prepare_options(options)

@property
@override
def auth_headers(self) -> dict[str, str]:
api_key = self.api_key
if not api_key:
# if the api key is an empty string, encoding the header will fail
return {}
return {"Authorization": f"Bearer {api_key}"}
return self._auth_headers

@property
@override
Expand All @@ -303,6 +321,7 @@ def copy(
self,
*,
api_key: str | None = None,
auth: AzureAuth | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
Expand Down Expand Up @@ -338,6 +357,10 @@ def copy(
elif set_default_query is not None:
params = set_default_query

auth = auth or self.auth
if auth is not None:
_extra_kwargs = {**_extra_kwargs, "auth": auth}

http_client = http_client or self._client
return self.__class__(
api_key=api_key or self.api_key,
Expand Down Expand Up @@ -412,6 +435,7 @@ def __init__(
self,
*,
api_key: str | None = None,
auth: AsyncAzureAuth | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
Expand Down Expand Up @@ -443,13 +467,16 @@ def __init__(
- `project` from `OPENAI_PROJECT_ID`
- `webhook_secret` from `OPENAI_WEBHOOK_SECRET`
"""
if api_key and auth:
raise ValueError("The `api_key` and `auth` arguments are mutually exclusive")
if api_key is None:
api_key = os.environ.get("OPENAI_API_KEY")
if api_key is None:
if api_key is None and auth is None:
raise OpenAIError(
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
"The api_key or auth client option must be set either by passing api_key or auth to the client or by setting the OPENAI_API_KEY environment variable"
)
self.api_key = api_key
self.auth = auth
self.api_key = api_key or ""

if organization is None:
organization = os.environ.get("OPENAI_ORG_ID")
Expand Down Expand Up @@ -482,6 +509,7 @@ def __init__(
)

self._default_stream_cls = AsyncStream
self._auth_headers: dict[str, str] = {}

@cached_property
def completions(self) -> AsyncCompletions:
Expand Down Expand Up @@ -598,14 +626,28 @@ def with_streaming_response(self) -> AsyncOpenAIWithStreamedResponse:
def qs(self) -> Querystring:
return Querystring(array_format="brackets")

async def refresh_auth_headers(self) -> None:
if self.auth:
secret = await self.auth.get_token()
else:
secret = self.api_key
if not secret:
# if the secret is an empty string, encoding the header will fail
# so we set it to an empty dict
# this is to avoid sending an invalid Authorization header
self._auth_headers = {}
else:
self._auth_headers = {"Authorization": f"Bearer {secret}"}

@override
async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
await self.refresh_auth_headers()
return await super()._prepare_options(options)

@property
@override
def auth_headers(self) -> dict[str, str]:
api_key = self.api_key
if not api_key:
# if the api key is an empty string, encoding the header will fail
return {}
return {"Authorization": f"Bearer {api_key}"}
return self._auth_headers

@property
@override
Expand All @@ -622,6 +664,7 @@ def copy(
self,
*,
api_key: str | None = None,
auth: AsyncAzureAuth | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
Expand Down Expand Up @@ -657,6 +700,10 @@ def copy(
elif set_default_query is not None:
params = set_default_query

auth = auth or self.auth
if auth is not None:
_extra_kwargs = {**_extra_kwargs, "auth": auth}

http_client = http_client or self._client
return self.__class__(
api_key=api_key or self.api_key,
Expand Down
111 changes: 106 additions & 5 deletions src/openai/lib/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
import inspect
from typing import Any, Union, Mapping, TypeVar, Callable, Awaitable, cast, overload
from typing import TYPE_CHECKING, Any, Union, Mapping, TypeVar, Callable, Awaitable, cast, overload
from typing_extensions import Self, override

import httpx
Expand Down Expand Up @@ -42,6 +42,107 @@
API_KEY_SENTINEL = "".join(["<", "missing API key", ">"])


if TYPE_CHECKING:
from azure.core.credentials import TokenCredential
from azure.core.credentials_async import AsyncTokenCredential


AzureTokenProvider = Callable[[], str]
AsyncAzureTokenProvider = Callable[[], Awaitable[str]]


class AzureAuth:
@overload
def __init__(self, *, token_provider: AzureTokenProvider) -> None: ...

@overload
def __init__(
self,
*,
credential: TokenCredential,
scopes: list[str] = ["https://cognitiveservices.azure.com/.default"],
) -> None: ...

def __init__(
self,
*,
token_provider: AzureTokenProvider | None = None,
credential: TokenCredential | None = None,
scopes: list[str] = ["https://cognitiveservices.azure.com/.default"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick, parameters of type list[str] need to be checked correct type passed in. You will get very unexpected results if you accidentally pass in a str.

) -> None:
if token_provider is not None and credential is not None:
raise ValueError("The `token_provider` and `credential` arguments are mutually exclusive.")
if token_provider is None and credential is None:
raise ValueError("One of `token_provider` or `credential` must be provided to AzureAuth.")
self.token_provider = token_provider
self.credential = credential
self.scopes = scopes

def get_token(self) -> str:
if self.token_provider is not None:
token = self.token_provider()
return token

if self.credential is not None:
try:
from azure.identity import get_bearer_token_provider
except ImportError as err:
raise ImportError(
"azure-identity library is not installed. Please install it to use AzureAuth."
) from err
token_provider = get_bearer_token_provider(self.credential, *self.scopes)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to cache this on self.

token = token_provider()
return token

raise ValueError("Unexpected values provided to AzureAuth. Unable to get token.")


class AsyncAzureAuth:
@overload
def __init__(self, *, token_provider: AsyncAzureTokenProvider) -> None: ...

@overload
def __init__(
self,
*,
credential: AsyncTokenCredential,
scopes: list[str] = ["https://cognitiveservices.azure.com/.default"],
) -> None: ...

def __init__(
self,
*,
token_provider: AsyncAzureTokenProvider | None = None,
credential: AsyncTokenCredential | None = None,
scopes: list[str] = ["https://cognitiveservices.azure.com/.default"],
) -> None:
if token_provider is not None and credential is not None:
raise ValueError("The `token_provider` and `credential` arguments are mutually exclusive.")
if token_provider is None and credential is None:
raise ValueError("One of `token_provider` or `credential` must be provided to AsyncAzureAuth.")
self.token_provider = token_provider
self.credential = credential
self.scopes = scopes

async def get_token(self) -> str:
if self.token_provider is not None:
token = await self.token_provider()
return token

if self.credential is not None:
Copy link
Collaborator

@johanste johanste Aug 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The AzureAsyncAuth instance needs to cache the bearer token provider.

try:
from azure.identity.aio import get_bearer_token_provider
except ImportError as err:
raise ImportError(
"azure-identity library is not installed. Please install it to use AsyncAzureAuth."
) from err
token_provider = get_bearer_token_provider(self.credential, *self.scopes)
token = await token_provider()
return token

raise ValueError("Unexpected values provided to AsyncAzureAuth. Unable to get token.")


class MutuallyExclusiveAuthError(OpenAIError):
def __init__(self) -> None:
super().__init__(
Expand Down Expand Up @@ -255,7 +356,7 @@ def __init__(
self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None

@override
def copy(
def copy( # type: ignore
self,
*,
api_key: str | None = None,
Expand Down Expand Up @@ -301,7 +402,7 @@ def copy(
},
)

with_options = copy
with_options = copy # type: ignore

def _get_azure_ad_token(self) -> str | None:
if self._azure_ad_token is not None:
Expand Down Expand Up @@ -536,7 +637,7 @@ def __init__(
self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None

@override
def copy(
def copy( # type: ignore
self,
*,
api_key: str | None = None,
Expand Down Expand Up @@ -582,7 +683,7 @@ def copy(
},
)

with_options = copy
with_options = copy # type: ignore

async def _get_azure_ad_token(self) -> str | None:
if self._azure_ad_token is not None:
Expand Down
Loading
Loading