Skip to content

Commit ccd26a1

Browse files
Use isinstance checks for message kinds (#252)
1 parent d3b7f2d commit ccd26a1

File tree

4 files changed

+25
-23
lines changed

4 files changed

+25
-23
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
models,
2121
result,
2222
)
23-
from .messages import TextPart, ToolCallPart
2423
from .result import ResultData
2524
from .settings import ModelSettings, merge_model_settings
2625
from .tools import (
@@ -795,7 +794,7 @@ async def _prepare_messages(
795794
self, deps: AgentDeps, user_prompt: str, message_history: list[_messages.Message] | None
796795
) -> tuple[int, list[_messages.Message]]:
797796
# if message history includes system prompts, we don't want to regenerate them
798-
if message_history and any(m.message_kind == 'system-prompt' for m in message_history):
797+
if message_history and any(isinstance(m, _messages.SystemPrompt) for m in message_history):
799798
# shallow copy messages
800799
messages = message_history.copy()
801800
else:
@@ -816,9 +815,9 @@ async def _handle_model_response(
816815
A tuple of `(final_result, messages)`. If `final_result` is not `None`, the conversation should end.
817816
"""
818817
texts: list[str] = []
819-
tool_calls: list[ToolCallPart] = []
818+
tool_calls: list[_messages.ToolCallPart] = []
820819
for item in model_response.parts:
821-
if isinstance(item, TextPart):
820+
if isinstance(item, _messages.TextPart):
822821
texts.append(item.content)
823822
else:
824823
tool_calls.append(item)
@@ -852,7 +851,7 @@ async def _handle_text_response(
852851
return None, [response]
853852

854853
async def _handle_structured_response(
855-
self, tool_calls: list[ToolCallPart], deps: AgentDeps
854+
self, tool_calls: list[_messages.ToolCallPart], deps: AgentDeps
856855
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.Message]]:
857856
"""Handle a structured response containing tool calls from the model for non-streaming responses."""
858857
assert tool_calls, 'Expected at least one tool call'
@@ -870,7 +869,7 @@ async def _handle_structured_response(
870869

871870
async def _process_final_tool_calls(
872871
self,
873-
tool_calls: list[ToolCallPart],
872+
tool_calls: list[_messages.ToolCallPart],
874873
deps: AgentDeps,
875874
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.Message]]:
876875
"""Process any final result tool calls and return the first valid result."""
@@ -916,7 +915,7 @@ async def _process_final_tool_calls(
916915

917916
async def _process_function_tools(
918917
self,
919-
tool_calls: list[ToolCallPart],
918+
tool_calls: list[_messages.ToolCallPart],
920919
deps: AgentDeps,
921920
) -> list[_messages.Message]:
922921
"""Process function (non-final) tool calls in parallel."""
@@ -1013,7 +1012,7 @@ async def _handle_streamed_model_response(
10131012
# we now run all tool functions in parallel
10141013
tasks: list[asyncio.Task[_messages.Message]] = []
10151014
for item in structured_msg.parts:
1016-
if isinstance(item, ToolCallPart):
1015+
if isinstance(item, _messages.ToolCallPart):
10171016
call = item
10181017
if tool := self._function_tools.get(call.tool_name):
10191018
tasks.append(asyncio.create_task(tool.run(deps, call), name=call.tool_name))

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ async def _messages_create(
196196
anthropic_messages: list[MessageParam] = []
197197

198198
for m in messages:
199-
if m.message_kind == 'system-prompt':
199+
if isinstance(m, SystemPrompt):
200200
system_prompt += m.content
201201
else:
202202
anthropic_messages.append(self._map_message(m))

tests/models/test_model_function.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ async def weather_model(messages: list[Message], info: AgentInfo) -> ModelRespon
8282
assert info.allow_text_result
8383
assert {t.name for t in info.function_tools} == {'get_location', 'get_weather'}
8484
last = messages[-1]
85-
if last.message_kind == 'user-prompt':
85+
if isinstance(last, UserPrompt):
8686
return ModelResponse(
8787
parts=[
8888
ToolCallPart.from_json(
@@ -91,11 +91,11 @@ async def weather_model(messages: list[Message], info: AgentInfo) -> ModelRespon
9191
)
9292
]
9393
)
94-
elif last.message_kind == 'tool-return':
94+
elif isinstance(last, ToolReturn):
9595
if last.tool_name == 'get_location':
9696
return ModelResponse(parts=[ToolCallPart.from_json('get_weather', last.model_response_str())])
9797
elif last.tool_name == 'get_weather':
98-
location_name = next(m.content for m in messages if m.message_kind == 'user-prompt')
98+
location_name = next(m.content for m in messages if isinstance(m, UserPrompt))
9999
return ModelResponse.from_text(f'{last.content} in {location_name}')
100100

101101
raise ValueError(f'Unexpected message: {last}')
@@ -177,7 +177,7 @@ def test_weather(set_event_loop: None):
177177

178178
async def call_function_model(messages: list[Message], _: AgentInfo) -> ModelResponse: # pragma: no cover
179179
last = messages[-1]
180-
if last.message_kind == 'user-prompt':
180+
if isinstance(last, UserPrompt):
181181
if last.content.startswith('{'):
182182
details = json.loads(last.content)
183183
return ModelResponse(
@@ -188,7 +188,7 @@ async def call_function_model(messages: list[Message], _: AgentInfo) -> ModelRes
188188
)
189189
]
190190
)
191-
elif last.message_kind == 'tool-return':
191+
elif isinstance(last, ToolReturn):
192192
return ModelResponse.from_text(pydantic_core.to_json(last).decode())
193193

194194
raise ValueError(f'Unexpected message: {last}')

tests/test_examples.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
ArgsDict,
2323
Message,
2424
ModelResponse,
25+
RetryPrompt,
2526
ToolCallPart,
27+
ToolReturn,
28+
UserPrompt,
2629
)
2730
from pydantic_ai.models import KnownModelName, Model
2831
from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel
@@ -215,7 +218,7 @@ async def async_http_request(url: str, **kwargs: Any) -> httpx.Response:
215218

216219
async def model_logic(messages: list[Message], info: AgentInfo) -> ModelResponse: # pragma: no cover
217220
m = messages[-1]
218-
if m.message_kind == 'user-prompt':
221+
if isinstance(m, UserPrompt):
219222
if response := text_responses.get(m.content):
220223
if isinstance(response, str):
221224
return ModelResponse.from_text(content=response)
@@ -225,28 +228,28 @@ async def model_logic(messages: list[Message], info: AgentInfo) -> ModelResponse
225228
if re.fullmatch(r'sql prompt \d+', m.content):
226229
return ModelResponse.from_text(content='SELECT 1')
227230

228-
elif m.message_kind == 'tool-return' and m.tool_name == 'roulette_wheel':
231+
elif isinstance(m, ToolReturn) and m.tool_name == 'roulette_wheel':
229232
win = m.content == 'winner'
230233
return ModelResponse(parts=[ToolCallPart(tool_name='final_result', args=ArgsDict({'response': win}))])
231-
elif m.message_kind == 'tool-return' and m.tool_name == 'roll_die':
234+
elif isinstance(m, ToolReturn) and m.tool_name == 'roll_die':
232235
return ModelResponse(parts=[ToolCallPart(tool_name='get_player_name', args=ArgsDict({}))])
233-
elif m.message_kind == 'tool-return' and m.tool_name == 'get_player_name':
236+
elif isinstance(m, ToolReturn) and m.tool_name == 'get_player_name':
234237
return ModelResponse.from_text(content="Congratulations Anne, you guessed correctly! You're a winner!")
235238
if (
236-
m.message_kind == 'retry-prompt'
239+
isinstance(m, RetryPrompt)
237240
and isinstance(m.content, str)
238241
and m.content.startswith("No user found with name 'Joh")
239242
):
240243
return ModelResponse(parts=[ToolCallPart(tool_name='get_user_by_name', args=ArgsDict({'name': 'John Doe'}))])
241-
elif m.message_kind == 'tool-return' and m.tool_name == 'get_user_by_name':
244+
elif isinstance(m, ToolReturn) and m.tool_name == 'get_user_by_name':
242245
args = {
243246
'message': 'Hello John, would you be free for coffee sometime next week? Let me know what works for you!',
244247
'user_id': 123,
245248
}
246249
return ModelResponse(parts=[ToolCallPart(tool_name='final_result', args=ArgsDict(args))])
247-
elif m.message_kind == 'retry-prompt' and m.tool_name == 'calc_volume':
250+
elif isinstance(m, RetryPrompt) and m.tool_name == 'calc_volume':
248251
return ModelResponse(parts=[ToolCallPart(tool_name='calc_volume', args=ArgsDict({'size': 6}))])
249-
elif m.message_kind == 'tool-return' and m.tool_name == 'customer_balance':
252+
elif isinstance(m, ToolReturn) and m.tool_name == 'customer_balance':
250253
args = {
251254
'support_advice': 'Hello John, your current account balance, including pending transactions, is $123.45.',
252255
'block_card': False,
@@ -262,7 +265,7 @@ async def stream_model_logic(
262265
messages: list[Message], info: AgentInfo
263266
) -> AsyncIterator[str | DeltaToolCalls]: # pragma: no cover
264267
m = messages[-1]
265-
if m.message_kind == 'user-prompt':
268+
if isinstance(m, UserPrompt):
266269
if response := text_responses.get(m.content):
267270
if isinstance(response, str):
268271
words = response.split(' ')

0 commit comments

Comments
 (0)