Skip to content

Commit 1f3e470

Browse files
authored
fix: serialize Pydantic models (#13)
Move transform code into its own module, since it'll be needed for other services eventually, and ensure it handles Pydantic models like `ExtraInfo`.
1 parent a71d1fa commit 1f3e470

File tree

4 files changed

+53
-14
lines changed

4 files changed

+53
-14
lines changed

flake.lock

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/crowdstrike_aidr/_transform.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""Utilities for converting data structures to JSON-serializable formats."""
2+
3+
from __future__ import annotations
4+
5+
from collections.abc import Mapping
6+
7+
from pydantic import BaseModel
8+
9+
from ._utils import is_given, is_mapping
10+
11+
12+
def _transform_typeddict(data: Mapping[str, object]) -> dict[str, object]:
13+
"""
14+
Transform a TypedDict-like mapping.
15+
16+
Args:
17+
data: A `Mapping` to transform.
18+
19+
Returns:
20+
A new dictionary with transformed values, excluding unset entries.
21+
"""
22+
return {key: transform(value) for key, value in data.items() if is_given(value)}
23+
24+
25+
def transform(data: object) -> object:
26+
"""
27+
Transform an object into a JSON-serializable format.
28+
29+
Args:
30+
data: The object to transform.
31+
32+
Returns:
33+
A JSON-serializable representation of the input data.
34+
"""
35+
if is_mapping(data):
36+
return _transform_typeddict(data)
37+
38+
if isinstance(data, BaseModel):
39+
return data.model_dump(exclude_unset=True, mode="json")
40+
41+
return data

src/crowdstrike_aidr/services/ai_guard.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,15 @@
11
from __future__ import annotations
22

3-
from collections.abc import Mapping
43
from typing import Literal
54

65
import httpx
76

87
from .._client import SyncAPIClient, make_request_options
8+
from .._transform import transform
99
from .._types import Body, Headers, NotGiven, Omit, Query, not_given, omit
10-
from .._utils import is_given
1110
from ..models.ai_guard import ExtraInfo, GuardChatCompletionsResponse
1211

1312

14-
def _transform_typeddict(data: Mapping[str, object]) -> Mapping[str, object]:
15-
return {key: value for key, value in data.items() if is_given(value)}
16-
17-
1813
class AIGuard(SyncAPIClient):
1914
_service_name: str = "aiguard"
2015

@@ -83,7 +78,7 @@ def guard_chat_completions(
8378
"""
8479
return self._post(
8580
"/v1/guard_chat_completions",
86-
body=_transform_typeddict(
81+
body=transform(
8782
{
8883
"guard_input": guard_input,
8984
"app_id": app_id,

tests/test_ai_guard.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77

88
from crowdstrike_aidr import AIGuard
9-
from crowdstrike_aidr.models.ai_guard import GuardChatCompletionsResponse
9+
from crowdstrike_aidr.models.ai_guard import ExtraInfo, GuardChatCompletionsResponse
1010

1111
from .utils import assert_matches_type
1212

@@ -19,5 +19,8 @@ def client(request: pytest.FixtureRequest) -> Iterator[AIGuard]:
1919

2020

2121
def test_guard_chat_completions(client: AIGuard) -> None:
22-
response = client.guard_chat_completions(guard_input={"messages": [{"role": "user", "content": "Hello, world!"}]})
22+
response = client.guard_chat_completions(
23+
guard_input={"messages": [{"role": "user", "content": "Hello, world!"}]},
24+
extra_info=ExtraInfo(app_name="app_name"),
25+
)
2326
assert_matches_type(GuardChatCompletionsResponse, response, path=["guard_chat_completions"])

0 commit comments

Comments
 (0)