Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ Python v3.13 or greater.
```python
from crowdstrike_aidr import AIGuard

client = AIGuard(base_url_template="https://api.crowdstrike.com/aidr/{SERVICE_NAME}")
client = AIGuard(
base_url_template="https://api.crowdstrike.com/aidr/{SERVICE_NAME}",
token="my API token"
)

response = client.guard_chat_completions(
guard_input={
Expand Down Expand Up @@ -46,13 +49,15 @@ from crowdstrike_aidr import AIGuard
# Using a float (total timeout in seconds).
client = AIGuard(
base_url_template="https://api.crowdstrike.com/aidr/{SERVICE_NAME}",
timeout=30.0
token="my API token",
timeout=30.0,
)

# Using httpx.Timeout for more granular control.
client = AIGuard(
base_url_template="https://api.crowdstrike.com/aidr/{SERVICE_NAME}",
timeout=httpx.Timeout(timeout=60.0, connect=10.0)
token="my API token",
timeout=httpx.Timeout(timeout=60.0, connect=10.0),
)
```

Expand All @@ -77,17 +82,14 @@ response = client.guard_chat_completions(
## Retries

The SDK automatically retries failed requests with exponential backoff. By
default, the client will retry up to 2 times.

### Client-level retries

Set the maximum number of retries for all requests:
default, the client will retry up to 2 times. Set `max_retries` during client
creation to change this.

```python
from crowdstrike_aidr import AIGuard

client = AIGuard(
base_url_template="https://api.crowdstrike.com/aidr/{SERVICE_NAME}",
max_retries=5 # Retry up to 5 times
max_retries=5 # Retry up to 5 times.
)
```
12 changes: 11 additions & 1 deletion src/crowdstrike_aidr/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
from collections.abc import Mapping
from random import random
from typing import TYPE_CHECKING, Any, cast, get_origin
from typing import TYPE_CHECKING, Any, cast, get_origin, override

import httpx
from httpx import URL, Timeout
Expand Down Expand Up @@ -237,11 +237,13 @@ def __del__(self) -> None:

class SyncAPIClient(BaseClient[httpx.Client]):
_client: httpx.Client
_token: str

def __init__(
self,
*,
base_url_template: str,
token: str,
max_retries: int = DEFAULT_MAX_RETRIES,
timeout: float | Timeout = DEFAULT_TIMEOUT,
http_client: httpx.Client | None = None,
Expand All @@ -260,9 +262,17 @@ def __init__(
custom_headers=custom_headers,
custom_query=custom_query,
)

self._token = token

resolved_base_url = self.base_url
self._client = http_client or SyncHttpxClientWrapper(base_url=resolved_base_url, timeout=self.timeout)

@property
@override
def auth_headers(self) -> dict[str, str]:
return {"Authorization": f"Bearer {self._token}"}

def _post(
self,
path: str,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_ai_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@pytest.fixture(scope="session")
def client(request: pytest.FixtureRequest) -> Iterator[AIGuard]:
yield AIGuard(base_url_template=base_url_template)
yield AIGuard(base_url_template=base_url_template, token="my API token")


def test_guard_chat_completions(client: AIGuard) -> None:
Expand Down