Skip to content

Commit ff7015a

Browse files
jlowinsamuelcolvin
andauthored
Generate tool results when using structured result (#179)
Co-authored-by: Samuel Colvin <[email protected]>
1 parent 024b1e2 commit ff7015a

File tree

6 files changed

+246
-48
lines changed

6 files changed

+246
-48
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 54 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -215,23 +215,24 @@ async def run(
215215
cost += request_cost
216216

217217
with _logfire.span('handle model response', run_step=run_step) as handle_span:
218-
either = await self._handle_model_response(model_response, deps)
218+
final_result, response_messages = await self._handle_model_response(model_response, deps)
219219

220-
if isinstance(either, _MarkFinalResult):
221-
# we have a final result, end the conversation
222-
result_data = either.data
220+
# Add all messages to the conversation
221+
messages.extend(response_messages)
222+
223+
# Check if we got a final result
224+
if final_result is not None:
225+
result_data = final_result.data
223226
run_span.set_attribute('all_messages', messages)
224227
run_span.set_attribute('cost', cost)
225228
handle_span.set_attribute('result', result_data)
226229
handle_span.message = 'handle model response -> final result'
227230
return result.RunResult(messages, new_message_index, result_data, cost)
228231
else:
229232
# continue the conversation
230-
tool_responses = either
231-
handle_span.set_attribute('tool_responses', tool_responses)
232-
response_msgs = ' '.join(m.role for m in tool_responses)
233+
handle_span.set_attribute('tool_responses', response_messages)
234+
response_msgs = ' '.join(r.role for r in response_messages)
233235
handle_span.message = f'handle model response -> {response_msgs}'
234-
messages.extend(tool_responses)
235236

236237
def run_sync(
237238
self,
@@ -324,10 +325,16 @@ async def run_stream(
324325
model_req_span.__exit__(None, None, None)
325326

326327
with _logfire.span('handle model response') as handle_span:
327-
either = await self._handle_streamed_model_response(model_response, deps)
328+
final_result, response_messages = await self._handle_streamed_model_response(
329+
model_response, deps
330+
)
331+
332+
# Add all messages to the conversation
333+
messages.extend(response_messages)
328334

329-
if isinstance(either, _MarkFinalResult):
330-
result_stream = either.data
335+
# Check if we got a final result
336+
if final_result is not None:
337+
result_stream = final_result.data
331338
run_span.set_attribute('all_messages', messages)
332339
handle_span.set_attribute('result_type', result_stream.__class__.__name__)
333340
handle_span.message = 'handle model response -> final result'
@@ -343,11 +350,10 @@ async def run_stream(
343350
)
344351
return
345352
else:
346-
tool_responses = either
347-
handle_span.set_attribute('tool_responses', tool_responses)
348-
response_msgs = ' '.join(m.role for m in tool_responses)
353+
# continue the conversation
354+
handle_span.set_attribute('tool_responses', response_messages)
355+
response_msgs = ' '.join(r.role for r in response_messages)
349356
handle_span.message = f'handle model response -> {response_msgs}'
350-
messages.extend(tool_responses)
351357
# the model_response should have been fully streamed by now, we can add it's cost
352358
cost += model_response.cost()
353359

@@ -725,11 +731,11 @@ async def _prepare_messages(
725731

726732
async def _handle_model_response(
727733
self, model_response: _messages.ModelAnyResponse, deps: AgentDeps
728-
) -> _MarkFinalResult[ResultData] | list[_messages.Message]:
734+
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.Message]]:
729735
"""Process a non-streamed response from the model.
730736
731737
Returns:
732-
Return `Either` — left: final result data, right: list of messages to send back to the model.
738+
A tuple of `(final_result, messages)`. If `final_result` is not `None`, the conversation should end.
733739
"""
734740
if model_response.role == 'model-text-response':
735741
# plain string response
@@ -739,15 +745,15 @@ async def _handle_model_response(
739745
result_data = await self._validate_result(result_data_input, deps, None)
740746
except _result.ToolRetryError as e:
741747
self._incr_result_retry()
742-
return [e.tool_retry]
748+
return None, [e.tool_retry]
743749
else:
744-
return _MarkFinalResult(result_data)
750+
return _MarkFinalResult(result_data), []
745751
else:
746752
self._incr_result_retry()
747753
response = _messages.RetryPrompt(
748754
content='Plain text responses are not permitted, please call one of the functions instead.',
749755
)
750-
return [response]
756+
return None, [response]
751757
elif model_response.role == 'model-structured-response':
752758
if self._result_schema is not None:
753759
# if there's a result schema, and any of the calls match one of its tools, return the result
@@ -759,9 +765,15 @@ async def _handle_model_response(
759765
result_data = await self._validate_result(result_data, deps, call)
760766
except _result.ToolRetryError as e:
761767
self._incr_result_retry()
762-
return [e.tool_retry]
768+
return None, [e.tool_retry]
763769
else:
764-
return _MarkFinalResult(result_data)
770+
# Add a ToolReturn message for the schema tool call
771+
tool_return = _messages.ToolReturn(
772+
tool_name=call.tool_name,
773+
content='Final result processed.',
774+
tool_id=call.tool_id,
775+
)
776+
return _MarkFinalResult(result_data), [tool_return]
765777

766778
if not model_response.calls:
767779
raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
@@ -776,26 +788,24 @@ async def _handle_model_response(
776788
messages.append(self._unknown_tool(call.tool_name))
777789

778790
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
779-
messages += await asyncio.gather(*tasks)
780-
return messages
791+
task_results: Sequence[_messages.Message] = await asyncio.gather(*tasks)
792+
messages.extend(task_results)
793+
return None, messages
781794
else:
782795
assert_never(model_response)
783796

784797
async def _handle_streamed_model_response(
785798
self, model_response: models.EitherStreamedResponse, deps: AgentDeps
786-
) -> _MarkFinalResult[models.EitherStreamedResponse] | list[_messages.Message]:
799+
) -> tuple[_MarkFinalResult[models.EitherStreamedResponse] | None, list[_messages.Message]]:
787800
"""Process a streamed response from the model.
788801
789-
TODO: change the response type to `models.EitherStreamedResponse | list[_messages.Message]` once we drop 3.9
790-
(with 3.9 we get `TypeError: Subscripted generics cannot be used with class and instance checks`)
791-
792802
Returns:
793-
Return `Either` — left: final result data, right: list of messages to send back to the model.
803+
A tuple of (final_result, messages). If final_result is not None, the conversation should end.
794804
"""
795805
if isinstance(model_response, models.StreamTextResponse):
796806
# plain string response
797807
if self._allow_text_result:
798-
return _MarkFinalResult(model_response)
808+
return _MarkFinalResult(model_response), []
799809
else:
800810
self._incr_result_retry()
801811
response = _messages.RetryPrompt(
@@ -805,7 +815,7 @@ async def _handle_streamed_model_response(
805815
async for _ in model_response:
806816
pass
807817

808-
return [response]
818+
return None, [response]
809819
else:
810820
assert isinstance(model_response, models.StreamStructuredResponse), f'Unexpected response: {model_response}'
811821
if self._result_schema is not None:
@@ -819,8 +829,14 @@ async def _handle_streamed_model_response(
819829
break
820830
structured_msg = model_response.get()
821831

822-
if self._result_schema.find_tool(structured_msg):
823-
return _MarkFinalResult(model_response)
832+
if match := self._result_schema.find_tool(structured_msg):
833+
call, _ = match
834+
tool_return = _messages.ToolReturn(
835+
tool_name=call.tool_name,
836+
content='Final result processed.',
837+
tool_id=call.tool_id,
838+
)
839+
return _MarkFinalResult(model_response), [tool_return]
824840

825841
# the model is calling a tool function, consume the response to get the next message
826842
async for _ in model_response:
@@ -839,8 +855,9 @@ async def _handle_streamed_model_response(
839855
messages.append(self._unknown_tool(call.tool_name))
840856

841857
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
842-
messages += await asyncio.gather(*tasks)
843-
return messages
858+
task_results: Sequence[_messages.Message] = await asyncio.gather(*tasks)
859+
messages.extend(task_results)
860+
return None, messages
844861

845862
async def _validate_result(
846863
self, result_data: ResultData, deps: AgentDeps, tool_call: _messages.ToolCall | None
@@ -912,6 +929,8 @@ class _MarkFinalResult(Generic[ResultData]):
912929
"""Marker class to indicate that the result is the final result.
913930
914931
This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultData` directly.
932+
933+
It also avoids problems in the case where the result type is itself `None`, but is set.
915934
"""
916935

917936
data: ResultData

tests/models/test_gemini.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,11 @@ async def test_request_structured_response(get_gemini_client: GetGeminiClient):
435435
],
436436
timestamp=IsNow(tz=timezone.utc),
437437
),
438+
ToolReturn(
439+
tool_name='final_result',
440+
content='Final result processed.',
441+
timestamp=IsNow(tz=timezone.utc),
442+
),
438443
]
439444
)
440445

@@ -648,6 +653,11 @@ async def bar(y: str) -> str:
648653
),
649654
ToolReturn(tool_name='foo', content='a', timestamp=IsNow(tz=timezone.utc)),
650655
ToolReturn(tool_name='bar', content='b', timestamp=IsNow(tz=timezone.utc)),
656+
ToolReturn(
657+
tool_name='final_result',
658+
content='Final result processed.',
659+
timestamp=IsNow(tz=timezone.utc),
660+
),
651661
ModelStructuredResponse(
652662
calls=[
653663
ToolCall(

tests/models/test_groq.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,12 @@ async def test_request_structured_response(allow_model_requests: None):
192192
],
193193
timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
194194
),
195+
ToolReturn(
196+
tool_name='final_result',
197+
content='Final result processed.',
198+
tool_id='123',
199+
timestamp=IsNow(tz=timezone.utc),
200+
),
195201
]
196202
)
197203

@@ -277,15 +283,15 @@ async def get_location(loc_name: str) -> str:
277283
tool_id='2',
278284
)
279285
],
280-
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
286+
timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
281287
),
282288
ToolReturn(
283289
tool_name='get_location',
284290
content='{"lat": 51, "lng": 0}',
285291
tool_id='2',
286292
timestamp=IsNow(tz=timezone.utc),
287293
),
288-
ModelTextResponse(content='final response', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)),
294+
ModelTextResponse(content='final response', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc)),
289295
]
290296
)
291297

@@ -390,10 +396,15 @@ async def test_stream_structured(allow_model_requests: None):
390396
assert result.is_complete
391397

392398
assert result.cost() == snapshot(Cost())
393-
assert result.timestamp() == snapshot(datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc))
394399
assert result.all_messages() == snapshot(
395400
[
396401
UserPrompt(content='', timestamp=IsNow(tz=timezone.utc)),
402+
ToolReturn(
403+
tool_name='final_result',
404+
content='Final result processed.',
405+
tool_id=None,
406+
timestamp=IsNow(tz=timezone.utc),
407+
),
397408
ModelStructuredResponse(
398409
calls=[
399410
ToolCall(tool_name='final_result', args=ArgsJson(args_json='{"first": "One", "second": "Two"}'))

tests/models/test_openai.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,12 @@ async def test_request_structured_response(allow_model_requests: None):
193193
],
194194
timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
195195
),
196+
ToolReturn(
197+
tool_name='final_result',
198+
content='Final result processed.',
199+
tool_id='123',
200+
timestamp=IsNow(tz=timezone.utc),
201+
),
196202
]
197203
)
198204

@@ -264,7 +270,7 @@ async def get_location(loc_name: str) -> str:
264270
tool_id='1',
265271
)
266272
],
267-
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
273+
timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
268274
),
269275
RetryPrompt(
270276
tool_name='get_location',
@@ -280,15 +286,15 @@ async def get_location(loc_name: str) -> str:
280286
tool_id='2',
281287
)
282288
],
283-
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
289+
timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
284290
),
285291
ToolReturn(
286292
tool_name='get_location',
287293
content='{"lat": 51, "lng": 0}',
288294
tool_id='2',
289295
timestamp=IsNow(tz=timezone.utc),
290296
),
291-
ModelTextResponse(content='final response', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)),
297+
ModelTextResponse(content='final response', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc)),
292298
]
293299
)
294300
assert result.cost() == snapshot(

0 commit comments

Comments
 (0)