33from __future__ import annotations
44
55import os
6- from typing import Any , Union , Mapping
7- from typing_extensions import Self , override
6+ from typing import Any , Dict , Union , Mapping , cast
7+ from typing_extensions import Self , Literal , override
88
99import httpx
1010
1313from ._types import (
1414 NOT_GIVEN ,
1515 Omit ,
16+ Headers ,
1617 Timeout ,
1718 NotGiven ,
1819 Transport ,
2627from ._version import __version__
2728from .resources import health
2829from ._streaming import Stream as Stream , AsyncStream as AsyncStream
29- from ._exceptions import AlphaError , APIStatusError
30+ from ._exceptions import APIStatusError
3031from ._base_client import (
3132 DEFAULT_MAX_RETRIES ,
3233 SyncAPIClient ,
3637from .resources .projects import projects
3738from .resources .organizations import organizations
3839
39- __all__ = ["Timeout" , "Transport" , "ProxiesTypes" , "RequestOptions" , "Alpha" , "AsyncAlpha" , "Client" , "AsyncClient" ]
40+ __all__ = [
41+ "ENVIRONMENTS" ,
42+ "Timeout" ,
43+ "Transport" ,
44+ "ProxiesTypes" ,
45+ "RequestOptions" ,
46+ "Alpha" ,
47+ "AsyncAlpha" ,
48+ "Client" ,
49+ "AsyncClient" ,
50+ ]
51+
52+ ENVIRONMENTS : Dict [str , str ] = {
53+ "production" : "https://api-alpha-o3gxj3oajfu.cleanlab.ai/" ,
54+ "local" : "http://localhost:8080" ,
55+ }
4056
4157
4258class Alpha (SyncAPIClient ):
@@ -48,17 +64,20 @@ class Alpha(SyncAPIClient):
4864 with_streaming_response : AlphaWithStreamedResponse
4965
5066 # client options
51- bearer_token : str
52- api_key : str
53- access_key : str
67+ bearer_token : str | None
68+ api_key : str | None
69+ access_key : str | None
70+
71+ _environment : Literal ["production" , "local" ] | NotGiven
5472
5573 def __init__ (
5674 self ,
5775 * ,
5876 bearer_token : str | None = None ,
5977 api_key : str | None = None ,
6078 access_key : str | None = None ,
61- base_url : str | httpx .URL | None = None ,
79+ environment : Literal ["production" , "local" ] | NotGiven = NOT_GIVEN ,
80+ base_url : str | httpx .URL | None | NotGiven = NOT_GIVEN ,
6281 timeout : Union [float , Timeout , None , NotGiven ] = NOT_GIVEN ,
6382 max_retries : int = DEFAULT_MAX_RETRIES ,
6483 default_headers : Mapping [str , str ] | None = None ,
@@ -86,32 +105,41 @@ def __init__(
86105 """
87106 if bearer_token is None :
88107 bearer_token = os .environ .get ("BEARER_TOKEN" )
89- if bearer_token is None :
90- raise AlphaError (
91- "The bearer_token client option must be set either by passing bearer_token to the client or by setting the BEARER_TOKEN environment variable"
92- )
93108 self .bearer_token = bearer_token
94109
95110 if api_key is None :
96111 api_key = os .environ .get ("X_API_KEY" )
97- if api_key is None :
98- raise AlphaError (
99- "The api_key client option must be set either by passing api_key to the client or by setting the X_API_KEY environment variable"
100- )
101112 self .api_key = api_key
102113
103114 if access_key is None :
104115 access_key = os .environ .get ("X_ACCESS_KEY" )
105- if access_key is None :
106- raise AlphaError (
107- "The access_key client option must be set either by passing access_key to the client or by setting the X_ACCESS_KEY environment variable"
108- )
109116 self .access_key = access_key
110117
111- if base_url is None :
112- base_url = os .environ .get ("ALPHA_BASE_URL" )
113- if base_url is None :
114- base_url = f"https://localhost:8080/test-api"
118+ self ._environment = environment
119+
120+ base_url_env = os .environ .get ("ALPHA_BASE_URL" )
121+ if is_given (base_url ) and base_url is not None :
122+ # cast required because mypy doesn't understand the type narrowing
123+ base_url = cast ("str | httpx.URL" , base_url ) # pyright: ignore[reportUnnecessaryCast]
124+ elif is_given (environment ):
125+ if base_url_env and base_url is not None :
126+ raise ValueError (
127+ "Ambiguous URL; The `ALPHA_BASE_URL` env var and the `environment` argument are given. If you want to use the environment, you must pass base_url=None" ,
128+ )
129+
130+ try :
131+ base_url = ENVIRONMENTS [environment ]
132+ except KeyError as exc :
133+ raise ValueError (f"Unknown environment: { environment } " ) from exc
134+ elif base_url_env is not None :
135+ base_url = base_url_env
136+ else :
137+ self ._environment = environment = "production"
138+
139+ try :
140+ base_url = ENVIRONMENTS [environment ]
141+ except KeyError as exc :
142+ raise ValueError (f"Unknown environment: { environment } " ) from exc
115143
116144 super ().__init__ (
117145 version = __version__ ,
@@ -150,16 +178,22 @@ def auth_headers(self) -> dict[str, str]:
150178 @property
151179 def _http_bearer (self ) -> dict [str , str ]:
152180 bearer_token = self .bearer_token
181+ if bearer_token is None :
182+ return {}
153183 return {"Authorization" : f"Bearer { bearer_token } " }
154184
155185 @property
156186 def _authenticated_api_key (self ) -> dict [str , str ]:
157187 api_key = self .api_key
188+ if api_key is None :
189+ return {}
158190 return {"X-API-Key" : api_key }
159191
160192 @property
161193 def _public_access_key (self ) -> dict [str , str ]:
162194 access_key = self .access_key
195+ if access_key is None :
196+ return {}
163197 return {"X-Access-Key" : access_key }
164198
165199 @property
@@ -171,12 +205,34 @@ def default_headers(self) -> dict[str, str | Omit]:
171205 ** self ._custom_headers ,
172206 }
173207
208+ @override
209+ def _validate_headers (self , headers : Headers , custom_headers : Headers ) -> None :
210+ if self .bearer_token and headers .get ("Authorization" ):
211+ return
212+ if isinstance (custom_headers .get ("Authorization" ), Omit ):
213+ return
214+
215+ if self .api_key and headers .get ("X-API-Key" ):
216+ return
217+ if isinstance (custom_headers .get ("X-API-Key" ), Omit ):
218+ return
219+
220+ if self .access_key and headers .get ("X-Access-Key" ):
221+ return
222+ if isinstance (custom_headers .get ("X-Access-Key" ), Omit ):
223+ return
224+
225+ raise TypeError (
226+ '"Could not resolve authentication method. Expected one of bearer_token, api_key or access_key to be set. Or for one of the `Authorization`, `X-API-Key` or `X-Access-Key` headers to be explicitly omitted"'
227+ )
228+
174229 def copy (
175230 self ,
176231 * ,
177232 bearer_token : str | None = None ,
178233 api_key : str | None = None ,
179234 access_key : str | None = None ,
235+ environment : Literal ["production" , "local" ] | None = None ,
180236 base_url : str | httpx .URL | None = None ,
181237 timeout : float | Timeout | None | NotGiven = NOT_GIVEN ,
182238 http_client : httpx .Client | None = None ,
@@ -214,6 +270,7 @@ def copy(
214270 api_key = api_key or self .api_key ,
215271 access_key = access_key or self .access_key ,
216272 base_url = base_url or self .base_url ,
273+ environment = environment or self ._environment ,
217274 timeout = self .timeout if isinstance (timeout , NotGiven ) else timeout ,
218275 http_client = http_client ,
219276 max_retries = max_retries if is_given (max_retries ) else self .max_retries ,
@@ -269,17 +326,20 @@ class AsyncAlpha(AsyncAPIClient):
269326 with_streaming_response : AsyncAlphaWithStreamedResponse
270327
271328 # client options
272- bearer_token : str
273- api_key : str
274- access_key : str
329+ bearer_token : str | None
330+ api_key : str | None
331+ access_key : str | None
332+
333+ _environment : Literal ["production" , "local" ] | NotGiven
275334
276335 def __init__ (
277336 self ,
278337 * ,
279338 bearer_token : str | None = None ,
280339 api_key : str | None = None ,
281340 access_key : str | None = None ,
282- base_url : str | httpx .URL | None = None ,
341+ environment : Literal ["production" , "local" ] | NotGiven = NOT_GIVEN ,
342+ base_url : str | httpx .URL | None | NotGiven = NOT_GIVEN ,
283343 timeout : Union [float , Timeout , None , NotGiven ] = NOT_GIVEN ,
284344 max_retries : int = DEFAULT_MAX_RETRIES ,
285345 default_headers : Mapping [str , str ] | None = None ,
@@ -307,32 +367,41 @@ def __init__(
307367 """
308368 if bearer_token is None :
309369 bearer_token = os .environ .get ("BEARER_TOKEN" )
310- if bearer_token is None :
311- raise AlphaError (
312- "The bearer_token client option must be set either by passing bearer_token to the client or by setting the BEARER_TOKEN environment variable"
313- )
314370 self .bearer_token = bearer_token
315371
316372 if api_key is None :
317373 api_key = os .environ .get ("X_API_KEY" )
318- if api_key is None :
319- raise AlphaError (
320- "The api_key client option must be set either by passing api_key to the client or by setting the X_API_KEY environment variable"
321- )
322374 self .api_key = api_key
323375
324376 if access_key is None :
325377 access_key = os .environ .get ("X_ACCESS_KEY" )
326- if access_key is None :
327- raise AlphaError (
328- "The access_key client option must be set either by passing access_key to the client or by setting the X_ACCESS_KEY environment variable"
329- )
330378 self .access_key = access_key
331379
332- if base_url is None :
333- base_url = os .environ .get ("ALPHA_BASE_URL" )
334- if base_url is None :
335- base_url = f"https://localhost:8080/test-api"
380+ self ._environment = environment
381+
382+ base_url_env = os .environ .get ("ALPHA_BASE_URL" )
383+ if is_given (base_url ) and base_url is not None :
384+ # cast required because mypy doesn't understand the type narrowing
385+ base_url = cast ("str | httpx.URL" , base_url ) # pyright: ignore[reportUnnecessaryCast]
386+ elif is_given (environment ):
387+ if base_url_env and base_url is not None :
388+ raise ValueError (
389+ "Ambiguous URL; The `ALPHA_BASE_URL` env var and the `environment` argument are given. If you want to use the environment, you must pass base_url=None" ,
390+ )
391+
392+ try :
393+ base_url = ENVIRONMENTS [environment ]
394+ except KeyError as exc :
395+ raise ValueError (f"Unknown environment: { environment } " ) from exc
396+ elif base_url_env is not None :
397+ base_url = base_url_env
398+ else :
399+ self ._environment = environment = "production"
400+
401+ try :
402+ base_url = ENVIRONMENTS [environment ]
403+ except KeyError as exc :
404+ raise ValueError (f"Unknown environment: { environment } " ) from exc
336405
337406 super ().__init__ (
338407 version = __version__ ,
@@ -371,16 +440,22 @@ def auth_headers(self) -> dict[str, str]:
371440 @property
372441 def _http_bearer (self ) -> dict [str , str ]:
373442 bearer_token = self .bearer_token
443+ if bearer_token is None :
444+ return {}
374445 return {"Authorization" : f"Bearer { bearer_token } " }
375446
376447 @property
377448 def _authenticated_api_key (self ) -> dict [str , str ]:
378449 api_key = self .api_key
450+ if api_key is None :
451+ return {}
379452 return {"X-API-Key" : api_key }
380453
381454 @property
382455 def _public_access_key (self ) -> dict [str , str ]:
383456 access_key = self .access_key
457+ if access_key is None :
458+ return {}
384459 return {"X-Access-Key" : access_key }
385460
386461 @property
@@ -392,12 +467,34 @@ def default_headers(self) -> dict[str, str | Omit]:
392467 ** self ._custom_headers ,
393468 }
394469
470+ @override
471+ def _validate_headers (self , headers : Headers , custom_headers : Headers ) -> None :
472+ if self .bearer_token and headers .get ("Authorization" ):
473+ return
474+ if isinstance (custom_headers .get ("Authorization" ), Omit ):
475+ return
476+
477+ if self .api_key and headers .get ("X-API-Key" ):
478+ return
479+ if isinstance (custom_headers .get ("X-API-Key" ), Omit ):
480+ return
481+
482+ if self .access_key and headers .get ("X-Access-Key" ):
483+ return
484+ if isinstance (custom_headers .get ("X-Access-Key" ), Omit ):
485+ return
486+
487+ raise TypeError (
488+ '"Could not resolve authentication method. Expected one of bearer_token, api_key or access_key to be set. Or for one of the `Authorization`, `X-API-Key` or `X-Access-Key` headers to be explicitly omitted"'
489+ )
490+
395491 def copy (
396492 self ,
397493 * ,
398494 bearer_token : str | None = None ,
399495 api_key : str | None = None ,
400496 access_key : str | None = None ,
497+ environment : Literal ["production" , "local" ] | None = None ,
401498 base_url : str | httpx .URL | None = None ,
402499 timeout : float | Timeout | None | NotGiven = NOT_GIVEN ,
403500 http_client : httpx .AsyncClient | None = None ,
@@ -435,6 +532,7 @@ def copy(
435532 api_key = api_key or self .api_key ,
436533 access_key = access_key or self .access_key ,
437534 base_url = base_url or self .base_url ,
535+ environment = environment or self ._environment ,
438536 timeout = self .timeout if isinstance (timeout , NotGiven ) else timeout ,
439537 http_client = http_client ,
440538 max_retries = max_retries if is_given (max_retries ) else self .max_retries ,
0 commit comments