Skip to content

Commit 9efdd40

Browse files
committed
Ensure ToolCallPart resulting from TestModel(custom_output_args=...) always holds a dict
1 parent 606643f commit 9efdd40

File tree

2 files changed

+108
-3
lines changed

2 files changed

+108
-3
lines changed

pydantic_ai_slim/pydantic_ai/models/test.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,14 @@ class _WrappedTextOutput:
4444
value: str | None
4545

4646

47-
@dataclass
47+
@dataclass(init=False)
4848
class _WrappedToolOutput:
4949
"""A wrapper class to tag an output that came from the custom_output_args field."""
5050

51-
value: Any | None
51+
value: dict[str, Any] | None
52+
53+
def __init__(self, value: Any | None):
54+
self.value = pydantic_core.to_jsonable_python(value)
5255

5356

5457
@dataclass(init=False)
@@ -364,7 +367,7 @@ def __init__(self, schema: _utils.ObjectJsonSchema, seed: int = 0):
364367
self.defs = schema.get('$defs', {})
365368
self.seed = seed
366369

367-
def generate(self) -> Any:
370+
def generate(self) -> dict[str, Any]:
368371
"""Generate data for the JSON schema."""
369372
return self._gen_any(self.schema)
370373

tests/models/test_model_test.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,40 @@ def test_custom_output_args():
6969
agent = Agent(output_type=tuple[str, str])
7070
result = agent.run_sync('x', model=TestModel(custom_output_args=['a', 'b']))
7171
assert result.output == ('a', 'b')
72+
assert result.all_messages() == snapshot(
73+
[
74+
ModelRequest(
75+
parts=[
76+
UserPromptPart(
77+
content='x',
78+
timestamp=IsNow(tz=timezone.utc),
79+
)
80+
]
81+
),
82+
ModelResponse(
83+
parts=[
84+
ToolCallPart(
85+
tool_name='final_result',
86+
args={'response': ['a', 'b']},
87+
tool_call_id='pyd_ai_tool_call_id__final_result',
88+
)
89+
],
90+
usage=RequestUsage(input_tokens=51, output_tokens=7),
91+
model_name='test',
92+
timestamp=IsNow(tz=timezone.utc),
93+
),
94+
ModelRequest(
95+
parts=[
96+
ToolReturnPart(
97+
tool_name='final_result',
98+
content='Final result processed.',
99+
tool_call_id='pyd_ai_tool_call_id__final_result',
100+
timestamp=IsNow(tz=timezone.utc),
101+
)
102+
]
103+
),
104+
]
105+
)
72106

73107

74108
def test_custom_output_args_model():
@@ -79,12 +113,80 @@ class Foo(BaseModel):
79113
agent = Agent(output_type=Foo)
80114
result = agent.run_sync('x', model=TestModel(custom_output_args={'foo': 'a', 'bar': 1}))
81115
assert result.output == Foo(foo='a', bar=1)
116+
assert result.all_messages() == snapshot(
117+
[
118+
ModelRequest(
119+
parts=[
120+
UserPromptPart(
121+
content='x',
122+
timestamp=IsNow(tz=timezone.utc),
123+
)
124+
]
125+
),
126+
ModelResponse(
127+
parts=[
128+
ToolCallPart(
129+
tool_name='final_result',
130+
args={'foo': 'a', 'bar': 1},
131+
tool_call_id='pyd_ai_tool_call_id__final_result',
132+
)
133+
],
134+
usage=RequestUsage(input_tokens=51, output_tokens=6),
135+
model_name='test',
136+
timestamp=IsNow(tz=timezone.utc),
137+
),
138+
ModelRequest(
139+
parts=[
140+
ToolReturnPart(
141+
tool_name='final_result',
142+
content='Final result processed.',
143+
tool_call_id='pyd_ai_tool_call_id__final_result',
144+
timestamp=IsNow(tz=timezone.utc),
145+
)
146+
]
147+
),
148+
]
149+
)
82150

83151

84152
def test_output_type():
85153
agent = Agent(output_type=tuple[str, str])
86154
result = agent.run_sync('x', model=TestModel())
87155
assert result.output == ('a', 'a')
156+
assert result.all_messages() == snapshot(
157+
[
158+
ModelRequest(
159+
parts=[
160+
UserPromptPart(
161+
content='x',
162+
timestamp=IsNow(tz=timezone.utc),
163+
)
164+
]
165+
),
166+
ModelResponse(
167+
parts=[
168+
ToolCallPart(
169+
tool_name='final_result',
170+
args={'response': ['a', 'a']},
171+
tool_call_id='pyd_ai_tool_call_id__final_result',
172+
)
173+
],
174+
usage=RequestUsage(input_tokens=51, output_tokens=7),
175+
model_name='test',
176+
timestamp=IsNow(tz=timezone.utc),
177+
),
178+
ModelRequest(
179+
parts=[
180+
ToolReturnPart(
181+
tool_name='final_result',
182+
content='Final result processed.',
183+
tool_call_id='pyd_ai_tool_call_id__final_result',
184+
timestamp=IsNow(tz=timezone.utc),
185+
)
186+
]
187+
),
188+
]
189+
)
88190

89191

90192
def test_tool_retry():

0 commit comments

Comments
 (0)