From 4146e29d98cec7ebc602739a18c18500eda04ced Mon Sep 17 00:00:00 2001 From: kristapratico Date: Fri, 15 Aug 2025 18:19:29 +0000 Subject: [PATCH 1/5] add auth parameter and AzureAuth --- src/openai/__init__.py | 21 +++- src/openai/_client.py | 79 ++++++++++--- src/openai/lib/azure.py | 111 +++++++++++++++++- .../resources/beta/realtime/realtime.py | 2 + 4 files changed, 191 insertions(+), 22 deletions(-) diff --git a/src/openai/__init__.py b/src/openai/__init__.py index 226fed9554..5b7ddda7f6 100644 --- a/src/openai/__init__.py +++ b/src/openai/__init__.py @@ -87,7 +87,12 @@ from .lib import azure as _azure, pydantic_function_tool as pydantic_function_tool from .version import VERSION as VERSION -from .lib.azure import AzureOpenAI as AzureOpenAI, AsyncAzureOpenAI as AsyncAzureOpenAI +from .lib.azure import ( + AzureAuth as AzureAuth, + AzureOpenAI as AzureOpenAI, + AsyncAzureAuth as AsyncAzureAuth, + AsyncAzureOpenAI as AsyncAzureOpenAI, +) from .lib._old_api import * from .lib.streaming import ( AssistantEventHandler as AssistantEventHandler, @@ -119,6 +124,8 @@ api_key: str | None = None +auth: AzureAuth | None = None + organization: str | None = None project: str | None = None @@ -165,6 +172,17 @@ def api_key(self, value: str | None) -> None: # type: ignore api_key = value + @property # type: ignore + @override + def auth(self) -> AzureAuth | None: + return auth + + @auth.setter # type: ignore + def auth(self, value: AzureAuth | None) -> None: # type: ignore + global auth + + auth = value + @property # type: ignore @override def organization(self) -> str | None: @@ -348,6 +366,7 @@ def _load_client() -> OpenAI: # type: ignore[reportUnusedFunction] _client = _ModuleClient( api_key=api_key, + auth=auth, organization=organization, project=project, webhook_secret=webhook_secret, diff --git a/src/openai/_client.py b/src/openai/_client.py index ed9b46f4b0..a547783eed 100644 --- a/src/openai/_client.py +++ b/src/openai/_client.py @@ -25,6 +25,7 @@ get_async_library, ) from ._compat import cached_property +from ._models import FinalRequestOptions from ._version import __version__ from ._streaming import Stream as Stream, AsyncStream as AsyncStream from ._exceptions import OpenAIError, APIStatusError @@ -35,6 +36,7 @@ ) if TYPE_CHECKING: + from .lib.azure import AzureAuth, AsyncAzureAuth from .resources import ( beta, chat, @@ -93,6 +95,7 @@ def __init__( self, *, api_key: str | None = None, + auth: AzureAuth | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -124,13 +127,16 @@ def __init__( - `project` from `OPENAI_PROJECT_ID` - `webhook_secret` from `OPENAI_WEBHOOK_SECRET` """ + if api_key and auth: + raise ValueError("The `api_key` and `auth` arguments are mutually exclusive") if api_key is None: api_key = os.environ.get("OPENAI_API_KEY") - if api_key is None: + if api_key is None and auth is None: raise OpenAIError( - "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" + "The api_key or auth client option must be set either by passing api_key or auth to the client or by setting the OPENAI_API_KEY environment variable" ) - self.api_key = api_key + self.auth = auth + self.api_key = api_key or "" if organization is None: organization = os.environ.get("OPENAI_ORG_ID") @@ -163,6 +169,7 @@ def __init__( ) self._default_stream_cls = Stream + self._auth_headers: dict[str, str] = {} @cached_property def completions(self) -> Completions: @@ -279,14 +286,25 @@ def with_streaming_response(self) -> OpenAIWithStreamedResponse: def qs(self) -> Querystring: return Querystring(array_format="brackets") + def refresh_auth_headers(self) -> None: + secret = self.auth.get_token() if self.auth else self.api_key + if not secret: + # if secret is an empty string, encoding the header will fail + # so we set it to an empty dict + # this is to avoid sending an invalid Authorization header + self._auth_headers = {} + else: + self._auth_headers = {"Authorization": f"Bearer {secret}"} + + @override + def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: + self.refresh_auth_headers() + return super()._prepare_options(options) + @property @override def auth_headers(self) -> dict[str, str]: - api_key = self.api_key - if not api_key: - # if the api key is an empty string, encoding the header will fail - return {} - return {"Authorization": f"Bearer {api_key}"} + return self._auth_headers @property @override @@ -303,6 +321,7 @@ def copy( self, *, api_key: str | None = None, + auth: AzureAuth | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -338,6 +357,10 @@ def copy( elif set_default_query is not None: params = set_default_query + auth = auth or self.auth + if auth is not None: + _extra_kwargs = {**_extra_kwargs, "auth": auth} + http_client = http_client or self._client return self.__class__( api_key=api_key or self.api_key, @@ -412,6 +435,7 @@ def __init__( self, *, api_key: str | None = None, + auth: AsyncAzureAuth | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -443,13 +467,16 @@ def __init__( - `project` from `OPENAI_PROJECT_ID` - `webhook_secret` from `OPENAI_WEBHOOK_SECRET` """ + if api_key and auth: + raise ValueError("The `api_key` and `auth` arguments are mutually exclusive") if api_key is None: api_key = os.environ.get("OPENAI_API_KEY") - if api_key is None: + if api_key is None and auth is None: raise OpenAIError( - "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" + "The api_key or auth client option must be set either by passing api_key or auth to the client or by setting the OPENAI_API_KEY environment variable" ) - self.api_key = api_key + self.auth = auth + self.api_key = api_key or "" if organization is None: organization = os.environ.get("OPENAI_ORG_ID") @@ -482,6 +509,7 @@ def __init__( ) self._default_stream_cls = AsyncStream + self._auth_headers: dict[str, str] = {} @cached_property def completions(self) -> AsyncCompletions: @@ -598,14 +626,28 @@ def with_streaming_response(self) -> AsyncOpenAIWithStreamedResponse: def qs(self) -> Querystring: return Querystring(array_format="brackets") + async def refresh_auth_headers(self) -> None: + if self.auth: + secret = await self.auth.get_token() + else: + secret = self.api_key + if not secret: + # if the secret is an empty string, encoding the header will fail + # so we set it to an empty dict + # this is to avoid sending an invalid Authorization header + self._auth_headers = {} + else: + self._auth_headers = {"Authorization": f"Bearer {secret}"} + + @override + async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: + await self.refresh_auth_headers() + return await super()._prepare_options(options) + @property @override def auth_headers(self) -> dict[str, str]: - api_key = self.api_key - if not api_key: - # if the api key is an empty string, encoding the header will fail - return {} - return {"Authorization": f"Bearer {api_key}"} + return self._auth_headers @property @override @@ -622,6 +664,7 @@ def copy( self, *, api_key: str | None = None, + auth: AsyncAzureAuth | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -657,6 +700,10 @@ def copy( elif set_default_query is not None: params = set_default_query + auth = auth or self.auth + if auth is not None: + _extra_kwargs = {**_extra_kwargs, "auth": auth} + http_client = http_client or self._client return self.__class__( api_key=api_key or self.api_key, diff --git a/src/openai/lib/azure.py b/src/openai/lib/azure.py index a994e4256c..260dabc4f8 100644 --- a/src/openai/lib/azure.py +++ b/src/openai/lib/azure.py @@ -2,7 +2,7 @@ import os import inspect -from typing import Any, Union, Mapping, TypeVar, Callable, Awaitable, cast, overload +from typing import TYPE_CHECKING, Any, Union, Mapping, TypeVar, Callable, Awaitable, cast, overload from typing_extensions import Self, override import httpx @@ -42,6 +42,107 @@ API_KEY_SENTINEL = "".join(["<", "missing API key", ">"]) +if TYPE_CHECKING: + from azure.core.credentials import TokenCredential + from azure.core.credentials_async import AsyncTokenCredential + + +AzureTokenProvider = Callable[[], str] +AsyncAzureTokenProvider = Callable[[], Awaitable[str]] + + +class AzureAuth: + @overload + def __init__(self, *, token_provider: AzureTokenProvider) -> None: ... + + @overload + def __init__( + self, + *, + credential: TokenCredential, + scopes: list[str] = ["https://cognitiveservices.azure.com/.default"], + ) -> None: ... + + def __init__( + self, + *, + token_provider: AzureTokenProvider | None = None, + credential: TokenCredential | None = None, + scopes: list[str] = ["https://cognitiveservices.azure.com/.default"], + ) -> None: + if token_provider is not None and credential is not None: + raise ValueError("The `token_provider` and `credential` arguments are mutually exclusive.") + if token_provider is None and credential is None: + raise ValueError("One of `token_provider` or `credential` must be provided to AzureAuth.") + self.token_provider = token_provider + self.credential = credential + self.scopes = scopes + + def get_token(self) -> str: + if self.token_provider is not None: + token = self.token_provider() + return token + + if self.credential is not None: + try: + from azure.identity import get_bearer_token_provider + except ImportError as err: + raise ImportError( + "azure-identity library is not installed. Please install it to use AzureAuth." + ) from err + token_provider = get_bearer_token_provider(self.credential, *self.scopes) + token = token_provider() + return token + + raise ValueError("Unexpected values provided to AzureAuth. Unable to get token.") + + +class AsyncAzureAuth: + @overload + def __init__(self, *, token_provider: AsyncAzureTokenProvider) -> None: ... + + @overload + def __init__( + self, + *, + credential: AsyncTokenCredential, + scopes: list[str] = ["https://cognitiveservices.azure.com/.default"], + ) -> None: ... + + def __init__( + self, + *, + token_provider: AsyncAzureTokenProvider | None = None, + credential: AsyncTokenCredential | None = None, + scopes: list[str] = ["https://cognitiveservices.azure.com/.default"], + ) -> None: + if token_provider is not None and credential is not None: + raise ValueError("The `token_provider` and `credential` arguments are mutually exclusive.") + if token_provider is None and credential is None: + raise ValueError("One of `token_provider` or `credential` must be provided to AsyncAzureAuth.") + self.token_provider = token_provider + self.credential = credential + self.scopes = scopes + + async def get_token(self) -> str: + if self.token_provider is not None: + token = await self.token_provider() + return token + + if self.credential is not None: + try: + from azure.identity.aio import get_bearer_token_provider + except ImportError as err: + raise ImportError( + "azure-identity library is not installed. Please install it to use AsyncAzureAuth." + ) from err + token_provider = get_bearer_token_provider(self.credential, *self.scopes) + token = await token_provider() + return token + + raise ValueError("Unexpected values provided to AsyncAzureAuth. Unable to get token.") + + class MutuallyExclusiveAuthError(OpenAIError): def __init__(self) -> None: super().__init__( @@ -255,7 +356,7 @@ def __init__( self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None @override - def copy( + def copy( # type: ignore self, *, api_key: str | None = None, @@ -301,7 +402,7 @@ def copy( }, ) - with_options = copy + with_options = copy # type: ignore def _get_azure_ad_token(self) -> str | None: if self._azure_ad_token is not None: @@ -536,7 +637,7 @@ def __init__( self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None @override - def copy( + def copy( # type: ignore self, *, api_key: str | None = None, @@ -582,7 +683,7 @@ def copy( }, ) - with_options = copy + with_options = copy # type: ignore async def _get_azure_ad_token(self) -> str | None: if self._azure_ad_token is not None: diff --git a/src/openai/resources/beta/realtime/realtime.py b/src/openai/resources/beta/realtime/realtime.py index 8e1b558cf3..beff8eb582 100644 --- a/src/openai/resources/beta/realtime/realtime.py +++ b/src/openai/resources/beta/realtime/realtime.py @@ -358,6 +358,7 @@ async def __aenter__(self) -> AsyncRealtimeConnection: raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc extra_query = self.__extra_query + await self.__client.refresh_auth_headers() auth_headers = self.__client.auth_headers if is_async_azure_client(self.__client): url, auth_headers = await self.__client._configure_realtime(self.__model, extra_query) @@ -540,6 +541,7 @@ def __enter__(self) -> RealtimeConnection: raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc extra_query = self.__extra_query + self.__client.refresh_auth_headers() auth_headers = self.__client.auth_headers if is_azure_client(self.__client): url, auth_headers = self.__client._configure_realtime(self.__model, extra_query) From 2847fd66c2c27d3bb0f3502ce285a3ae8df4af1a Mon Sep 17 00:00:00 2001 From: kristapratico Date: Fri, 15 Aug 2025 21:26:37 +0000 Subject: [PATCH 2/5] add tests --- tests/lib/test_azure.py | 126 +++++++++++++++++++++++++++++- tests/test_client.py | 150 ++++++++++++++++++++++++++++++++++-- tests/test_module_client.py | 35 ++++++++- 3 files changed, 304 insertions(+), 7 deletions(-) diff --git a/tests/lib/test_azure.py b/tests/lib/test_azure.py index 52c24eba27..ba4a2e1e44 100644 --- a/tests/lib/test_azure.py +++ b/tests/lib/test_azure.py @@ -2,6 +2,7 @@ import logging from typing import Union, cast +from unittest.mock import AsyncMock, MagicMock, patch from typing_extensions import Literal, Protocol import httpx @@ -10,10 +11,11 @@ from openai._utils import SensitiveHeadersFilter, is_dict from openai._models import FinalRequestOptions -from openai.lib.azure import AzureOpenAI, AsyncAzureOpenAI +from openai.lib.azure import AzureAuth, AzureOpenAI, AsyncAzureAuth, AsyncAzureOpenAI Client = Union[AzureOpenAI, AsyncAzureOpenAI] +mock_credential = MagicMock() sync_client = AzureOpenAI( api_version="2023-07-01", @@ -802,3 +804,125 @@ def test_client_sets_base_url(client: Client) -> None: ) ) assert req.url == "https://example-resource.azure.openai.com/openai/models?api-version=2024-02-01" + + +class TestAzureAuth: + """Test cases for the AzureAuth class.""" + + def test_init_with_token_provider(self) -> None: + def token_provider() -> str: + return "test-token-123" + + auth = AzureAuth(token_provider=token_provider) + assert auth.token_provider is token_provider + assert auth.credential is None + assert auth.scopes == ["https://cognitiveservices.azure.com/.default"] + + def test_init_with_credential(self) -> None: + auth = AzureAuth(credential=mock_credential) + assert auth.credential is mock_credential + assert auth.token_provider is None + assert auth.scopes == ["https://cognitiveservices.azure.com/.default"] + + def test_init_with_custom_scopes(self) -> None: + custom_scopes = ["https://custom.scope/.default"] + + auth = AzureAuth(credential=mock_credential, scopes=custom_scopes) + assert auth.scopes == custom_scopes + + def test_init_mutually_exclusive_raises_error(self) -> None: + def token_provider() -> str: + return "test-token-123" + + with pytest.raises(ValueError, match="mutually exclusive"): + AzureAuth(token_provider=token_provider, credential=mock_credential) # type: ignore[misc] + + def test_init_with_no_params_raises_error(self) -> None: + with pytest.raises(ValueError, match="One of `token_provider` or `credential` must be provided"): + AzureAuth() # type: ignore[misc] + + def test_get_token_with_token_provider(self) -> None: + expected_token = "test-token-456" + + def token_provider() -> str: + return expected_token + + auth = AzureAuth(token_provider=token_provider) + token = auth.get_token() + assert token == expected_token + + def test_get_token_with_credential(self) -> None: + auth = AzureAuth(credential=mock_credential) + + with patch("azure.identity.get_bearer_token_provider") as mock_provider: + mock_token_provider = MagicMock() + mock_token_provider.return_value = "azure-token-789" + mock_provider.return_value = mock_token_provider + + token = auth.get_token() + + assert token == "azure-token-789" + mock_provider.assert_called_once_with(mock_credential, *auth.scopes) + mock_token_provider.assert_called_once() + + +class TestAsyncAzureAuth: + """Test cases for the AsyncAzureAuth class.""" + + def test_init_with_token_provider(self) -> None: + async def async_token_provider() -> str: + return "async-test-token-123" + + auth = AsyncAzureAuth(token_provider=async_token_provider) + assert auth.token_provider is async_token_provider + assert auth.credential is None + assert auth.scopes == ["https://cognitiveservices.azure.com/.default"] + + def test_init_with_credential(self) -> None: + auth = AsyncAzureAuth(credential=mock_credential) + assert auth.credential is mock_credential + assert auth.token_provider is None + assert auth.scopes == ["https://cognitiveservices.azure.com/.default"] + + def test_init_with_custom_scopes(self) -> None: + custom_scopes = ["https://custom.scope/.default"] + + auth = AsyncAzureAuth(credential=mock_credential, scopes=custom_scopes) + assert auth.scopes == custom_scopes + + def test_init_mutually_exclusive_raises_error(self) -> None: + async def async_token_provider() -> str: + return "async-test-token-123" + + with pytest.raises(ValueError, match="mutually exclusive"): + AsyncAzureAuth(token_provider=async_token_provider, credential=mock_credential) # type: ignore[misc] + + def test_init_with_no_params_raises_error(self) -> None: + with pytest.raises(ValueError, match="One of `token_provider` or `credential` must be provided"): + AsyncAzureAuth() # type: ignore[misc] + + @pytest.mark.asyncio + async def test_get_token_with_token_provider(self) -> None: + expected_token = "async-test-token-456" + + async def async_token_provider() -> str: + return expected_token + + auth = AsyncAzureAuth(token_provider=async_token_provider) + token = await auth.get_token() + assert token == expected_token + + @pytest.mark.asyncio + async def test_get_token_with_credential(self) -> None: + auth = AsyncAzureAuth(credential=mock_credential) + + with patch("azure.identity.aio.get_bearer_token_provider") as mock_provider: + mock_token_provider = AsyncMock() + mock_token_provider.return_value = "async-azure-token-789" + mock_provider.return_value = mock_token_provider + + token = await auth.get_token() + + assert token == "async-azure-token-789" + mock_provider.assert_called_once_with(mock_credential, *auth.scopes) + mock_token_provider.assert_awaited_once() diff --git a/tests/test_client.py b/tests/test_client.py index ccda50a7f0..ecaa1fd9fb 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -11,7 +11,7 @@ import inspect import subprocess import tracemalloc -from typing import Any, Union, cast +from typing import Any, Union, Protocol, cast from textwrap import dedent from unittest import mock from typing_extensions import Literal @@ -21,7 +21,7 @@ from respx import MockRouter from pydantic import ValidationError -from openai import OpenAI, AsyncOpenAI, APIResponseValidationError +from openai import OpenAI, AzureAuth, AsyncOpenAI, AsyncAzureAuth, APIResponseValidationError from openai._types import Omit from openai._models import BaseModel, FinalRequestOptions from openai._streaming import Stream, AsyncStream @@ -41,6 +41,10 @@ api_key = "My API Key" +class MockRequestCall(Protocol): + request: httpx.Request + + def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]: request = client._build_request(FinalRequestOptions(method="get", url="/foo")) url = httpx.URL(request.url) @@ -337,7 +341,9 @@ def test_default_headers_option(self) -> None: def test_validate_headers(self) -> None: client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True) - request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + options = client._prepare_options(FinalRequestOptions(method="get", url="/foo")) + request = client._build_request(options) + assert request.headers.get("Authorization") == f"Bearer {api_key}" with pytest.raises(OpenAIError): @@ -939,6 +945,63 @@ def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: assert exc_info.value.response.status_code == 302 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" + def test_refresh_auth_headers_token(self) -> None: + client = OpenAI(base_url=base_url, auth=AzureAuth(token_provider=lambda: "test_bearer_token")) + client.refresh_auth_headers() + assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token" + + def test_refresh_auth_headers_key(self) -> None: + client = OpenAI(base_url=base_url, api_key="test_api_key") + client.refresh_auth_headers() + assert client.auth_headers.get("Authorization") == "Bearer test_api_key" + + @pytest.mark.respx() + def test_auth_provider_refresh(self, respx_mock: MockRouter) -> None: + respx_mock.post(base_url + "/chat/completions").mock( + side_effect=[ + httpx.Response(500, json={"error": "server error"}), + httpx.Response(200, json={"foo": "bar"}), + ] + ) + + counter = 0 + + def token_provider() -> str: + nonlocal counter + + counter += 1 + + if counter == 1: + return "first" + + return "second" + + client = OpenAI(base_url=base_url, auth=AzureAuth(token_provider=token_provider)) + client.chat.completions.create(messages=[], model="gpt-4") + + calls = cast("list[MockRequestCall]", respx_mock.calls) + assert len(calls) == 2 + + assert calls[0].request.headers.get("Authorization") == "Bearer first" + assert calls[1].request.headers.get("Authorization") == "Bearer second" + + def test_auth_mutually_exclusive(self) -> None: + with pytest.raises(ValueError) as exc_info: + OpenAI(base_url=base_url, api_key=api_key, auth=AzureAuth(token_provider=lambda: "test_bearer_token")) + assert str(exc_info.value) == "The `api_key` and `auth` arguments are mutually exclusive" + + def test_copy_auth(self) -> None: + client = OpenAI(base_url=base_url, auth=AzureAuth(token_provider=lambda: "test_bearer_token_1")).copy( + auth=AzureAuth(token_provider=lambda: "test_bearer_token_2") + ) + client.refresh_auth_headers() + assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"} + + def test_copy_auth_mutually_exclusive(self) -> None: + with pytest.raises(ValueError) as exc_info: + OpenAI(base_url=base_url, api_key=api_key).copy(auth=AzureAuth(token_provider=lambda: "test_bearer_token")) + assert str(exc_info.value) == "The `api_key` and `auth` arguments are mutually exclusive" + class TestAsyncOpenAI: client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True) @@ -1220,9 +1283,10 @@ def test_default_headers_option(self) -> None: assert request.headers.get("x-foo") == "stainless" assert request.headers.get("x-stainless-lang") == "my-overriding-header" - def test_validate_headers(self) -> None: + async def test_validate_headers(self) -> None: client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True) - request = client._build_request(FinalRequestOptions(method="get", url="/foo")) + options = await client._prepare_options(FinalRequestOptions(method="get", url="/foo")) + request = client._build_request(options) assert request.headers.get("Authorization") == f"Bearer {api_key}" with pytest.raises(OpenAIError): @@ -1887,3 +1951,79 @@ async def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: assert exc_info.value.response.status_code == 302 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" + + @pytest.mark.asyncio + async def test_refresh_auth_headers_token_async(self) -> None: + async def token_provider() -> str: + return "test_bearer_token" + + client = AsyncOpenAI(base_url=base_url, auth=AsyncAzureAuth(token_provider=token_provider)) + await client.refresh_auth_headers() + assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token" + + @pytest.mark.asyncio + async def test_refresh_auth_headers_key_async(self) -> None: + client = AsyncOpenAI(base_url=base_url, api_key="test_api_key") + await client.refresh_auth_headers() + assert client.auth_headers.get("Authorization") == "Bearer test_api_key" + + @pytest.mark.asyncio + @pytest.mark.respx() + async def test_bearer_token_refresh_async(self, respx_mock: MockRouter) -> None: + respx_mock.post(base_url + "/chat/completions").mock( + side_effect=[ + httpx.Response(500, json={"error": "server error"}), + httpx.Response(200, json={"foo": "bar"}), + ] + ) + + counter = 0 + + async def token_provider() -> str: + nonlocal counter + + counter += 1 + + if counter == 1: + return "first" + + return "second" + + client = AsyncOpenAI(base_url=base_url, auth=AsyncAzureAuth(token_provider=token_provider)) + await client.chat.completions.create(messages=[], model="gpt-4") + + calls = cast("list[MockRequestCall]", respx_mock.calls) + assert len(calls) == 2 + + assert calls[0].request.headers.get("Authorization") == "Bearer first" + assert calls[1].request.headers.get("Authorization") == "Bearer second" + + def test_auth_mutually_exclusive_async(self) -> None: + async def token_provider() -> str: + return "test_bearer_token" + + with pytest.raises(ValueError) as exc_info: + AsyncOpenAI(base_url=base_url, api_key=api_key, auth=AsyncAzureAuth(token_provider=token_provider)) + assert str(exc_info.value) == "The `api_key` and `auth` arguments are mutually exclusive" + + @pytest.mark.asyncio + async def test_copy_auth(self) -> None: + async def token_provider_1() -> str: + return "test_bearer_token_1" + + async def token_provider_2() -> str: + return "test_bearer_token_2" + + client = AsyncOpenAI(base_url=base_url, auth=AsyncAzureAuth(token_provider=token_provider_1)).copy( + auth=AsyncAzureAuth(token_provider=token_provider_2) + ) + await client.refresh_auth_headers() + assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"} + + def test_copy_auth_mutually_exclusive_async(self) -> None: + async def token_provider() -> str: + return "test_bearer_token" + + with pytest.raises(ValueError) as exc_info: + AsyncOpenAI(base_url=base_url, api_key=api_key).copy(auth=AsyncAzureAuth(token_provider=token_provider)) + assert str(exc_info.value) == "The `api_key` and `auth` arguments are mutually exclusive" diff --git a/tests/test_module_client.py b/tests/test_module_client.py index 9c9a1addab..446bf5a668 100644 --- a/tests/test_module_client.py +++ b/tests/test_module_client.py @@ -9,12 +9,13 @@ from httpx import URL import openai -from openai import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES +from openai import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES, AzureAuth def reset_state() -> None: openai._reset_client() openai.api_key = None or "My API Key" + openai.auth = None openai.organization = None openai.project = None openai.webhook_secret = None @@ -97,6 +98,17 @@ def test_http_client_option() -> None: assert openai.completions._client._client is new_client +def test_auth_provider_option() -> None: + assert openai.auth is None + assert openai.completions._client.auth is None + + openai.auth = AzureAuth(token_provider=lambda: "foo") + + assert openai.auth.get_token() == "foo" + assert openai.completions._client.auth + assert openai.completions._client.auth.get_token() == "foo" + + import contextlib from typing import Iterator @@ -123,6 +135,27 @@ def test_only_api_key_results_in_openai_api() -> None: assert type(openai.completions._client).__name__ == "_ModuleClient" +def test_only_auth_provider_in_openai_api() -> None: + with fresh_env(): + openai.api_type = None + openai.api_key = None + openai.auth = AzureAuth(token_provider=lambda: "foo") + + assert type(openai.completions._client).__name__ == "_ModuleClient" + + +def test_both_api_key_and_auth_provider_in_openai_api() -> None: + with fresh_env(): + openai.api_key = "example API key" + openai.auth = AzureAuth(token_provider=lambda: "foo") + + with pytest.raises( + ValueError, + match=r"The `api_key` and `auth` arguments are mutually exclusive", + ): + openai.completions._client # noqa: B018 + + def test_azure_api_key_env_without_api_version() -> None: with fresh_env(): openai.api_type = None From 5cc66ccc1bd20b85fbc2084fbff3418dc7fd891a Mon Sep 17 00:00:00 2001 From: kristapratico Date: Fri, 15 Aug 2025 21:26:46 +0000 Subject: [PATCH 3/5] add TokenAuth protocols --- src/openai/__init__.py | 8 ++++---- src/openai/_client.py | 11 ++++++----- src/openai/_types.py | 8 ++++++++ 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/openai/__init__.py b/src/openai/__init__.py index 5b7ddda7f6..cbad256e2a 100644 --- a/src/openai/__init__.py +++ b/src/openai/__init__.py @@ -7,7 +7,7 @@ from typing_extensions import override from . import types -from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes +from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, TokenAuth, Transport, ProxiesTypes from ._utils import file_from_path from ._client import Client, OpenAI, Stream, Timeout, Transport, AsyncClient, AsyncOpenAI, AsyncStream, RequestOptions from ._models import BaseModel @@ -124,7 +124,7 @@ api_key: str | None = None -auth: AzureAuth | None = None +auth: TokenAuth | None = None organization: str | None = None @@ -174,11 +174,11 @@ def api_key(self, value: str | None) -> None: # type: ignore @property # type: ignore @override - def auth(self) -> AzureAuth | None: + def auth(self) -> TokenAuth | None: return auth @auth.setter # type: ignore - def auth(self, value: AzureAuth | None) -> None: # type: ignore + def auth(self, value: TokenAuth | None) -> None: # type: ignore global auth auth = value diff --git a/src/openai/_client.py b/src/openai/_client.py index a547783eed..6ab196a51e 100644 --- a/src/openai/_client.py +++ b/src/openai/_client.py @@ -15,8 +15,10 @@ Omit, Timeout, NotGiven, + TokenAuth, Transport, ProxiesTypes, + AsyncTokenAuth, RequestOptions, ) from ._utils import ( @@ -36,7 +38,6 @@ ) if TYPE_CHECKING: - from .lib.azure import AzureAuth, AsyncAzureAuth from .resources import ( beta, chat, @@ -95,7 +96,7 @@ def __init__( self, *, api_key: str | None = None, - auth: AzureAuth | None = None, + auth: TokenAuth | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -321,7 +322,7 @@ def copy( self, *, api_key: str | None = None, - auth: AzureAuth | None = None, + auth: TokenAuth | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -435,7 +436,7 @@ def __init__( self, *, api_key: str | None = None, - auth: AsyncAzureAuth | None = None, + auth: AsyncTokenAuth | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, @@ -664,7 +665,7 @@ def copy( self, *, api_key: str | None = None, - auth: AsyncAzureAuth | None = None, + auth: AsyncTokenAuth | None = None, organization: str | None = None, project: str | None = None, webhook_secret: str | None = None, diff --git a/src/openai/_types.py b/src/openai/_types.py index 5dae55f4a9..53ab5d21fb 100644 --- a/src/openai/_types.py +++ b/src/openai/_types.py @@ -219,3 +219,11 @@ class _GenericAlias(Protocol): class HttpxSendArgs(TypedDict, total=False): auth: httpx.Auth follow_redirects: bool + + +class TokenAuth(Protocol): + def get_token(self) -> str: ... + + +class AsyncTokenAuth(Protocol): + async def get_token(self) -> str: ... From 908e9899846ea309b3accee5f4c90e24bbc2aca5 Mon Sep 17 00:00:00 2001 From: kristapratico Date: Sat, 16 Aug 2025 00:03:51 +0000 Subject: [PATCH 4/5] feedback --- src/openai/lib/azure.py | 46 +++++++++++++++++++++-------------------- tests/lib/test_azure.py | 8 ++----- 2 files changed, 26 insertions(+), 28 deletions(-) diff --git a/src/openai/lib/azure.py b/src/openai/lib/azure.py index 260dabc4f8..cd4dbb7b97 100644 --- a/src/openai/lib/azure.py +++ b/src/openai/lib/azure.py @@ -74,24 +74,25 @@ def __init__( raise ValueError("The `token_provider` and `credential` arguments are mutually exclusive.") if token_provider is None and credential is None: raise ValueError("One of `token_provider` or `credential` must be provided to AzureAuth.") - self.token_provider = token_provider - self.credential = credential - self.scopes = scopes - - def get_token(self) -> str: - if self.token_provider is not None: - token = self.token_provider() - return token + if isinstance(scopes, str): # type: ignore[unreachable] + raise TypeError("`scopes` must be a list of strings, not a single string.") - if self.credential is not None: + if credential is not None: try: from azure.identity import get_bearer_token_provider except ImportError as err: raise ImportError( "azure-identity library is not installed. Please install it to use AzureAuth." ) from err - token_provider = get_bearer_token_provider(self.credential, *self.scopes) - token = token_provider() + token_provider = get_bearer_token_provider(credential, *scopes) + + self.token_provider = token_provider + self.credential = credential + self.scopes = scopes + + def get_token(self) -> str: + if self.token_provider is not None: + token = self.token_provider() return token raise ValueError("Unexpected values provided to AzureAuth. Unable to get token.") @@ -120,24 +121,25 @@ def __init__( raise ValueError("The `token_provider` and `credential` arguments are mutually exclusive.") if token_provider is None and credential is None: raise ValueError("One of `token_provider` or `credential` must be provided to AsyncAzureAuth.") - self.token_provider = token_provider - self.credential = credential - self.scopes = scopes - - async def get_token(self) -> str: - if self.token_provider is not None: - token = await self.token_provider() - return token + if isinstance(scopes, str): # type: ignore[unreachable] + raise TypeError("`scopes` must be a list of strings, not a single string.") - if self.credential is not None: + if credential is not None: try: from azure.identity.aio import get_bearer_token_provider except ImportError as err: raise ImportError( "azure-identity library is not installed. Please install it to use AsyncAzureAuth." ) from err - token_provider = get_bearer_token_provider(self.credential, *self.scopes) - token = await token_provider() + token_provider = get_bearer_token_provider(credential, *scopes) + + self.token_provider = token_provider + self.credential = credential + self.scopes = scopes + + async def get_token(self) -> str: + if self.token_provider is not None: + token = await self.token_provider() return token raise ValueError("Unexpected values provided to AsyncAzureAuth. Unable to get token.") diff --git a/tests/lib/test_azure.py b/tests/lib/test_azure.py index ba4a2e1e44..9b9d401087 100644 --- a/tests/lib/test_azure.py +++ b/tests/lib/test_azure.py @@ -821,7 +821,6 @@ def token_provider() -> str: def test_init_with_credential(self) -> None: auth = AzureAuth(credential=mock_credential) assert auth.credential is mock_credential - assert auth.token_provider is None assert auth.scopes == ["https://cognitiveservices.azure.com/.default"] def test_init_with_custom_scopes(self) -> None: @@ -852,13 +851,12 @@ def token_provider() -> str: assert token == expected_token def test_get_token_with_credential(self) -> None: - auth = AzureAuth(credential=mock_credential) - with patch("azure.identity.get_bearer_token_provider") as mock_provider: mock_token_provider = MagicMock() mock_token_provider.return_value = "azure-token-789" mock_provider.return_value = mock_token_provider + auth = AzureAuth(credential=mock_credential) token = auth.get_token() assert token == "azure-token-789" @@ -881,7 +879,6 @@ async def async_token_provider() -> str: def test_init_with_credential(self) -> None: auth = AsyncAzureAuth(credential=mock_credential) assert auth.credential is mock_credential - assert auth.token_provider is None assert auth.scopes == ["https://cognitiveservices.azure.com/.default"] def test_init_with_custom_scopes(self) -> None: @@ -914,13 +911,12 @@ async def async_token_provider() -> str: @pytest.mark.asyncio async def test_get_token_with_credential(self) -> None: - auth = AsyncAzureAuth(credential=mock_credential) - with patch("azure.identity.aio.get_bearer_token_provider") as mock_provider: mock_token_provider = AsyncMock() mock_token_provider.return_value = "async-azure-token-789" mock_provider.return_value = mock_token_provider + auth = AsyncAzureAuth(credential=mock_credential) token = await auth.get_token() assert token == "async-azure-token-789" From ad1fe0ae8a2c76354eb39aa97da79b82cdd129de Mon Sep 17 00:00:00 2001 From: kristapratico Date: Sat, 16 Aug 2025 00:56:13 +0000 Subject: [PATCH 5/5] rename tests --- tests/test_client.py | 2 +- tests/test_module_client.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index ecaa1fd9fb..927ac966c7 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -956,7 +956,7 @@ def test_refresh_auth_headers_key(self) -> None: assert client.auth_headers.get("Authorization") == "Bearer test_api_key" @pytest.mark.respx() - def test_auth_provider_refresh(self, respx_mock: MockRouter) -> None: + def test_bearer_token_refresh(self, respx_mock: MockRouter) -> None: respx_mock.post(base_url + "/chat/completions").mock( side_effect=[ httpx.Response(500, json={"error": "server error"}), diff --git a/tests/test_module_client.py b/tests/test_module_client.py index 446bf5a668..d0fd050151 100644 --- a/tests/test_module_client.py +++ b/tests/test_module_client.py @@ -98,7 +98,7 @@ def test_http_client_option() -> None: assert openai.completions._client._client is new_client -def test_auth_provider_option() -> None: +def test_auth_option() -> None: assert openai.auth is None assert openai.completions._client.auth is None @@ -135,7 +135,7 @@ def test_only_api_key_results_in_openai_api() -> None: assert type(openai.completions._client).__name__ == "_ModuleClient" -def test_only_auth_provider_in_openai_api() -> None: +def test_only_auth_in_openai_api() -> None: with fresh_env(): openai.api_type = None openai.api_key = None @@ -144,7 +144,7 @@ def test_only_auth_provider_in_openai_api() -> None: assert type(openai.completions._client).__name__ == "_ModuleClient" -def test_both_api_key_and_auth_provider_in_openai_api() -> None: +def test_both_api_key_and_auth_in_openai_api() -> None: with fresh_env(): openai.api_key = "example API key" openai.auth = AzureAuth(token_provider=lambda: "foo")