Skip to content

Commit 6a522a4

Browse files
committed
copy changes from stainless
1 parent 51f1a84 commit 6a522a4

File tree

4 files changed

+216
-421
lines changed

4 files changed

+216
-421
lines changed

src/alpha/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,18 @@
33
from . import types
44
from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes
55
from ._utils import file_from_path
6-
from ._client import Alpha, Client, Stream, Timeout, Transport, AsyncAlpha, AsyncClient, AsyncStream, RequestOptions
6+
from ._client import (
7+
ENVIRONMENTS,
8+
Alpha,
9+
Client,
10+
Stream,
11+
Timeout,
12+
Transport,
13+
AsyncAlpha,
14+
AsyncClient,
15+
AsyncStream,
16+
RequestOptions,
17+
)
718
from ._models import BaseModel
819
from ._version import __title__, __version__
920
from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse
@@ -59,6 +70,7 @@
5970
"AsyncStream",
6071
"Alpha",
6172
"AsyncAlpha",
73+
"ENVIRONMENTS",
6274
"file_from_path",
6375
"BaseModel",
6476
"DEFAULT_TIMEOUT",

src/alpha/_client.py

Lines changed: 142 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from __future__ import annotations
44

55
import os
6-
from typing import Any, Union, Mapping
7-
from typing_extensions import Self, override
6+
from typing import Any, Dict, Union, Mapping, cast
7+
from typing_extensions import Self, Literal, override
88

99
import httpx
1010

@@ -13,6 +13,7 @@
1313
from ._types import (
1414
NOT_GIVEN,
1515
Omit,
16+
Headers,
1617
Timeout,
1718
NotGiven,
1819
Transport,
@@ -26,7 +27,7 @@
2627
from ._version import __version__
2728
from .resources import health
2829
from ._streaming import Stream as Stream, AsyncStream as AsyncStream
29-
from ._exceptions import AlphaError, APIStatusError
30+
from ._exceptions import APIStatusError
3031
from ._base_client import (
3132
DEFAULT_MAX_RETRIES,
3233
SyncAPIClient,
@@ -36,7 +37,22 @@
3637
from .resources.projects import projects
3738
from .resources.organizations import organizations
3839

39-
__all__ = ["Timeout", "Transport", "ProxiesTypes", "RequestOptions", "Alpha", "AsyncAlpha", "Client", "AsyncClient"]
40+
__all__ = [
41+
"ENVIRONMENTS",
42+
"Timeout",
43+
"Transport",
44+
"ProxiesTypes",
45+
"RequestOptions",
46+
"Alpha",
47+
"AsyncAlpha",
48+
"Client",
49+
"AsyncClient",
50+
]
51+
52+
ENVIRONMENTS: Dict[str, str] = {
53+
"production": "https://api-alpha-o3gxj3oajfu.cleanlab.ai/",
54+
"local": "http://localhost:8080",
55+
}
4056

4157

4258
class Alpha(SyncAPIClient):
@@ -48,17 +64,20 @@ class Alpha(SyncAPIClient):
4864
with_streaming_response: AlphaWithStreamedResponse
4965

5066
# client options
51-
bearer_token: str
52-
api_key: str
53-
access_key: str
67+
bearer_token: str | None
68+
api_key: str | None
69+
access_key: str | None
70+
71+
_environment: Literal["production", "local"] | NotGiven
5472

5573
def __init__(
5674
self,
5775
*,
5876
bearer_token: str | None = None,
5977
api_key: str | None = None,
6078
access_key: str | None = None,
61-
base_url: str | httpx.URL | None = None,
79+
environment: Literal["production", "local"] | NotGiven = NOT_GIVEN,
80+
base_url: str | httpx.URL | None | NotGiven = NOT_GIVEN,
6281
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
6382
max_retries: int = DEFAULT_MAX_RETRIES,
6483
default_headers: Mapping[str, str] | None = None,
@@ -86,32 +105,41 @@ def __init__(
86105
"""
87106
if bearer_token is None:
88107
bearer_token = os.environ.get("BEARER_TOKEN")
89-
if bearer_token is None:
90-
raise AlphaError(
91-
"The bearer_token client option must be set either by passing bearer_token to the client or by setting the BEARER_TOKEN environment variable"
92-
)
93108
self.bearer_token = bearer_token
94109

95110
if api_key is None:
96111
api_key = os.environ.get("X_API_KEY")
97-
if api_key is None:
98-
raise AlphaError(
99-
"The api_key client option must be set either by passing api_key to the client or by setting the X_API_KEY environment variable"
100-
)
101112
self.api_key = api_key
102113

103114
if access_key is None:
104115
access_key = os.environ.get("X_ACCESS_KEY")
105-
if access_key is None:
106-
raise AlphaError(
107-
"The access_key client option must be set either by passing access_key to the client or by setting the X_ACCESS_KEY environment variable"
108-
)
109116
self.access_key = access_key
110117

111-
if base_url is None:
112-
base_url = os.environ.get("ALPHA_BASE_URL")
113-
if base_url is None:
114-
base_url = f"https://localhost:8080/test-api"
118+
self._environment = environment
119+
120+
base_url_env = os.environ.get("ALPHA_BASE_URL")
121+
if is_given(base_url) and base_url is not None:
122+
# cast required because mypy doesn't understand the type narrowing
123+
base_url = cast("str | httpx.URL", base_url) # pyright: ignore[reportUnnecessaryCast]
124+
elif is_given(environment):
125+
if base_url_env and base_url is not None:
126+
raise ValueError(
127+
"Ambiguous URL; The `ALPHA_BASE_URL` env var and the `environment` argument are given. If you want to use the environment, you must pass base_url=None",
128+
)
129+
130+
try:
131+
base_url = ENVIRONMENTS[environment]
132+
except KeyError as exc:
133+
raise ValueError(f"Unknown environment: {environment}") from exc
134+
elif base_url_env is not None:
135+
base_url = base_url_env
136+
else:
137+
self._environment = environment = "production"
138+
139+
try:
140+
base_url = ENVIRONMENTS[environment]
141+
except KeyError as exc:
142+
raise ValueError(f"Unknown environment: {environment}") from exc
115143

116144
super().__init__(
117145
version=__version__,
@@ -150,16 +178,22 @@ def auth_headers(self) -> dict[str, str]:
150178
@property
151179
def _http_bearer(self) -> dict[str, str]:
152180
bearer_token = self.bearer_token
181+
if bearer_token is None:
182+
return {}
153183
return {"Authorization": f"Bearer {bearer_token}"}
154184

155185
@property
156186
def _authenticated_api_key(self) -> dict[str, str]:
157187
api_key = self.api_key
188+
if api_key is None:
189+
return {}
158190
return {"X-API-Key": api_key}
159191

160192
@property
161193
def _public_access_key(self) -> dict[str, str]:
162194
access_key = self.access_key
195+
if access_key is None:
196+
return {}
163197
return {"X-Access-Key": access_key}
164198

165199
@property
@@ -171,12 +205,34 @@ def default_headers(self) -> dict[str, str | Omit]:
171205
**self._custom_headers,
172206
}
173207

208+
@override
209+
def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
210+
if self.bearer_token and headers.get("Authorization"):
211+
return
212+
if isinstance(custom_headers.get("Authorization"), Omit):
213+
return
214+
215+
if self.api_key and headers.get("X-API-Key"):
216+
return
217+
if isinstance(custom_headers.get("X-API-Key"), Omit):
218+
return
219+
220+
if self.access_key and headers.get("X-Access-Key"):
221+
return
222+
if isinstance(custom_headers.get("X-Access-Key"), Omit):
223+
return
224+
225+
raise TypeError(
226+
'"Could not resolve authentication method. Expected one of bearer_token, api_key or access_key to be set. Or for one of the `Authorization`, `X-API-Key` or `X-Access-Key` headers to be explicitly omitted"'
227+
)
228+
174229
def copy(
175230
self,
176231
*,
177232
bearer_token: str | None = None,
178233
api_key: str | None = None,
179234
access_key: str | None = None,
235+
environment: Literal["production", "local"] | None = None,
180236
base_url: str | httpx.URL | None = None,
181237
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
182238
http_client: httpx.Client | None = None,
@@ -214,6 +270,7 @@ def copy(
214270
api_key=api_key or self.api_key,
215271
access_key=access_key or self.access_key,
216272
base_url=base_url or self.base_url,
273+
environment=environment or self._environment,
217274
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
218275
http_client=http_client,
219276
max_retries=max_retries if is_given(max_retries) else self.max_retries,
@@ -269,17 +326,20 @@ class AsyncAlpha(AsyncAPIClient):
269326
with_streaming_response: AsyncAlphaWithStreamedResponse
270327

271328
# client options
272-
bearer_token: str
273-
api_key: str
274-
access_key: str
329+
bearer_token: str | None
330+
api_key: str | None
331+
access_key: str | None
332+
333+
_environment: Literal["production", "local"] | NotGiven
275334

276335
def __init__(
277336
self,
278337
*,
279338
bearer_token: str | None = None,
280339
api_key: str | None = None,
281340
access_key: str | None = None,
282-
base_url: str | httpx.URL | None = None,
341+
environment: Literal["production", "local"] | NotGiven = NOT_GIVEN,
342+
base_url: str | httpx.URL | None | NotGiven = NOT_GIVEN,
283343
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
284344
max_retries: int = DEFAULT_MAX_RETRIES,
285345
default_headers: Mapping[str, str] | None = None,
@@ -307,32 +367,41 @@ def __init__(
307367
"""
308368
if bearer_token is None:
309369
bearer_token = os.environ.get("BEARER_TOKEN")
310-
if bearer_token is None:
311-
raise AlphaError(
312-
"The bearer_token client option must be set either by passing bearer_token to the client or by setting the BEARER_TOKEN environment variable"
313-
)
314370
self.bearer_token = bearer_token
315371

316372
if api_key is None:
317373
api_key = os.environ.get("X_API_KEY")
318-
if api_key is None:
319-
raise AlphaError(
320-
"The api_key client option must be set either by passing api_key to the client or by setting the X_API_KEY environment variable"
321-
)
322374
self.api_key = api_key
323375

324376
if access_key is None:
325377
access_key = os.environ.get("X_ACCESS_KEY")
326-
if access_key is None:
327-
raise AlphaError(
328-
"The access_key client option must be set either by passing access_key to the client or by setting the X_ACCESS_KEY environment variable"
329-
)
330378
self.access_key = access_key
331379

332-
if base_url is None:
333-
base_url = os.environ.get("ALPHA_BASE_URL")
334-
if base_url is None:
335-
base_url = f"https://localhost:8080/test-api"
380+
self._environment = environment
381+
382+
base_url_env = os.environ.get("ALPHA_BASE_URL")
383+
if is_given(base_url) and base_url is not None:
384+
# cast required because mypy doesn't understand the type narrowing
385+
base_url = cast("str | httpx.URL", base_url) # pyright: ignore[reportUnnecessaryCast]
386+
elif is_given(environment):
387+
if base_url_env and base_url is not None:
388+
raise ValueError(
389+
"Ambiguous URL; The `ALPHA_BASE_URL` env var and the `environment` argument are given. If you want to use the environment, you must pass base_url=None",
390+
)
391+
392+
try:
393+
base_url = ENVIRONMENTS[environment]
394+
except KeyError as exc:
395+
raise ValueError(f"Unknown environment: {environment}") from exc
396+
elif base_url_env is not None:
397+
base_url = base_url_env
398+
else:
399+
self._environment = environment = "production"
400+
401+
try:
402+
base_url = ENVIRONMENTS[environment]
403+
except KeyError as exc:
404+
raise ValueError(f"Unknown environment: {environment}") from exc
336405

337406
super().__init__(
338407
version=__version__,
@@ -371,16 +440,22 @@ def auth_headers(self) -> dict[str, str]:
371440
@property
372441
def _http_bearer(self) -> dict[str, str]:
373442
bearer_token = self.bearer_token
443+
if bearer_token is None:
444+
return {}
374445
return {"Authorization": f"Bearer {bearer_token}"}
375446

376447
@property
377448
def _authenticated_api_key(self) -> dict[str, str]:
378449
api_key = self.api_key
450+
if api_key is None:
451+
return {}
379452
return {"X-API-Key": api_key}
380453

381454
@property
382455
def _public_access_key(self) -> dict[str, str]:
383456
access_key = self.access_key
457+
if access_key is None:
458+
return {}
384459
return {"X-Access-Key": access_key}
385460

386461
@property
@@ -392,12 +467,34 @@ def default_headers(self) -> dict[str, str | Omit]:
392467
**self._custom_headers,
393468
}
394469

470+
@override
471+
def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
472+
if self.bearer_token and headers.get("Authorization"):
473+
return
474+
if isinstance(custom_headers.get("Authorization"), Omit):
475+
return
476+
477+
if self.api_key and headers.get("X-API-Key"):
478+
return
479+
if isinstance(custom_headers.get("X-API-Key"), Omit):
480+
return
481+
482+
if self.access_key and headers.get("X-Access-Key"):
483+
return
484+
if isinstance(custom_headers.get("X-Access-Key"), Omit):
485+
return
486+
487+
raise TypeError(
488+
'"Could not resolve authentication method. Expected one of bearer_token, api_key or access_key to be set. Or for one of the `Authorization`, `X-API-Key` or `X-Access-Key` headers to be explicitly omitted"'
489+
)
490+
395491
def copy(
396492
self,
397493
*,
398494
bearer_token: str | None = None,
399495
api_key: str | None = None,
400496
access_key: str | None = None,
497+
environment: Literal["production", "local"] | None = None,
401498
base_url: str | httpx.URL | None = None,
402499
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
403500
http_client: httpx.AsyncClient | None = None,
@@ -435,6 +532,7 @@ def copy(
435532
api_key=api_key or self.api_key,
436533
access_key=access_key or self.access_key,
437534
base_url=base_url or self.base_url,
535+
environment=environment or self._environment,
438536
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
439537
http_client=http_client,
440538
max_retries=max_retries if is_given(max_retries) else self.max_retries,

0 commit comments

Comments
 (0)