diff --git a/src/openai/_client.py b/src/openai/_client.py index fe5ebac42a..2be32fe13f 100644 --- a/src/openai/_client.py +++ b/src/openai/_client.py @@ -3,7 +3,7 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING, Any, Union, Mapping +from typing import TYPE_CHECKING, Any, Union, Mapping, Callable, Awaitable from typing_extensions import Self, override import httpx @@ -25,6 +25,7 @@ get_async_library, ) from ._compat import cached_property +from ._models import FinalRequestOptions from ._version import __version__ from ._streaming import Stream as Stream, AsyncStream as AsyncStream from ._exceptions import OpenAIError, APIStatusError @@ -96,7 +97,7 @@ class OpenAI(SyncAPIClient): def __init__( self, *, - api_key: str | None = None, + api_key: str | None | Callable[[], str] = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -134,7 +135,12 @@ def __init__( raise OpenAIError( "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" ) - self.api_key = api_key + if callable(api_key): + self.api_key = "" + self._api_key_provider: Callable[[], str] | None = api_key + else: + self.api_key = api_key + self._api_key_provider = None if organization is None: organization = os.environ.get("OPENAI_ORG_ID") @@ -295,6 +301,15 @@ def with_streaming_response(self) -> OpenAIWithStreamedResponse: def qs(self) -> Querystring: return Querystring(array_format="brackets") + def _refresh_api_key(self) -> None: + if self._api_key_provider: + self.api_key = self._api_key_provider() + + @override + def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: + self._refresh_api_key() + return super()._prepare_options(options) + @property @override def auth_headers(self) -> dict[str, str]: @@ -318,7 +333,7 @@ def default_headers(self) -> dict[str, str | Omit]: def copy( self, *, - api_key: str | None = None, + api_key: str | Callable[[], str] | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -356,7 +371,7 @@ def copy( http_client = http_client or self._client return self.__class__( - api_key=api_key or self.api_key, + api_key=api_key or self._api_key_provider or self.api_key, organization=organization or self.organization, project=project or self.project, webhook_secret=webhook_secret or self.webhook_secret, @@ -427,7 +442,7 @@ class AsyncOpenAI(AsyncAPIClient): def __init__( self, *, - api_key: str | None = None, + api_key: str | Callable[[], Awaitable[str]] | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -465,7 +480,12 @@ def __init__( raise OpenAIError( "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" ) - self.api_key = api_key + if callable(api_key): + self.api_key = "" + self._api_key_provider: Callable[[], Awaitable[str]] | None = api_key + else: + self.api_key = api_key + self._api_key_provider = None if organization is None: organization = os.environ.get("OPENAI_ORG_ID") @@ -626,6 +646,15 @@ def with_streaming_response(self) -> AsyncOpenAIWithStreamedResponse: def qs(self) -> Querystring: return Querystring(array_format="brackets") + async def _refresh_api_key(self) -> None: + if self._api_key_provider: + self.api_key = await self._api_key_provider() + + @override + async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: + await self._refresh_api_key() + return await super()._prepare_options(options) + @property @override def auth_headers(self) -> dict[str, str]: @@ -649,7 +678,7 @@ def default_headers(self) -> dict[str, str | Omit]: def copy( self, *, - api_key: str | None = None, + api_key: str | Callable[[], Awaitable[str]] | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -687,7 +716,7 @@ def copy( http_client = http_client or self._client return self.__class__( - api_key=api_key or self.api_key, + api_key=api_key or self._api_key_provider or self.api_key, organization=organization or self.organization, project=project or self.project, webhook_secret=webhook_secret or self.webhook_secret, diff --git a/src/openai/lib/azure.py b/src/openai/lib/azure.py index a994e4256c..ad64707261 100644 --- a/src/openai/lib/azure.py +++ b/src/openai/lib/azure.py @@ -94,7 +94,7 @@ def __init__( azure_endpoint: str, azure_deployment: str | None = None, api_version: str | None = None, - api_key: str | None = None, + api_key: str | Callable[[], str] | None = None, azure_ad_token: str | None = None, azure_ad_token_provider: AzureADTokenProvider | None = None, organization: str | None = None, @@ -114,7 +114,7 @@ def __init__( *, azure_deployment: str | None = None, api_version: str | None = None, - api_key: str | None = None, + api_key: str | Callable[[], str] | None = None, azure_ad_token: str | None = None, azure_ad_token_provider: AzureADTokenProvider | None = None, organization: str | None = None, @@ -134,7 +134,7 @@ def __init__( *, base_url: str, api_version: str | None = None, - api_key: str | None = None, + api_key: str | Callable[[], str] | None = None, azure_ad_token: str | None = None, azure_ad_token_provider: AzureADTokenProvider | None = None, organization: str | None = None, @@ -154,7 +154,7 @@ def __init__( api_version: str | None = None, azure_endpoint: str | None = None, azure_deployment: str | None = None, - api_key: str | None = None, + api_key: str | Callable[[], str] | None = None, azure_ad_token: str | None = None, azure_ad_token_provider: AzureADTokenProvider | None = None, organization: str | None = None, @@ -258,7 +258,7 @@ def __init__( def copy( self, *, - api_key: str | None = None, + api_key: str | Callable[[], str] | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -345,7 +345,7 @@ def _configure_realtime(self, model: str, extra_query: Query) -> tuple[httpx.URL "api-version": self._api_version, "deployment": self._azure_deployment or model, } - if self.api_key != "": + if self.api_key and self.api_key != "": auth_headers = {"api-key": self.api_key} else: token = self._get_azure_ad_token() @@ -372,7 +372,7 @@ def __init__( azure_endpoint: str, azure_deployment: str | None = None, api_version: str | None = None, - api_key: str | None = None, + api_key: str | Callable[[], Awaitable[str]] | None = None, azure_ad_token: str | None = None, azure_ad_token_provider: AsyncAzureADTokenProvider | None = None, organization: str | None = None, @@ -393,7 +393,7 @@ def __init__( *, azure_deployment: str | None = None, api_version: str | None = None, - api_key: str | None = None, + api_key: str | Callable[[], Awaitable[str]] | None = None, azure_ad_token: str | None = None, azure_ad_token_provider: AsyncAzureADTokenProvider | None = None, organization: str | None = None, @@ -414,7 +414,7 @@ def __init__( *, base_url: str, api_version: str | None = None, - api_key: str | None = None, + api_key: str | Callable[[], Awaitable[str]] | None = None, azure_ad_token: str | None = None, azure_ad_token_provider: AsyncAzureADTokenProvider | None = None, organization: str | None = None, @@ -435,7 +435,7 @@ def __init__( azure_endpoint: str | None = None, azure_deployment: str | None = None, api_version: str | None = None, - api_key: str | None = None, + api_key: str | Callable[[], Awaitable[str]] | None = None, azure_ad_token: str | None = None, azure_ad_token_provider: AsyncAzureADTokenProvider | None = None, organization: str | None = None, @@ -539,7 +539,7 @@ def __init__( def copy( self, *, - api_key: str | None = None, + api_key: str | Callable[[], Awaitable[str]] | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -628,7 +628,7 @@ async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[htt "api-version": self._api_version, "deployment": self._azure_deployment or model, } - if self.api_key != "": + if self.api_key and self.api_key != "": auth_headers = {"api-key": self.api_key} else: token = await self._get_azure_ad_token() diff --git a/src/openai/resources/beta/realtime/realtime.py b/src/openai/resources/beta/realtime/realtime.py index 7b99c7f6c4..4fa35963b6 100644 --- a/src/openai/resources/beta/realtime/realtime.py +++ b/src/openai/resources/beta/realtime/realtime.py @@ -358,6 +358,7 @@ async def __aenter__(self) -> AsyncRealtimeConnection: raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc extra_query = self.__extra_query + await self.__client._refresh_api_key() auth_headers = self.__client.auth_headers if is_async_azure_client(self.__client): url, auth_headers = await self.__client._configure_realtime(self.__model, extra_query) @@ -540,6 +541,7 @@ def __enter__(self) -> RealtimeConnection: raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc extra_query = self.__extra_query + self.__client._refresh_api_key() auth_headers = self.__client.auth_headers if is_azure_client(self.__client): url, auth_headers = self.__client._configure_realtime(self.__model, extra_query) diff --git a/src/openai/resources/realtime/realtime.py b/src/openai/resources/realtime/realtime.py index ebdfce86e3..2f5adf6548 100644 --- a/src/openai/resources/realtime/realtime.py +++ b/src/openai/resources/realtime/realtime.py @@ -326,6 +326,7 @@ async def __aenter__(self) -> AsyncRealtimeConnection: raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc extra_query = self.__extra_query + await self.__client._refresh_api_key() auth_headers = self.__client.auth_headers if is_async_azure_client(self.__client): url, auth_headers = await self.__client._configure_realtime(self.__model, extra_query) @@ -507,6 +508,7 @@ def __enter__(self) -> RealtimeConnection: raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc extra_query = self.__extra_query + self.__client._refresh_api_key() auth_headers = self.__client.auth_headers if is_azure_client(self.__client): url, auth_headers = self.__client._configure_realtime(self.__model, extra_query) diff --git a/tests/test_client.py b/tests/test_client.py index ccda50a7f0..e5300e55d7 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -11,7 +11,7 @@ import inspect import subprocess import tracemalloc -from typing import Any, Union, cast +from typing import Any, Union, Protocol, cast from textwrap import dedent from unittest import mock from typing_extensions import Literal @@ -41,6 +41,10 @@ api_key = "My API Key" +class MockRequestCall(Protocol): + request: httpx.Request + + def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]: request = client._build_request(FinalRequestOptions(method="get", url="/foo")) url = httpx.URL(request.url) @@ -337,7 +341,9 @@ def test_default_headers_option(self) -> None: def test_validate_headers(self) -> None: client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True) - request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + options = client._prepare_options(FinalRequestOptions(method="get", url="/foo")) + request = client._build_request(options) + assert request.headers.get("Authorization") == f"Bearer {api_key}" with pytest.raises(OpenAIError): @@ -939,6 +945,62 @@ def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: assert exc_info.value.response.status_code == 302 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" + def test_api_key_before_after_refresh_provider(self) -> None: + client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token") + + assert client.api_key == "" + assert "Authorization" not in client.auth_headers + + client._refresh_api_key() + + assert client.api_key == "test_bearer_token" + assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token" + + def test_api_key_before_after_refresh_str(self) -> None: + client = OpenAI(base_url=base_url, api_key="test_api_key") + + assert client.auth_headers.get("Authorization") == "Bearer test_api_key" + client._refresh_api_key() + + assert client.auth_headers.get("Authorization") == "Bearer test_api_key" + + @pytest.mark.respx() + def test_api_key_refresh_on_retry(self, respx_mock: MockRouter) -> None: + respx_mock.post(base_url + "/chat/completions").mock( + side_effect=[ + httpx.Response(500, json={"error": "server error"}), + httpx.Response(200, json={"foo": "bar"}), + ] + ) + + counter = 0 + + def token_provider() -> str: + nonlocal counter + + counter += 1 + + if counter == 1: + return "first" + + return "second" + + client = OpenAI(base_url=base_url, api_key=token_provider) + client.chat.completions.create(messages=[], model="gpt-4") + + calls = cast("list[MockRequestCall]", respx_mock.calls) + assert len(calls) == 2 + + assert calls[0].request.headers.get("Authorization") == "Bearer first" + assert calls[1].request.headers.get("Authorization") == "Bearer second" + + def test_copy_auth(self) -> None: + client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token_1").copy( + api_key=lambda: "test_bearer_token_2" + ) + client._refresh_api_key() + assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"} + class TestAsyncOpenAI: client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True) @@ -1220,9 +1282,10 @@ def test_default_headers_option(self) -> None: assert request.headers.get("x-foo") == "stainless" assert request.headers.get("x-stainless-lang") == "my-overriding-header" - def test_validate_headers(self) -> None: + async def test_validate_headers(self) -> None: client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True) - request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + options = await client._prepare_options(FinalRequestOptions(method="get", url="/foo")) + request = client._build_request(options) assert request.headers.get("Authorization") == f"Bearer {api_key}" with pytest.raises(OpenAIError): @@ -1887,3 +1950,70 @@ async def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: assert exc_info.value.response.status_code == 302 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" + + @pytest.mark.asyncio + async def test_api_key_before_after_refresh_provider(self) -> None: + async def mock_api_key_provider(): + return "test_bearer_token" + + client = AsyncOpenAI(base_url=base_url, api_key=mock_api_key_provider) + + assert client.api_key == "" + assert "Authorization" not in client.auth_headers + + await client._refresh_api_key() + + assert client.api_key == "test_bearer_token" + assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token" + + @pytest.mark.asyncio + async def test_api_key_before_after_refresh_str(self) -> None: + client = AsyncOpenAI(base_url=base_url, api_key="test_api_key") + + assert client.auth_headers.get("Authorization") == "Bearer test_api_key" + await client._refresh_api_key() + + assert client.auth_headers.get("Authorization") == "Bearer test_api_key" + + @pytest.mark.asyncio + @pytest.mark.respx() + async def test_bearer_token_refresh_async(self, respx_mock: MockRouter) -> None: + respx_mock.post(base_url + "/chat/completions").mock( + side_effect=[ + httpx.Response(500, json={"error": "server error"}), + httpx.Response(200, json={"foo": "bar"}), + ] + ) + + counter = 0 + + async def token_provider() -> str: + nonlocal counter + + counter += 1 + + if counter == 1: + return "first" + + return "second" + + client = AsyncOpenAI(base_url=base_url, api_key=token_provider) + await client.chat.completions.create(messages=[], model="gpt-4") + + calls = cast("list[MockRequestCall]", respx_mock.calls) + assert len(calls) == 2 + + assert calls[0].request.headers.get("Authorization") == "Bearer first" + assert calls[1].request.headers.get("Authorization") == "Bearer second" + + @pytest.mark.asyncio + async def test_copy_auth(self) -> None: + async def token_provider_1() -> str: + return "test_bearer_token_1" + + async def token_provider_2() -> str: + return "test_bearer_token_2" + + client = AsyncOpenAI(base_url=base_url, api_key=token_provider_1).copy(api_key=token_provider_2) + await client._refresh_api_key() + assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"}