Skip to content

Commit 29e986e

Browse files
Adding commentary and tests re heterogenous behavior (#517)
1 parent a193111 commit 29e986e

File tree

3 files changed

+59
-56
lines changed

3 files changed

+59
-56
lines changed

pydantic_ai_examples/weather_agent.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ class Deps:
3535

3636
weather_agent = Agent(
3737
'openai:gpt-4o',
38+
# 'Be concise, reply with one sentence.' is enough for some models (like openai) to use
39+
# the below tools appropriately, but others like anthropic and gemini require a bit more direction.
3840
system_prompt=(
3941
'Be concise, reply with one sentence.'
4042
'Use the `get_lat_lng` tool to get the latitude and longitude of the locations, '

tests/models/test_gemini.py

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
ModelResponse,
2020
RetryPromptPart,
2121
SystemPromptPart,
22-
TextPart,
2322
ToolCallPart,
2423
ToolReturnPart,
2524
UserPromptPart,
@@ -537,61 +536,6 @@ def handler(_: httpx.Request):
537536
assert str(exc_info.value) == snapshot('Unexpected response from gemini 401, body:\ninvalid request')
538537

539538

540-
async def test_heterogeneous_responses(get_gemini_client: GetGeminiClient):
541-
"""Indicates that tool calls are prioritized over text in heterogeneous responses."""
542-
responses = [
543-
gemini_response(
544-
_content_model_response(
545-
ModelResponse(
546-
parts=[TextPart(content='foo'), ToolCallPart.from_raw_args('get_location', {'loc_name': 'London'})]
547-
)
548-
)
549-
),
550-
gemini_response(_content_model_response(ModelResponse.from_text('final response'))),
551-
]
552-
553-
gemini_client = get_gemini_client(responses)
554-
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client)
555-
agent = Agent(m)
556-
557-
@agent.tool_plain
558-
async def get_location(loc_name: str) -> str:
559-
if loc_name == 'London':
560-
return json.dumps({'lat': 51, 'lng': 0})
561-
else:
562-
raise ModelRetry('Wrong location, please try again')
563-
564-
result = await agent.run('Hello')
565-
assert result.data == 'final response'
566-
assert result.all_messages() == snapshot(
567-
[
568-
ModelRequest(
569-
parts=[
570-
UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
571-
]
572-
),
573-
ModelResponse(
574-
parts=[
575-
TextPart(content='foo'),
576-
ToolCallPart(
577-
tool_name='get_location',
578-
args=ArgsDict(args_dict={'loc_name': 'London'}),
579-
),
580-
],
581-
timestamp=IsNow(tz=timezone.utc),
582-
),
583-
ModelRequest(
584-
parts=[
585-
ToolReturnPart(
586-
tool_name='get_location', content='{"lat": 51, "lng": 0}', timestamp=IsNow(tz=timezone.utc)
587-
)
588-
]
589-
),
590-
ModelResponse.from_text(content='final response', timestamp=IsNow(tz=timezone.utc)),
591-
]
592-
)
593-
594-
595539
async def test_stream_text(get_gemini_client: GetGeminiClient):
596540
responses = [
597541
gemini_response(_content_model_response(ModelResponse.from_text('Hello '))),

tests/test_agent.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import sys
23
from datetime import timezone
34
from typing import Any, Callable, Union
@@ -16,6 +17,7 @@
1617
ModelMessage,
1718
ModelRequest,
1819
ModelResponse,
20+
ModelResponsePart,
1921
RetryPromptPart,
2022
SystemPromptPart,
2123
TextPart,
@@ -1124,3 +1126,58 @@ def return_empty_text(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
11241126

11251127
result = await agent.run('Hello')
11261128
assert result.data == ('foo', 'bar')
1129+
1130+
1131+
def test_heterogenous_reponses_non_streaming(set_event_loop: None) -> None:
1132+
"""Indicates that tool calls are prioritized over text in heterogeneous responses."""
1133+
1134+
def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
1135+
assert info.result_tools is not None
1136+
parts: list[ModelResponsePart] = []
1137+
if len(messages) == 1:
1138+
parts = [
1139+
TextPart(content='foo'),
1140+
ToolCallPart.from_raw_args('get_location', {'loc_name': 'London'}),
1141+
]
1142+
else:
1143+
parts = [TextPart(content='final response')]
1144+
return ModelResponse(parts=parts)
1145+
1146+
agent = Agent(FunctionModel(return_model))
1147+
1148+
@agent.tool_plain
1149+
async def get_location(loc_name: str) -> str:
1150+
if loc_name == 'London':
1151+
return json.dumps({'lat': 51, 'lng': 0})
1152+
else:
1153+
raise ModelRetry('Wrong location, please try again')
1154+
1155+
result = agent.run_sync('Hello')
1156+
assert result.data == 'final response'
1157+
assert result.all_messages() == snapshot(
1158+
[
1159+
ModelRequest(
1160+
parts=[
1161+
UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
1162+
]
1163+
),
1164+
ModelResponse(
1165+
parts=[
1166+
TextPart(content='foo'),
1167+
ToolCallPart(
1168+
tool_name='get_location',
1169+
args=ArgsDict(args_dict={'loc_name': 'London'}),
1170+
),
1171+
],
1172+
timestamp=IsNow(tz=timezone.utc),
1173+
),
1174+
ModelRequest(
1175+
parts=[
1176+
ToolReturnPart(
1177+
tool_name='get_location', content='{"lat": 51, "lng": 0}', timestamp=IsNow(tz=timezone.utc)
1178+
)
1179+
]
1180+
),
1181+
ModelResponse.from_text(content='final response', timestamp=IsNow(tz=timezone.utc)),
1182+
]
1183+
)

0 commit comments

Comments
 (0)