Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@

from ._base_client import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES

api_key: str | None = None
api_key: str | _t.Callable[[], str] | None = None

organization: str | None = None

Expand Down Expand Up @@ -156,15 +156,15 @@ class _ModuleClient(OpenAI):

@property # type: ignore
@override
def api_key(self) -> str | None:
def api_key(self) -> str | _t.Callable[[], str] | None:
return api_key

@api_key.setter # type: ignore
def api_key(self, value: str | None) -> None: # type: ignore
def api_key(self, value: str | _t.Callable[[], str] | None) -> None: # type: ignore
global api_key

api_key = value


@property # type: ignore
@override
def organization(self) -> str | None:
Expand Down
82 changes: 63 additions & 19 deletions src/openai/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Any, Union, Mapping
from typing import TYPE_CHECKING, Any, Union, Mapping, Callable, Awaitable
from typing_extensions import Self, override

import httpx

from openai._models import FinalRequestOptions

from . import _exceptions
from ._qs import Querystring
from ._types import (
Expand Down Expand Up @@ -79,6 +81,7 @@
class OpenAI(SyncAPIClient):
# client options
api_key: str
bearer_token_provider: Callable[[], str] | None = None
organization: str | None
project: str | None
webhook_secret: str | None
Expand All @@ -94,7 +97,7 @@ class OpenAI(SyncAPIClient):
def __init__(
self,
*,
api_key: str | None = None,
api_key: str | None | Callable[[], str] = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
Expand Down Expand Up @@ -132,7 +135,12 @@ def __init__(
raise OpenAIError(
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
)
self.api_key = api_key
if callable(api_key):
self.api_key = ""
self.bearer_token_provider = api_key
else:
self.api_key = api_key or ""
self.bearer_token_provider = None

if organization is None:
organization = os.environ.get("OPENAI_ORG_ID")
Expand Down Expand Up @@ -165,6 +173,7 @@ def __init__(
)

self._default_stream_cls = Stream
self._auth_headers: dict[str, str] = {}

@cached_property
def completions(self) -> Completions:
Expand Down Expand Up @@ -287,14 +296,28 @@ def with_streaming_response(self) -> OpenAIWithStreamedResponse:
def qs(self) -> Querystring:
return Querystring(array_format="brackets")

def refresh_auth_headers(self) -> None:
if self.bearer_token_provider:
secret = self.bearer_token_provider()
else:
secret = self.api_key
if not secret:
# if the api key is an empty string, encoding the header will fail
# so we set it to an empty dict
# this is to avoid sending an invalid Authorization header
self._auth_headers = {}
else:
self._auth_headers = {"Authorization": f"Bearer {secret}"}

@override
def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
self.refresh_auth_headers()
return super()._prepare_options(options)

@property
@override
def auth_headers(self) -> dict[str, str]:
api_key = self.api_key
if not api_key:
# if the api key is an empty string, encoding the header will fail
return {}
return {"Authorization": f"Bearer {api_key}"}
return self._auth_headers

@property
@override
Expand All @@ -310,7 +333,7 @@ def default_headers(self) -> dict[str, str | Omit]:
def copy(
self,
*,
api_key: str | None = None,
api_key: str | Callable[[], str] | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
Expand Down Expand Up @@ -348,7 +371,7 @@ def copy(

http_client = http_client or self._client
return self.__class__(
api_key=api_key or self.api_key,
api_key=api_key or self.bearer_token_provider or self.api_key,
organization=organization or self.organization,
project=project or self.project,
webhook_secret=webhook_secret or self.webhook_secret,
Expand Down Expand Up @@ -404,6 +427,7 @@ def _make_status_error(
class AsyncOpenAI(AsyncAPIClient):
# client options
api_key: str
bearer_token_provider: Callable[[], Awaitable[str]] | None = None
organization: str | None
project: str | None
webhook_secret: str | None
Expand All @@ -419,7 +443,7 @@ class AsyncOpenAI(AsyncAPIClient):
def __init__(
self,
*,
api_key: str | None = None,
api_key: str | Callable[[], Awaitable[str]] | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
Expand Down Expand Up @@ -457,7 +481,12 @@ def __init__(
raise OpenAIError(
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
)
self.api_key = api_key
if callable(api_key):
self.api_key = ""
self.bearer_token_provider = api_key
else:
self.api_key = api_key or ""
self.bearer_token_provider = None

if organization is None:
organization = os.environ.get("OPENAI_ORG_ID")
Expand Down Expand Up @@ -490,6 +519,7 @@ def __init__(
)

self._default_stream_cls = AsyncStream
self._auth_headers: dict[str, str] = {}

@cached_property
def completions(self) -> AsyncCompletions:
Expand Down Expand Up @@ -612,14 +642,28 @@ def with_streaming_response(self) -> AsyncOpenAIWithStreamedResponse:
def qs(self) -> Querystring:
return Querystring(array_format="brackets")

async def refresh_auth_headers(self) -> None:
if self.bearer_token_provider:
secret = await self.bearer_token_provider()
else:
secret = self.api_key
if not secret:
# if the api key is an empty string, encoding the header will fail
# so we set it to an empty dict
# this is to avoid sending an invalid Authorization header
self._auth_headers = {}
else:
self._auth_headers = {"Authorization": f"Bearer {secret}"}

@override
async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
await self.refresh_auth_headers()
return await super()._prepare_options(options)

@property
@override
def auth_headers(self) -> dict[str, str]:
api_key = self.api_key
if not api_key:
# if the api key is an empty string, encoding the header will fail
return {}
return {"Authorization": f"Bearer {api_key}"}
return self._auth_headers

@property
@override
Expand All @@ -635,7 +679,7 @@ def default_headers(self) -> dict[str, str | Omit]:
def copy(
self,
*,
api_key: str | None = None,
api_key: str | Callable[[], Awaitable[str]] | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
Expand Down Expand Up @@ -673,7 +717,7 @@ def copy(

http_client = http_client or self._client
return self.__class__(
api_key=api_key or self.api_key,
api_key=api_key or self.bearer_token_provider or self.api_key,
organization=organization or self.organization,
project=project or self.project,
webhook_secret=webhook_secret or self.webhook_secret,
Expand Down
10 changes: 5 additions & 5 deletions src/openai/lib/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def __init__(
self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None

@override
def copy(
def copy( # type: ignore
self,
*,
api_key: str | None = None,
Expand Down Expand Up @@ -301,7 +301,7 @@ def copy(
},
)

with_options = copy
with_options = copy # type: ignore

def _get_azure_ad_token(self) -> str | None:
if self._azure_ad_token is not None:
Expand Down Expand Up @@ -536,7 +536,7 @@ def __init__(
self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None

@override
def copy(
def copy( # type: ignore
self,
*,
api_key: str | None = None,
Expand Down Expand Up @@ -582,7 +582,7 @@ def copy(
},
)

with_options = copy
with_options = copy # type: ignore

async def _get_azure_ad_token(self) -> str | None:
if self._azure_ad_token is not None:
Expand Down Expand Up @@ -628,7 +628,7 @@ async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[htt
"api-version": self._api_version,
"deployment": self._azure_deployment or model,
}
if self.api_key != "<missing API key>":
if self.api_key and self.api_key != "<missing API key>":
auth_headers = {"api-key": self.api_key}
else:
token = await self._get_azure_ad_token()
Expand Down
2 changes: 2 additions & 0 deletions src/openai/resources/beta/realtime/realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ async def __aenter__(self) -> AsyncRealtimeConnection:
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc

extra_query = self.__extra_query
await self.__client.refresh_auth_headers()
auth_headers = self.__client.auth_headers
if is_async_azure_client(self.__client):
url, auth_headers = await self.__client._configure_realtime(self.__model, extra_query)
Expand Down Expand Up @@ -540,6 +541,7 @@ def __enter__(self) -> RealtimeConnection:
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc

extra_query = self.__extra_query
self.__client.refresh_auth_headers()
auth_headers = self.__client.auth_headers
if is_azure_client(self.__client):
url, auth_headers = self.__client._configure_realtime(self.__model, extra_query)
Expand Down
Loading
Loading