diff --git a/src/agents/model_settings.py b/src/agents/model_settings.py index f844eb87c..1ccede853 100644 --- a/src/agents/model_settings.py +++ b/src/agents/model_settings.py @@ -1,14 +1,50 @@ from __future__ import annotations import dataclasses +from collections.abc import Mapping from dataclasses import dataclass, fields, replace -from typing import Any, Literal +from typing import Annotated, Any, Literal, Union -from openai._types import Body, Headers, Query +from openai import Omit as _Omit +from openai._types import Body, Query from openai.types.responses import ResponseIncludable from openai.types.shared import Reasoning -from pydantic import BaseModel - +from pydantic import BaseModel, GetCoreSchemaHandler +from pydantic_core import core_schema +from typing_extensions import TypeAlias + + +class _OmitTypeAnnotation: + @classmethod + def __get_pydantic_core_schema__( + cls, + _source_type: Any, + _handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + def validate_from_none(value: None) -> _Omit: + return _Omit() + + from_none_schema = core_schema.chain_schema( + [ + core_schema.none_schema(), + core_schema.no_info_plain_validator_function(validate_from_none), + ] + ) + return core_schema.json_or_python_schema( + json_schema=from_none_schema, + python_schema=core_schema.union_schema( + [ + # check if it's an instance first before doing any further work + core_schema.is_instance_schema(_Omit), + from_none_schema, + ] + ), + serialization=core_schema.plain_serializer_function_ser_schema( + lambda instance: None + ), + ) +Omit = Annotated[_Omit, _OmitTypeAnnotation] +Headers: TypeAlias = Mapping[str, Union[str, Omit]] @dataclass class ModelSettings: diff --git a/tests/model_settings/test_serialization.py b/tests/model_settings/test_serialization.py index 2bbc7ce2c..94d11def3 100644 --- a/tests/model_settings/test_serialization.py +++ b/tests/model_settings/test_serialization.py @@ -2,6 +2,8 @@ from dataclasses import fields from openai.types.shared import Reasoning +from pydantic import TypeAdapter +from pydantic_core import to_json from agents.model_settings import ModelSettings @@ -132,3 +134,32 @@ def test_extra_args_resolve_both_none() -> None: assert resolved.extra_args is None assert resolved.temperature == 0.5 assert resolved.top_p == 0.9 + +def test_pydantic_serialization() -> None: + + """Tests whether ModelSettings can be serialized with Pydantic.""" + + # First, lets create a ModelSettings instance + model_settings = ModelSettings( + temperature=0.5, + top_p=0.9, + frequency_penalty=0.0, + presence_penalty=0.0, + tool_choice="auto", + parallel_tool_calls=True, + truncation="auto", + max_tokens=100, + reasoning=Reasoning(), + metadata={"foo": "bar"}, + store=False, + include_usage=False, + extra_query={"foo": "bar"}, + extra_body={"foo": "bar"}, + extra_headers={"foo": "bar"}, + extra_args={"custom_param": "value", "another_param": 42}, + ) + + json = to_json(model_settings) + deserialized = TypeAdapter(ModelSettings).validate_json(json) + + assert model_settings == deserialized