Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions src/openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@

from ._base_client import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES

api_key: str | None = None
api_key: str | _t.Callable[[], str] | None = None

organization: str | None = None

Expand Down Expand Up @@ -156,16 +156,27 @@ class _ModuleClient(OpenAI):

@property # type: ignore
@override
def api_key(self) -> str | None:
return api_key
def api_key(self) -> str | _t.Callable[[], str] | None:
return api_key() if callable(api_key) else api_key
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels like it could be brittle and cause some weird behaviour if we access self.api_key in multiple places for the same request.

My gut is that we could just not support callable api keys for the module client?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed with @johanste offline and agreed we can remove callable api keys for the module level client.


@api_key.setter # type: ignore
def api_key(self, value: str | None) -> None: # type: ignore
def api_key(self, value: str | _t.Callable[[], str] | None) -> None: # type: ignore
global api_key
api_key = value

@property
def _api_key_provider(self) -> _t.Callable[[], str] | None: # type: ignore
return None

@_api_key_provider.setter
def _api_key_provider(self, value: _t.Callable[[], str] | None) -> None: # type: ignore
global api_key
# Yes, setting the api_key is intentional. The module level client accepts callables
# for the module level api_key and will call it to retrieve the value
# if it is a callable.
api_key = value

@property # type: ignore
@property
@override
def organization(self) -> str | None:
return organization
Expand Down
48 changes: 39 additions & 9 deletions src/openai/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Any, Union, Mapping
from typing import TYPE_CHECKING, Any, Union, Mapping, Callable, Awaitable
from typing_extensions import Self, override

import httpx

from openai._models import FinalRequestOptions

from . import _exceptions
from ._qs import Querystring
from ._types import (
Expand Down Expand Up @@ -94,7 +96,7 @@ class OpenAI(SyncAPIClient):
def __init__(
self,
*,
api_key: str | None = None,
api_key: str | None | Callable[[], str] = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
Expand Down Expand Up @@ -132,7 +134,12 @@ def __init__(
raise OpenAIError(
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
)
self.api_key = api_key
if callable(api_key):
self.api_key = ""
self._api_key_provider: Callable[[], str] | None = api_key
else:
self.api_key = api_key or ""
self._api_key_provider = None

if organization is None:
organization = os.environ.get("OPENAI_ORG_ID")
Expand Down Expand Up @@ -287,6 +294,15 @@ def with_streaming_response(self) -> OpenAIWithStreamedResponse:
def qs(self) -> Querystring:
return Querystring(array_format="brackets")

def _refresh_api_key(self) -> None:
if self._api_key_provider:
self.api_key = self._api_key_provider()

@override
def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
self._refresh_api_key()
return super()._prepare_options(options)

@property
@override
def auth_headers(self) -> dict[str, str]:
Expand All @@ -310,7 +326,7 @@ def default_headers(self) -> dict[str, str | Omit]:
def copy(
self,
*,
api_key: str | None = None,
api_key: str | Callable[[], str] | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
Expand Down Expand Up @@ -348,7 +364,7 @@ def copy(

http_client = http_client or self._client
return self.__class__(
api_key=api_key or self.api_key,
api_key=api_key or self._api_key_provider or self.api_key,
organization=organization or self.organization,
project=project or self.project,
webhook_secret=webhook_secret or self.webhook_secret,
Expand Down Expand Up @@ -419,7 +435,7 @@ class AsyncOpenAI(AsyncAPIClient):
def __init__(
self,
*,
api_key: str | None = None,
api_key: str | Callable[[], Awaitable[str]] | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
Expand Down Expand Up @@ -457,7 +473,12 @@ def __init__(
raise OpenAIError(
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
)
self.api_key = api_key
if callable(api_key):
self.api_key = ""
self._api_key_provider: Callable[[], Awaitable[str]] | None = api_key
else:
self.api_key = api_key or ""
self._api_key_provider = None

if organization is None:
organization = os.environ.get("OPENAI_ORG_ID")
Expand Down Expand Up @@ -612,6 +633,15 @@ def with_streaming_response(self) -> AsyncOpenAIWithStreamedResponse:
def qs(self) -> Querystring:
return Querystring(array_format="brackets")

async def _refresh_api_key(self) -> None:
if self._api_key_provider:
self.api_key = await self._api_key_provider()

@override
async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
await self._refresh_api_key()
return await super()._prepare_options(options)

@property
@override
def auth_headers(self) -> dict[str, str]:
Expand All @@ -635,7 +665,7 @@ def default_headers(self) -> dict[str, str | Omit]:
def copy(
self,
*,
api_key: str | None = None,
api_key: str | Callable[[], Awaitable[str]] | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
Expand Down Expand Up @@ -673,7 +703,7 @@ def copy(

http_client = http_client or self._client
return self.__class__(
api_key=api_key or self.api_key,
api_key=api_key or self._api_key_provider or self.api_key,
organization=organization or self.organization,
project=project or self.project,
webhook_secret=webhook_secret or self.webhook_secret,
Expand Down
8 changes: 4 additions & 4 deletions src/openai/lib/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def __init__(
def copy(
self,
*,
api_key: str | None = None,
api_key: str | Callable[[], str] | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
Expand Down Expand Up @@ -435,7 +435,7 @@ def __init__(
azure_endpoint: str | None = None,
azure_deployment: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
api_key: str | Callable[[], Awaitable[str]] | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
organization: str | None = None,
Expand Down Expand Up @@ -539,7 +539,7 @@ def __init__(
def copy(
self,
*,
api_key: str | None = None,
api_key: str | Callable[[], Awaitable[str]] | None = None,
organization: str | None = None,
project: str | None = None,
webhook_secret: str | None = None,
Expand Down Expand Up @@ -628,7 +628,7 @@ async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[htt
"api-version": self._api_version,
"deployment": self._azure_deployment or model,
}
if self.api_key != "<missing API key>":
if self.api_key and self.api_key != "<missing API key>":
auth_headers = {"api-key": self.api_key}
else:
token = await self._get_azure_ad_token()
Expand Down
2 changes: 2 additions & 0 deletions src/openai/resources/beta/realtime/realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ async def __aenter__(self) -> AsyncRealtimeConnection:
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc

extra_query = self.__extra_query
await self.__client._refresh_api_key()
auth_headers = self.__client.auth_headers
if is_async_azure_client(self.__client):
url, auth_headers = await self.__client._configure_realtime(self.__model, extra_query)
Expand Down Expand Up @@ -540,6 +541,7 @@ def __enter__(self) -> RealtimeConnection:
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc

extra_query = self.__extra_query
self.__client._refresh_api_key()
auth_headers = self.__client.auth_headers
if is_azure_client(self.__client):
url, auth_headers = self.__client._configure_realtime(self.__model, extra_query)
Expand Down
Loading
Loading