Skip to content

Commit 28697b9

Browse files
committed
Annotating the openai.Omit type so that ModelSettings can be serialized by pydantic
1 parent c2005f8 commit 28697b9

File tree

2 files changed

+68
-5
lines changed

2 files changed

+68
-5
lines changed

src/agents/model_settings.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,46 @@
22

33
import dataclasses
44
from dataclasses import dataclass, fields, replace
5-
from typing import Any, Literal
5+
from typing import Any, Literal, Annotated, TypeAlias, Mapping
66

7-
from openai._types import Body, Headers, Query
7+
from openai import Omit as _Omit
8+
from openai._types import Body, Query
89
from openai.types.responses import ResponseIncludable
910
from openai.types.shared import Reasoning
10-
from pydantic import BaseModel
11-
11+
from pydantic import BaseModel, GetCoreSchemaHandler
12+
from pydantic_core import core_schema
13+
14+
class _OmitTypeAnnotation:
15+
@classmethod
16+
def __get_pydantic_core_schema__(
17+
cls,
18+
_source_type: Any,
19+
_handler: GetCoreSchemaHandler,
20+
) -> core_schema.CoreSchema:
21+
def validate_from_none(value: None) -> _Omit:
22+
return _Omit()
23+
24+
from_none_schema = core_schema.chain_schema(
25+
[
26+
core_schema.none_schema(),
27+
core_schema.no_info_plain_validator_function(validate_from_none),
28+
]
29+
)
30+
return core_schema.json_or_python_schema(
31+
json_schema=from_none_schema,
32+
python_schema=core_schema.union_schema(
33+
[
34+
# check if it's an instance first before doing any further work
35+
core_schema.is_instance_schema(_Omit),
36+
from_none_schema,
37+
]
38+
),
39+
serialization=core_schema.plain_serializer_function_ser_schema(
40+
lambda instance: None
41+
),
42+
)
43+
Omit = Annotated[_Omit, _OmitTypeAnnotation]
44+
Headers: TypeAlias = Mapping[str, str | Omit]
1245

1346
@dataclass
1447
class ModelSettings:

tests/model_settings/test_serialization.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from dataclasses import fields
33

44
from openai.types.shared import Reasoning
5-
5+
from pydantic import TypeAdapter
6+
from pydantic_core import to_json
67
from agents.model_settings import ModelSettings
78

89

@@ -132,3 +133,32 @@ def test_extra_args_resolve_both_none() -> None:
132133
assert resolved.extra_args is None
133134
assert resolved.temperature == 0.5
134135
assert resolved.top_p == 0.9
136+
137+
def test_pydantic_serialization() -> None:
138+
139+
"""Tests whether ModelSettings can be serialized with Pydantic."""
140+
141+
# First, lets create a ModelSettings instance
142+
model_settings = ModelSettings(
143+
temperature=0.5,
144+
top_p=0.9,
145+
frequency_penalty=0.0,
146+
presence_penalty=0.0,
147+
tool_choice="auto",
148+
parallel_tool_calls=True,
149+
truncation="auto",
150+
max_tokens=100,
151+
reasoning=Reasoning(),
152+
metadata={"foo": "bar"},
153+
store=False,
154+
include_usage=False,
155+
extra_query={"foo": "bar"},
156+
extra_body={"foo": "bar"},
157+
extra_headers={"foo": "bar"},
158+
extra_args={"custom_param": "value", "another_param": 42},
159+
)
160+
161+
json = to_json(model_settings)
162+
deserialized = TypeAdapter(ModelSettings).validate_json(json)
163+
164+
assert model_settings == deserialized

0 commit comments

Comments
 (0)