diff --git a/README.md b/README.md index a6962e1..5ecfc99 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,93 @@ # CrowdStrike AIDR Python SDK Python SDK for CrowdStrike AIDR. + +## Installation + +```bash +pip install crowdstrike-aidr +``` + +## Requirements + +Python v3.13 or greater. + +## Usage + +```python +from crowdstrike_aidr import AIGuard + +client = AIGuard(base_url_template="https://api.crowdstrike.com/aidr/{SERVICE_NAME}") + +response = client.guard_chat_completions( + guard_input={ + "messages": [ + {"role": "user", "content": "Hello, world!"} + ] + } +) +``` + +## Timeouts + +The SDK uses `httpx.Timeout` for timeout configuration. By default, requests +have a timeout of 60 seconds with a 5 second connection timeout. + +You can configure timeouts in two ways: + +### Client-level timeout + +Set a default timeout for all requests made by the client: + +```python +import httpx +from crowdstrike_aidr import AIGuard + +# Using a float (total timeout in seconds). +client = AIGuard( + base_url_template="https://api.crowdstrike.com/aidr/{SERVICE_NAME}", + timeout=30.0 +) + +# Using httpx.Timeout for more granular control. +client = AIGuard( + base_url_template="https://api.crowdstrike.com/aidr/{SERVICE_NAME}", + timeout=httpx.Timeout(timeout=60.0, connect=10.0) +) +``` + +### Request-level timeout + +Override the timeout for a specific request: + +```python +# Using a float (total timeout in seconds). +response = client.guard_chat_completions( + guard_input={"messages": [...]}, + timeout=120.0 +) + +# Using httpx.Timeout for more granular control. +response = client.guard_chat_completions( + guard_input={"messages": [...]}, + timeout=httpx.Timeout(timeout=120.0, connect=15.0) +) +``` + +## Retries + +The SDK automatically retries failed requests with exponential backoff. By +default, the client will retry up to 2 times. + +### Client-level retries + +Set the maximum number of retries for all requests: + +```python +from crowdstrike_aidr import AIGuard + +client = AIGuard( + base_url_template="https://api.crowdstrike.com/aidr/{SERVICE_NAME}", + max_retries=5 # Retry up to 5 times +) +``` diff --git a/src/crowdstrike_aidr/_client.py b/src/crowdstrike_aidr/_client.py index f01e14d..ad5f2cf 100644 --- a/src/crowdstrike_aidr/_client.py +++ b/src/crowdstrike_aidr/_client.py @@ -80,19 +80,20 @@ def _merge_mappings( class BaseClient[T: httpx.Client | httpx.AsyncClient]: _client: T + _service_name: str def __init__( self, *, - base_url: str | URL, + base_url_template: str, max_retries: int = DEFAULT_MAX_RETRIES, timeout: float | Timeout | None = DEFAULT_TIMEOUT, custom_headers: Mapping[str, str] | None = None, custom_query: Mapping[str, object] | None = None, ) -> None: - self._base_url = URL(base_url) self.max_retries = max_retries self.timeout = timeout + self._base_url_template = base_url_template self._custom_headers = custom_headers or {} self._custom_query = custom_query or {} @@ -102,11 +103,8 @@ def auth_headers(self) -> dict[str, str]: @property def base_url(self) -> URL: - return self._base_url - - @base_url.setter - def base_url(self, url: URL | str) -> None: - self._base_url = url if isinstance(url, URL) else URL(url) + resolved_url = self._base_url_template.replace("{SERVICE_NAME}", self._service_name) + return URL(resolved_url) @property def default_headers(self) -> dict[str, str | Omit]: @@ -191,8 +189,9 @@ def _prepare_url(self, url: str) -> URL: merge_url = URL(url) if merge_url.is_relative_url: - merge_raw_path = self.base_url.raw_path + merge_url.raw_path.lstrip(b"/") - return self.base_url.copy_with(raw_path=merge_raw_path) + base_url = self.base_url + merge_raw_path = base_url.raw_path + merge_url.raw_path.lstrip(b"/") + return base_url.copy_with(raw_path=merge_raw_path) return merge_url @@ -241,7 +240,7 @@ class SyncAPIClient(BaseClient[httpx.Client]): def __init__( self, *, - base_url: str | URL, + base_url_template: str, max_retries: int = DEFAULT_MAX_RETRIES, timeout: float | Timeout = DEFAULT_TIMEOUT, http_client: httpx.Client | None = None, @@ -255,12 +254,13 @@ def __init__( super().__init__( timeout=cast(Timeout, timeout), - base_url=base_url, + base_url_template=base_url_template, max_retries=max_retries, custom_headers=custom_headers, custom_query=custom_query, ) - self._client = http_client or SyncHttpxClientWrapper(base_url=self._base_url, timeout=self.timeout) + resolved_base_url = self.base_url + self._client = http_client or SyncHttpxClientWrapper(base_url=resolved_base_url, timeout=self.timeout) def _post( self, diff --git a/src/crowdstrike_aidr/services/ai_guard.py b/src/crowdstrike_aidr/services/ai_guard.py index a69f5bc..6972160 100644 --- a/src/crowdstrike_aidr/services/ai_guard.py +++ b/src/crowdstrike_aidr/services/ai_guard.py @@ -16,6 +16,8 @@ def _transform_typeddict(data: Mapping[str, object]) -> Mapping[str, object]: class AIGuard(SyncAPIClient): + _service_name: str = "aiguard" + def guard_chat_completions( self, *, diff --git a/tests/test_ai_guard.py b/tests/test_ai_guard.py index 37b65ab..2c8e4e0 100644 --- a/tests/test_ai_guard.py +++ b/tests/test_ai_guard.py @@ -10,12 +10,12 @@ from .utils import assert_matches_type -base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") +base_url_template = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @pytest.fixture(scope="session") def client(request: pytest.FixtureRequest) -> Iterator[AIGuard]: - yield AIGuard(base_url=base_url) + yield AIGuard(base_url_template=base_url_template) def test_guard_chat_completions(client: AIGuard) -> None: