diff --git a/README.md b/README.md index 5ecfc99..8b86bf6 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,10 @@ Python v3.13 or greater. ```python from crowdstrike_aidr import AIGuard -client = AIGuard(base_url_template="https://api.crowdstrike.com/aidr/{SERVICE_NAME}") +client = AIGuard( + base_url_template="https://api.crowdstrike.com/aidr/{SERVICE_NAME}", + token="my API token" +) response = client.guard_chat_completions( guard_input={ @@ -46,13 +49,15 @@ from crowdstrike_aidr import AIGuard # Using a float (total timeout in seconds). client = AIGuard( base_url_template="https://api.crowdstrike.com/aidr/{SERVICE_NAME}", - timeout=30.0 + token="my API token", + timeout=30.0, ) # Using httpx.Timeout for more granular control. client = AIGuard( base_url_template="https://api.crowdstrike.com/aidr/{SERVICE_NAME}", - timeout=httpx.Timeout(timeout=60.0, connect=10.0) + token="my API token", + timeout=httpx.Timeout(timeout=60.0, connect=10.0), ) ``` @@ -77,17 +82,14 @@ response = client.guard_chat_completions( ## Retries The SDK automatically retries failed requests with exponential backoff. By -default, the client will retry up to 2 times. - -### Client-level retries - -Set the maximum number of retries for all requests: +default, the client will retry up to 2 times. Set `max_retries` during client +creation to change this. ```python from crowdstrike_aidr import AIGuard client = AIGuard( base_url_template="https://api.crowdstrike.com/aidr/{SERVICE_NAME}", - max_retries=5 # Retry up to 5 times + max_retries=5 # Retry up to 5 times. ) ``` diff --git a/src/crowdstrike_aidr/_client.py b/src/crowdstrike_aidr/_client.py index 7f43780..ff35899 100644 --- a/src/crowdstrike_aidr/_client.py +++ b/src/crowdstrike_aidr/_client.py @@ -5,7 +5,7 @@ import time from collections.abc import Mapping from random import random -from typing import TYPE_CHECKING, Any, cast, get_origin +from typing import TYPE_CHECKING, Any, cast, get_origin, override import httpx from httpx import URL, Timeout @@ -237,11 +237,13 @@ def __del__(self) -> None: class SyncAPIClient(BaseClient[httpx.Client]): _client: httpx.Client + _token: str def __init__( self, *, base_url_template: str, + token: str, max_retries: int = DEFAULT_MAX_RETRIES, timeout: float | Timeout = DEFAULT_TIMEOUT, http_client: httpx.Client | None = None, @@ -260,9 +262,17 @@ def __init__( custom_headers=custom_headers, custom_query=custom_query, ) + + self._token = token + resolved_base_url = self.base_url self._client = http_client or SyncHttpxClientWrapper(base_url=resolved_base_url, timeout=self.timeout) + @property + @override + def auth_headers(self) -> dict[str, str]: + return {"Authorization": f"Bearer {self._token}"} + def _post( self, path: str, diff --git a/tests/test_ai_guard.py b/tests/test_ai_guard.py index 2c8e4e0..de6b0a6 100644 --- a/tests/test_ai_guard.py +++ b/tests/test_ai_guard.py @@ -15,7 +15,7 @@ @pytest.fixture(scope="session") def client(request: pytest.FixtureRequest) -> Iterator[AIGuard]: - yield AIGuard(base_url_template=base_url_template) + yield AIGuard(base_url_template=base_url_template, token="my API token") def test_guard_chat_completions(client: AIGuard) -> None: