|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | import os |
6 | | -from typing import Any, Union, Mapping |
7 | | -from typing_extensions import Self, override |
| 6 | +from typing import Any, Dict, Union, Mapping, cast |
| 7 | +from typing_extensions import Self, Literal, override |
8 | 8 |
|
9 | 9 | import httpx |
10 | 10 |
|
|
92 | 92 | from .resources.fine_tuning import fine_tuning |
93 | 93 | from .resources.organization import organization |
94 | 94 |
|
95 | | -__all__ = ["Timeout", "Transport", "ProxiesTypes", "RequestOptions", "Hanzo", "AsyncHanzo", "Client", "AsyncClient"] |
| 95 | +__all__ = [ |
| 96 | + "ENVIRONMENTS", |
| 97 | + "Timeout", |
| 98 | + "Transport", |
| 99 | + "ProxiesTypes", |
| 100 | + "RequestOptions", |
| 101 | + "Hanzo", |
| 102 | + "AsyncHanzo", |
| 103 | + "Client", |
| 104 | + "AsyncClient", |
| 105 | +] |
| 106 | + |
| 107 | +ENVIRONMENTS: Dict[str, str] = { |
| 108 | + "production": "https://api.hanzo.ai", |
| 109 | + "sandbox": "https://api.sandbox.hanzo.ai", |
| 110 | +} |
96 | 111 |
|
97 | 112 |
|
98 | 113 | class Hanzo(SyncAPIClient): |
@@ -150,11 +165,14 @@ class Hanzo(SyncAPIClient): |
150 | 165 | # client options |
151 | 166 | api_key: str |
152 | 167 |
|
| 168 | + _environment: Literal["production", "sandbox"] | NotGiven |
| 169 | + |
153 | 170 | def __init__( |
154 | 171 | self, |
155 | 172 | *, |
156 | 173 | api_key: str | None = None, |
157 | | - base_url: str | httpx.URL | None = None, |
| 174 | + environment: Literal["production", "sandbox"] | NotGiven = NOT_GIVEN, |
| 175 | + base_url: str | httpx.URL | None | NotGiven = NOT_GIVEN, |
158 | 176 | timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, |
159 | 177 | max_retries: int = DEFAULT_MAX_RETRIES, |
160 | 178 | default_headers: Mapping[str, str] | None = None, |
@@ -185,10 +203,31 @@ def __init__( |
185 | 203 | ) |
186 | 204 | self.api_key = api_key |
187 | 205 |
|
188 | | - if base_url is None: |
189 | | - base_url = os.environ.get("HANZO_BASE_URL") |
190 | | - if base_url is None: |
191 | | - base_url = f"https://api.hanzo.ai" |
| 206 | + self._environment = environment |
| 207 | + |
| 208 | + base_url_env = os.environ.get("HANZO_BASE_URL") |
| 209 | + if is_given(base_url) and base_url is not None: |
| 210 | + # cast required because mypy doesn't understand the type narrowing |
| 211 | + base_url = cast("str | httpx.URL", base_url) # pyright: ignore[reportUnnecessaryCast] |
| 212 | + elif is_given(environment): |
| 213 | + if base_url_env and base_url is not None: |
| 214 | + raise ValueError( |
| 215 | + "Ambiguous URL; The `HANZO_BASE_URL` env var and the `environment` argument are given. If you want to use the environment, you must pass base_url=None", |
| 216 | + ) |
| 217 | + |
| 218 | + try: |
| 219 | + base_url = ENVIRONMENTS[environment] |
| 220 | + except KeyError as exc: |
| 221 | + raise ValueError(f"Unknown environment: {environment}") from exc |
| 222 | + elif base_url_env is not None: |
| 223 | + base_url = base_url_env |
| 224 | + else: |
| 225 | + self._environment = environment = "production" |
| 226 | + |
| 227 | + try: |
| 228 | + base_url = ENVIRONMENTS[environment] |
| 229 | + except KeyError as exc: |
| 230 | + raise ValueError(f"Unknown environment: {environment}") from exc |
192 | 231 |
|
193 | 232 | super().__init__( |
194 | 233 | version=__version__, |
@@ -276,6 +315,7 @@ def copy( |
276 | 315 | self, |
277 | 316 | *, |
278 | 317 | api_key: str | None = None, |
| 318 | + environment: Literal["production", "sandbox"] | None = None, |
279 | 319 | base_url: str | httpx.URL | None = None, |
280 | 320 | timeout: float | Timeout | None | NotGiven = NOT_GIVEN, |
281 | 321 | http_client: httpx.Client | None = None, |
@@ -311,6 +351,7 @@ def copy( |
311 | 351 | return self.__class__( |
312 | 352 | api_key=api_key or self.api_key, |
313 | 353 | base_url=base_url or self.base_url, |
| 354 | + environment=environment or self._environment, |
314 | 355 | timeout=self.timeout if isinstance(timeout, NotGiven) else timeout, |
315 | 356 | http_client=http_client, |
316 | 357 | max_retries=max_retries if is_given(max_retries) else self.max_retries, |
@@ -431,11 +472,14 @@ class AsyncHanzo(AsyncAPIClient): |
431 | 472 | # client options |
432 | 473 | api_key: str |
433 | 474 |
|
| 475 | + _environment: Literal["production", "sandbox"] | NotGiven |
| 476 | + |
434 | 477 | def __init__( |
435 | 478 | self, |
436 | 479 | *, |
437 | 480 | api_key: str | None = None, |
438 | | - base_url: str | httpx.URL | None = None, |
| 481 | + environment: Literal["production", "sandbox"] | NotGiven = NOT_GIVEN, |
| 482 | + base_url: str | httpx.URL | None | NotGiven = NOT_GIVEN, |
439 | 483 | timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, |
440 | 484 | max_retries: int = DEFAULT_MAX_RETRIES, |
441 | 485 | default_headers: Mapping[str, str] | None = None, |
@@ -466,10 +510,31 @@ def __init__( |
466 | 510 | ) |
467 | 511 | self.api_key = api_key |
468 | 512 |
|
469 | | - if base_url is None: |
470 | | - base_url = os.environ.get("HANZO_BASE_URL") |
471 | | - if base_url is None: |
472 | | - base_url = f"https://api.hanzo.ai" |
| 513 | + self._environment = environment |
| 514 | + |
| 515 | + base_url_env = os.environ.get("HANZO_BASE_URL") |
| 516 | + if is_given(base_url) and base_url is not None: |
| 517 | + # cast required because mypy doesn't understand the type narrowing |
| 518 | + base_url = cast("str | httpx.URL", base_url) # pyright: ignore[reportUnnecessaryCast] |
| 519 | + elif is_given(environment): |
| 520 | + if base_url_env and base_url is not None: |
| 521 | + raise ValueError( |
| 522 | + "Ambiguous URL; The `HANZO_BASE_URL` env var and the `environment` argument are given. If you want to use the environment, you must pass base_url=None", |
| 523 | + ) |
| 524 | + |
| 525 | + try: |
| 526 | + base_url = ENVIRONMENTS[environment] |
| 527 | + except KeyError as exc: |
| 528 | + raise ValueError(f"Unknown environment: {environment}") from exc |
| 529 | + elif base_url_env is not None: |
| 530 | + base_url = base_url_env |
| 531 | + else: |
| 532 | + self._environment = environment = "production" |
| 533 | + |
| 534 | + try: |
| 535 | + base_url = ENVIRONMENTS[environment] |
| 536 | + except KeyError as exc: |
| 537 | + raise ValueError(f"Unknown environment: {environment}") from exc |
473 | 538 |
|
474 | 539 | super().__init__( |
475 | 540 | version=__version__, |
@@ -557,6 +622,7 @@ def copy( |
557 | 622 | self, |
558 | 623 | *, |
559 | 624 | api_key: str | None = None, |
| 625 | + environment: Literal["production", "sandbox"] | None = None, |
560 | 626 | base_url: str | httpx.URL | None = None, |
561 | 627 | timeout: float | Timeout | None | NotGiven = NOT_GIVEN, |
562 | 628 | http_client: httpx.AsyncClient | None = None, |
@@ -592,6 +658,7 @@ def copy( |
592 | 658 | return self.__class__( |
593 | 659 | api_key=api_key or self.api_key, |
594 | 660 | base_url=base_url or self.base_url, |
| 661 | + environment=environment or self._environment, |
595 | 662 | timeout=self.timeout if isinstance(timeout, NotGiven) else timeout, |
596 | 663 | http_client=http_client, |
597 | 664 | max_retries=max_retries if is_given(max_retries) else self.max_retries, |
|
0 commit comments