From 895b4242e467e79969b1ebac44b5e949c8c53371 Mon Sep 17 00:00:00 2001 From: Kenan Yildirim Date: Sat, 3 Jan 2026 09:42:48 -0500 Subject: [PATCH] fix: serialize Pydantic models Move transform code into its own module, since it'll be needed for other services eventually, and ensure it handles Pydantic models like `ExtraInfo`. --- flake.lock | 10 +++--- src/crowdstrike_aidr/_transform.py | 41 +++++++++++++++++++++++ src/crowdstrike_aidr/services/ai_guard.py | 9 ++--- tests/test_ai_guard.py | 7 ++-- 4 files changed, 53 insertions(+), 14 deletions(-) create mode 100644 src/crowdstrike_aidr/_transform.py diff --git a/flake.lock b/flake.lock index f434476..063ad8d 100644 --- a/flake.lock +++ b/flake.lock @@ -20,12 +20,12 @@ }, "nixpkgs": { "locked": { - "lastModified": 1765779637, - "narHash": "sha256-KJ2wa/BLSrTqDjbfyNx70ov/HdgNBCBBSQP3BIzKnv4=", - "rev": "1306659b587dc277866c7b69eb97e5f07864d8c4", - "revCount": 912002, + "lastModified": 1767116409, + "narHash": "sha256-5vKw92l1GyTnjoLzEagJy5V5mDFck72LiQWZSOnSicw=", + "rev": "cad22e7d996aea55ecab064e84834289143e44a0", + "revCount": 919991, "type": "tarball", - "url": "https://api.flakehub.com/f/pinned/NixOS/nixpkgs/0.1.912002%2Brev-1306659b587dc277866c7b69eb97e5f07864d8c4/019b2463-7b8e-7042-8b7e-490d08a3cd7a/source.tar.gz" + "url": "https://api.flakehub.com/f/pinned/NixOS/nixpkgs/0.1.919991%2Brev-cad22e7d996aea55ecab064e84834289143e44a0/019b7874-2314-7694-916e-2006d5d405fb/source.tar.gz" }, "original": { "type": "tarball", diff --git a/src/crowdstrike_aidr/_transform.py b/src/crowdstrike_aidr/_transform.py new file mode 100644 index 0000000..92a733d --- /dev/null +++ b/src/crowdstrike_aidr/_transform.py @@ -0,0 +1,41 @@ +"""Utilities for converting data structures to JSON-serializable formats.""" + +from __future__ import annotations + +from collections.abc import Mapping + +from pydantic import BaseModel + +from ._utils import is_given, is_mapping + + +def _transform_typeddict(data: Mapping[str, object]) -> dict[str, object]: + """ + Transform a TypedDict-like mapping. + + Args: + data: A `Mapping` to transform. + + Returns: + A new dictionary with transformed values, excluding unset entries. + """ + return {key: transform(value) for key, value in data.items() if is_given(value)} + + +def transform(data: object) -> object: + """ + Transform an object into a JSON-serializable format. + + Args: + data: The object to transform. + + Returns: + A JSON-serializable representation of the input data. + """ + if is_mapping(data): + return _transform_typeddict(data) + + if isinstance(data, BaseModel): + return data.model_dump(exclude_unset=True, mode="json") + + return data diff --git a/src/crowdstrike_aidr/services/ai_guard.py b/src/crowdstrike_aidr/services/ai_guard.py index 6972160..049e09a 100644 --- a/src/crowdstrike_aidr/services/ai_guard.py +++ b/src/crowdstrike_aidr/services/ai_guard.py @@ -1,20 +1,15 @@ from __future__ import annotations -from collections.abc import Mapping from typing import Literal import httpx from .._client import SyncAPIClient, make_request_options +from .._transform import transform from .._types import Body, Headers, NotGiven, Omit, Query, not_given, omit -from .._utils import is_given from ..models.ai_guard import ExtraInfo, GuardChatCompletionsResponse -def _transform_typeddict(data: Mapping[str, object]) -> Mapping[str, object]: - return {key: value for key, value in data.items() if is_given(value)} - - class AIGuard(SyncAPIClient): _service_name: str = "aiguard" @@ -83,7 +78,7 @@ def guard_chat_completions( """ return self._post( "/v1/guard_chat_completions", - body=_transform_typeddict( + body=transform( { "guard_input": guard_input, "app_id": app_id, diff --git a/tests/test_ai_guard.py b/tests/test_ai_guard.py index de6b0a6..8773f1b 100644 --- a/tests/test_ai_guard.py +++ b/tests/test_ai_guard.py @@ -6,7 +6,7 @@ import pytest from crowdstrike_aidr import AIGuard -from crowdstrike_aidr.models.ai_guard import GuardChatCompletionsResponse +from crowdstrike_aidr.models.ai_guard import ExtraInfo, GuardChatCompletionsResponse from .utils import assert_matches_type @@ -19,5 +19,8 @@ def client(request: pytest.FixtureRequest) -> Iterator[AIGuard]: def test_guard_chat_completions(client: AIGuard) -> None: - response = client.guard_chat_completions(guard_input={"messages": [{"role": "user", "content": "Hello, world!"}]}) + response = client.guard_chat_completions( + guard_input={"messages": [{"role": "user", "content": "Hello, world!"}]}, + extra_info=ExtraInfo(app_name="app_name"), + ) assert_matches_type(GuardChatCompletionsResponse, response, path=["guard_chat_completions"])