Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,93 @@
# CrowdStrike AIDR Python SDK

Python SDK for CrowdStrike AIDR.

## Installation

```bash
pip install crowdstrike-aidr
```

## Requirements

Python v3.13 or greater.

## Usage

```python
from crowdstrike_aidr import AIGuard

client = AIGuard(base_url_template="https://api.crowdstrike.com/aidr/{SERVICE_NAME}")

response = client.guard_chat_completions(
guard_input={
"messages": [
{"role": "user", "content": "Hello, world!"}
]
}
)
```

## Timeouts

The SDK uses `httpx.Timeout` for timeout configuration. By default, requests
have a timeout of 60 seconds with a 5 second connection timeout.

You can configure timeouts in two ways:

### Client-level timeout

Set a default timeout for all requests made by the client:

```python
import httpx
from crowdstrike_aidr import AIGuard

# Using a float (total timeout in seconds).
client = AIGuard(
base_url_template="https://api.crowdstrike.com/aidr/{SERVICE_NAME}",
timeout=30.0
)

# Using httpx.Timeout for more granular control.
client = AIGuard(
base_url_template="https://api.crowdstrike.com/aidr/{SERVICE_NAME}",
timeout=httpx.Timeout(timeout=60.0, connect=10.0)
)
```

### Request-level timeout

Override the timeout for a specific request:

```python
# Using a float (total timeout in seconds).
response = client.guard_chat_completions(
guard_input={"messages": [...]},
timeout=120.0
)

# Using httpx.Timeout for more granular control.
response = client.guard_chat_completions(
guard_input={"messages": [...]},
timeout=httpx.Timeout(timeout=120.0, connect=15.0)
)
```

## Retries

The SDK automatically retries failed requests with exponential backoff. By
default, the client will retry up to 2 times.

### Client-level retries

Set the maximum number of retries for all requests:

```python
from crowdstrike_aidr import AIGuard

client = AIGuard(
base_url_template="https://api.crowdstrike.com/aidr/{SERVICE_NAME}",
max_retries=5 # Retry up to 5 times
)
```
24 changes: 12 additions & 12 deletions src/crowdstrike_aidr/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,20 @@ def _merge_mappings(

class BaseClient[T: httpx.Client | httpx.AsyncClient]:
_client: T
_service_name: str

def __init__(
self,
*,
base_url: str | URL,
base_url_template: str,
max_retries: int = DEFAULT_MAX_RETRIES,
timeout: float | Timeout | None = DEFAULT_TIMEOUT,
custom_headers: Mapping[str, str] | None = None,
custom_query: Mapping[str, object] | None = None,
) -> None:
self._base_url = URL(base_url)
self.max_retries = max_retries
self.timeout = timeout
self._base_url_template = base_url_template
self._custom_headers = custom_headers or {}
self._custom_query = custom_query or {}

Expand All @@ -102,11 +103,8 @@ def auth_headers(self) -> dict[str, str]:

@property
def base_url(self) -> URL:
return self._base_url

@base_url.setter
def base_url(self, url: URL | str) -> None:
self._base_url = url if isinstance(url, URL) else URL(url)
resolved_url = self._base_url_template.replace("{SERVICE_NAME}", self._service_name)
return URL(resolved_url)

@property
def default_headers(self) -> dict[str, str | Omit]:
Expand Down Expand Up @@ -191,8 +189,9 @@ def _prepare_url(self, url: str) -> URL:

merge_url = URL(url)
if merge_url.is_relative_url:
merge_raw_path = self.base_url.raw_path + merge_url.raw_path.lstrip(b"/")
return self.base_url.copy_with(raw_path=merge_raw_path)
base_url = self.base_url
merge_raw_path = base_url.raw_path + merge_url.raw_path.lstrip(b"/")
return base_url.copy_with(raw_path=merge_raw_path)

return merge_url

Expand Down Expand Up @@ -241,7 +240,7 @@ class SyncAPIClient(BaseClient[httpx.Client]):
def __init__(
self,
*,
base_url: str | URL,
base_url_template: str,
max_retries: int = DEFAULT_MAX_RETRIES,
timeout: float | Timeout = DEFAULT_TIMEOUT,
http_client: httpx.Client | None = None,
Expand All @@ -255,12 +254,13 @@ def __init__(

super().__init__(
timeout=cast(Timeout, timeout),
base_url=base_url,
base_url_template=base_url_template,
max_retries=max_retries,
custom_headers=custom_headers,
custom_query=custom_query,
)
self._client = http_client or SyncHttpxClientWrapper(base_url=self._base_url, timeout=self.timeout)
resolved_base_url = self.base_url
self._client = http_client or SyncHttpxClientWrapper(base_url=resolved_base_url, timeout=self.timeout)

def _post(
self,
Expand Down
2 changes: 2 additions & 0 deletions src/crowdstrike_aidr/services/ai_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def _transform_typeddict(data: Mapping[str, object]) -> Mapping[str, object]:


class AIGuard(SyncAPIClient):
_service_name: str = "aiguard"

def guard_chat_completions(
self,
*,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_ai_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@

from .utils import assert_matches_type

base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
base_url_template = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")


@pytest.fixture(scope="session")
def client(request: pytest.FixtureRequest) -> Iterator[AIGuard]:
yield AIGuard(base_url=base_url)
yield AIGuard(base_url_template=base_url_template)


def test_guard_chat_completions(client: AIGuard) -> None:
Expand Down