Skip to content

Commit d4d159a

Browse files
committed
review feedback + remove callable api_key from module client
1 parent 60dba27 commit d4d159a

File tree

5 files changed

+19
-69
lines changed

5 files changed

+19
-69
lines changed

src/openai/__init__.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@
117117

118118
from ._base_client import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES
119119

120-
api_key: str | _t.Callable[[], str] | None = None
120+
api_key: str | None = None
121121

122122
organization: str | None = None
123123

@@ -156,27 +156,15 @@ class _ModuleClient(OpenAI):
156156

157157
@property # type: ignore
158158
@override
159-
def api_key(self) -> str | _t.Callable[[], str] | None:
160-
return api_key() if callable(api_key) else api_key
159+
def api_key(self) -> str | None:
160+
return api_key
161161

162162
@api_key.setter # type: ignore
163-
def api_key(self, value: str | _t.Callable[[], str] | None) -> None: # type: ignore
163+
def api_key(self, value: str | None) -> None: # type: ignore
164164
global api_key
165165
api_key = value
166166

167-
@property
168-
def _api_key_provider(self) -> _t.Callable[[], str] | None: # type: ignore
169-
return None
170-
171-
@_api_key_provider.setter
172-
def _api_key_provider(self, value: _t.Callable[[], str] | None) -> None: # type: ignore
173-
global api_key
174-
# Yes, setting the api_key is intentional. The module level client accepts callables
175-
# for the module level api_key and will call it to retrieve the value
176-
# if it is a callable.
177-
api_key = value
178-
179-
@property
167+
@property # type: ignore
180168
@override
181169
def organization(self) -> str | None:
182170
return organization

src/openai/_client.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88

99
import httpx
1010

11-
from openai._models import FinalRequestOptions
12-
1311
from . import _exceptions
1412
from ._qs import Querystring
1513
from ._types import (
@@ -27,6 +25,7 @@
2725
get_async_library,
2826
)
2927
from ._compat import cached_property
28+
from ._models import FinalRequestOptions
3029
from ._version import __version__
3130
from ._streaming import Stream as Stream, AsyncStream as AsyncStream
3231
from ._exceptions import OpenAIError, APIStatusError
@@ -138,7 +137,7 @@ def __init__(
138137
self.api_key = ""
139138
self._api_key_provider: Callable[[], str] | None = api_key
140139
else:
141-
self.api_key = api_key or ""
140+
self.api_key = api_key
142141
self._api_key_provider = None
143142

144143
if organization is None:
@@ -477,7 +476,7 @@ def __init__(
477476
self.api_key = ""
478477
self._api_key_provider: Callable[[], Awaitable[str]] | None = api_key
479478
else:
480-
self.api_key = api_key or ""
479+
self.api_key = api_key
481480
self._api_key_provider = None
482481

483482
if organization is None:

src/openai/lib/azure.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(
9494
azure_endpoint: str,
9595
azure_deployment: str | None = None,
9696
api_version: str | None = None,
97-
api_key: str | None = None,
97+
api_key: str | Callable[[], str] | None = None,
9898
azure_ad_token: str | None = None,
9999
azure_ad_token_provider: AzureADTokenProvider | None = None,
100100
organization: str | None = None,
@@ -114,7 +114,7 @@ def __init__(
114114
*,
115115
azure_deployment: str | None = None,
116116
api_version: str | None = None,
117-
api_key: str | None = None,
117+
api_key: str | Callable[[], str] | None = None,
118118
azure_ad_token: str | None = None,
119119
azure_ad_token_provider: AzureADTokenProvider | None = None,
120120
organization: str | None = None,
@@ -134,7 +134,7 @@ def __init__(
134134
*,
135135
base_url: str,
136136
api_version: str | None = None,
137-
api_key: str | None = None,
137+
api_key: str | Callable[[], str] | None = None,
138138
azure_ad_token: str | None = None,
139139
azure_ad_token_provider: AzureADTokenProvider | None = None,
140140
organization: str | None = None,
@@ -154,7 +154,7 @@ def __init__(
154154
api_version: str | None = None,
155155
azure_endpoint: str | None = None,
156156
azure_deployment: str | None = None,
157-
api_key: str | None = None,
157+
api_key: str | Callable[[], str] | None = None,
158158
azure_ad_token: str | None = None,
159159
azure_ad_token_provider: AzureADTokenProvider | None = None,
160160
organization: str | None = None,
@@ -345,7 +345,7 @@ def _configure_realtime(self, model: str, extra_query: Query) -> tuple[httpx.URL
345345
"api-version": self._api_version,
346346
"deployment": self._azure_deployment or model,
347347
}
348-
if self.api_key != "<missing API key>":
348+
if self.api_key and self.api_key != "<missing API key>":
349349
auth_headers = {"api-key": self.api_key}
350350
else:
351351
token = self._get_azure_ad_token()
@@ -372,7 +372,7 @@ def __init__(
372372
azure_endpoint: str,
373373
azure_deployment: str | None = None,
374374
api_version: str | None = None,
375-
api_key: str | None = None,
375+
api_key: str | Callable[[], Awaitable[str]] | None = None,
376376
azure_ad_token: str | None = None,
377377
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
378378
organization: str | None = None,
@@ -393,7 +393,7 @@ def __init__(
393393
*,
394394
azure_deployment: str | None = None,
395395
api_version: str | None = None,
396-
api_key: str | None = None,
396+
api_key: str | Callable[[], Awaitable[str]] | None = None,
397397
azure_ad_token: str | None = None,
398398
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
399399
organization: str | None = None,
@@ -414,7 +414,7 @@ def __init__(
414414
*,
415415
base_url: str,
416416
api_version: str | None = None,
417-
api_key: str | None = None,
417+
api_key: str | Callable[[], Awaitable[str]] | None = None,
418418
azure_ad_token: str | None = None,
419419
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
420420
organization: str | None = None,

tests/test_client.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -949,14 +949,13 @@ def test_api_key_before_after_refresh_provider(self) -> None:
949949
client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token")
950950

951951
assert client.api_key == ""
952-
assert 'Authorization' not in client.auth_headers
952+
assert "Authorization" not in client.auth_headers
953953

954954
client._refresh_api_key()
955955

956956
assert client.api_key == "test_bearer_token"
957957
assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token"
958958

959-
960959
def test_api_key_before_after_refresh_str(self) -> None:
961960
client = OpenAI(base_url=base_url, api_key="test_api_key")
962961

@@ -1956,18 +1955,17 @@ async def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None:
19561955
async def test_api_key_before_after_refresh_provider(self) -> None:
19571956
async def mock_api_key_provider():
19581957
return "test_bearer_token"
1959-
1958+
19601959
client = AsyncOpenAI(base_url=base_url, api_key=mock_api_key_provider)
19611960

19621961
assert client.api_key == ""
1963-
assert 'Authorization' not in client.auth_headers
1962+
assert "Authorization" not in client.auth_headers
19641963

19651964
await client._refresh_api_key()
19661965

19671966
assert client.api_key == "test_bearer_token"
19681967
assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token"
19691968

1970-
19711969
@pytest.mark.asyncio
19721970
async def test_api_key_before_after_refresh_str(self) -> None:
19731971
client = AsyncOpenAI(base_url=base_url, api_key="test_api_key")

tests/test_module_client.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -97,23 +97,6 @@ def test_http_client_option() -> None:
9797
assert openai.completions._client._client is new_client
9898

9999

100-
def test_api_key_callable() -> None:
101-
openai.api_key = lambda: "1"
102-
assert openai.completions._client.api_key == "1"
103-
104-
def test_api_key_overridable() -> None:
105-
openai.api_key = lambda: "1"
106-
assert openai.completions._client.api_key == "1"
107-
assert openai.completions._client._api_key_provider is None
108-
109-
openai.api_key = "2"
110-
assert openai.completions._client.api_key == "2"
111-
assert openai.completions._client._api_key_provider is None
112-
113-
openai.api_key = lambda: "3"
114-
assert openai.completions._client.api_key == "3"
115-
assert openai.completions._client._api_key_provider is None
116-
117100
import contextlib
118101
from typing import Iterator
119102

@@ -140,24 +123,6 @@ def test_only_api_key_results_in_openai_api() -> None:
140123
assert type(openai.completions._client).__name__ == "_ModuleClient"
141124

142125

143-
def test_only_api_key_in_openai_api() -> None:
144-
with fresh_env():
145-
openai.api_type = None
146-
openai.api_key = lambda: "example bearer token"
147-
148-
assert type(openai.completions._client).__name__ == "_ModuleClient"
149-
150-
151-
def test_both_api_key_and_api_key_provider_in_openai_api() -> None:
152-
with fresh_env():
153-
openai.api_key = lambda: "example bearer token"
154-
155-
assert openai.api_key() == "example bearer token"
156-
157-
openai.api_key = "example API key"
158-
assert openai.api_key == "example API key"
159-
160-
161126
def test_azure_api_key_env_without_api_version() -> None:
162127
with fresh_env():
163128
openai.api_type = None

0 commit comments

Comments
 (0)