Skip to content

Commit 866a031

Browse files
Prioritize tool calls over eager text responses (#505)
Co-authored-by: David Montague <[email protected]>
1 parent 8715834 commit 866a031

File tree

3 files changed

+61
-19
lines changed

3 files changed

+61
-19
lines changed

pydantic_ai_examples/weather_agent.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ class Deps:
3535

3636
weather_agent = Agent(
3737
'openai:gpt-4o',
38-
system_prompt='Be concise, reply with one sentence.',
38+
system_prompt=(
39+
'Be concise, reply with one sentence.'
40+
'Use the `get_lat_lng` tool to get the latitude and longitude of the locations, '
41+
'then use the `get_weather` tool to get the weather.'
42+
),
3943
deps_type=Deps,
4044
retries=2,
4145
)

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -864,11 +864,15 @@ async def _handle_model_response(
864864
else:
865865
tool_calls.append(part)
866866

867-
if texts:
867+
# At the moment, we prioritize at least executing tool calls if they are present.
868+
# In the future, we'd consider making this configurable at the agent or run level.
869+
# This accounts for cases like anthropic returns that might contain a text response
870+
# and a tool call response, where the text response just indicates the tool call will happen.
871+
if tool_calls:
872+
return await self._handle_structured_response(tool_calls, run_context)
873+
elif texts:
868874
text = '\n\n'.join(texts)
869875
return await self._handle_text_response(text, run_context)
870-
elif tool_calls:
871-
return await self._handle_structured_response(tool_calls, run_context)
872876
else:
873877
raise exceptions.UnexpectedModelBehavior('Received empty model response')
874878

tests/models/test_gemini.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
ModelResponse,
2020
RetryPromptPart,
2121
SystemPromptPart,
22+
TextPart,
2223
ToolCallPart,
2324
ToolReturnPart,
2425
UserPromptPart,
@@ -537,25 +538,58 @@ def handler(_: httpx.Request):
537538

538539

539540
async def test_heterogeneous_responses(get_gemini_client: GetGeminiClient):
540-
response = gemini_response(
541-
_GeminiContent(
542-
role='model',
543-
parts=[
544-
_GeminiTextPart(text='foo'),
545-
_function_call_part_from_call(
541+
"""Indicates that tool calls are prioritized over text in heterogeneous responses."""
542+
responses = [
543+
gemini_response(
544+
_content_model_response(
545+
ModelResponse(
546+
parts=[TextPart(content='foo'), ToolCallPart.from_raw_args('get_location', {'loc_name': 'London'})]
547+
)
548+
)
549+
),
550+
gemini_response(_content_model_response(ModelResponse.from_text('final response'))),
551+
]
552+
553+
gemini_client = get_gemini_client(responses)
554+
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client)
555+
agent = Agent(m)
556+
557+
@agent.tool_plain
558+
async def get_location(loc_name: str) -> str:
559+
if loc_name == 'London':
560+
return json.dumps({'lat': 51, 'lng': 0})
561+
else:
562+
raise ModelRetry('Wrong location, please try again')
563+
564+
result = await agent.run('Hello')
565+
assert result.data == 'final response'
566+
assert result.all_messages() == snapshot(
567+
[
568+
ModelRequest(
569+
parts=[
570+
UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
571+
]
572+
),
573+
ModelResponse(
574+
parts=[
575+
TextPart(content='foo'),
546576
ToolCallPart(
547577
tool_name='get_location',
548-
args=ArgsDict(args_dict={'loc_name': 'San Fransisco'}),
578+
args=ArgsDict(args_dict={'loc_name': 'London'}),
579+
),
580+
],
581+
timestamp=IsNow(tz=timezone.utc),
582+
),
583+
ModelRequest(
584+
parts=[
585+
ToolReturnPart(
586+
tool_name='get_location', content='{"lat": 51, "lng": 0}', timestamp=IsNow(tz=timezone.utc)
549587
)
550-
),
551-
],
552-
)
588+
]
589+
),
590+
ModelResponse.from_text(content='final response', timestamp=IsNow(tz=timezone.utc)),
591+
]
553592
)
554-
gemini_client = get_gemini_client(response)
555-
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client)
556-
agent = Agent(m)
557-
result = await agent.run('Hello')
558-
assert result.data == 'foo'
559593

560594

561595
async def test_stream_text(get_gemini_client: GetGeminiClient):

0 commit comments

Comments
 (0)