Skip to content

Commit 4e60e9d

Browse files
authored
Fix duplicate output tool return part when concatenating first run messages with follow-up new_messages (#3075)
1 parent 5e596f1 commit 4e60e9d

File tree

2 files changed

+136
-4
lines changed

2 files changed

+136
-4
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -458,15 +458,13 @@ async def _prepare_request(
458458

459459
original_history = ctx.state.message_history[:]
460460
message_history = await _process_message_history(original_history, ctx.deps.history_processors, run_context)
461-
# Never merge the new `ModelRequest` with the one preceding it, to keep `new_messages()` from accidentally including part of the existing message history
462-
message_history = [*_clean_message_history(message_history[:-1]), message_history[-1]]
463461
# `ctx.state.message_history` is the same list used by `capture_run_messages`, so we should replace its contents, not the reference
464462
ctx.state.message_history[:] = message_history
465463
# Update the new message index to ensure `result.new_messages()` returns the correct messages
466464
ctx.deps.new_message_index -= len(original_history) - len(message_history)
467465

468-
# Do one more cleaning pass to merge possible consecutive trailing `ModelRequest`s into one, with tool call parts before user parts,
469-
# but don't store it in the message history on state.
466+
# Merge possible consecutive trailing `ModelRequest`s into one, with tool call parts before user parts,
467+
# but don't store it in the message history on state. This is just for the benefit of model classes that want clear user/assistant boundaries.
470468
# See `tests/test_tools.py::test_parallel_tool_return_with_deferred` for an example where this is necessary
471469
message_history = _clean_message_history(message_history)
472470

tests/test_agent.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5377,3 +5377,137 @@ def dynamic_instr() -> str:
53775377
sys_texts = [p.content for p in req.parts if isinstance(p, SystemPromptPart)]
53785378
# The dynamic system prompt should still be present since overrides target instructions only
53795379
assert dynamic_value in sys_texts
5380+
5381+
5382+
def test_continue_conversation_that_ended_in_output_tool_call(allow_model_requests: None):
5383+
def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
5384+
if any(isinstance(p, ToolReturnPart) and p.tool_name == 'roll_dice' for p in messages[-1].parts):
5385+
return ModelResponse(
5386+
parts=[
5387+
ToolCallPart(
5388+
tool_name='final_result',
5389+
args={'dice_roll': 4},
5390+
tool_call_id='pyd_ai_tool_call_id__final_result',
5391+
)
5392+
]
5393+
)
5394+
return ModelResponse(
5395+
parts=[ToolCallPart(tool_name='roll_dice', args={}, tool_call_id='pyd_ai_tool_call_id__roll_dice')]
5396+
)
5397+
5398+
class Result(BaseModel):
5399+
dice_roll: int
5400+
5401+
agent = Agent(FunctionModel(llm), output_type=Result)
5402+
5403+
@agent.tool_plain
5404+
def roll_dice() -> int:
5405+
return 4
5406+
5407+
result = agent.run_sync('Roll me a dice.')
5408+
messages = result.all_messages()
5409+
assert messages == snapshot(
5410+
[
5411+
ModelRequest(
5412+
parts=[
5413+
UserPromptPart(
5414+
content='Roll me a dice.',
5415+
timestamp=IsDatetime(),
5416+
)
5417+
]
5418+
),
5419+
ModelResponse(
5420+
parts=[ToolCallPart(tool_name='roll_dice', args={}, tool_call_id='pyd_ai_tool_call_id__roll_dice')],
5421+
usage=RequestUsage(input_tokens=55, output_tokens=2),
5422+
model_name='function:llm:',
5423+
timestamp=IsDatetime(),
5424+
),
5425+
ModelRequest(
5426+
parts=[
5427+
ToolReturnPart(
5428+
tool_name='roll_dice',
5429+
content=4,
5430+
tool_call_id='pyd_ai_tool_call_id__roll_dice',
5431+
timestamp=IsDatetime(),
5432+
)
5433+
]
5434+
),
5435+
ModelResponse(
5436+
parts=[
5437+
ToolCallPart(
5438+
tool_name='final_result',
5439+
args={'dice_roll': 4},
5440+
tool_call_id='pyd_ai_tool_call_id__final_result',
5441+
)
5442+
],
5443+
usage=RequestUsage(input_tokens=56, output_tokens=6),
5444+
model_name='function:llm:',
5445+
timestamp=IsDatetime(),
5446+
),
5447+
ModelRequest(
5448+
parts=[
5449+
ToolReturnPart(
5450+
tool_name='final_result',
5451+
content='Final result processed.',
5452+
tool_call_id='pyd_ai_tool_call_id__final_result',
5453+
timestamp=IsDatetime(),
5454+
)
5455+
]
5456+
),
5457+
]
5458+
)
5459+
5460+
result = agent.run_sync('Roll me a dice again.', message_history=messages)
5461+
new_messages = result.new_messages()
5462+
assert new_messages == snapshot(
5463+
[
5464+
ModelRequest(
5465+
parts=[
5466+
UserPromptPart(
5467+
content='Roll me a dice again.',
5468+
timestamp=IsDatetime(),
5469+
)
5470+
]
5471+
),
5472+
ModelResponse(
5473+
parts=[ToolCallPart(tool_name='roll_dice', args={}, tool_call_id='pyd_ai_tool_call_id__roll_dice')],
5474+
usage=RequestUsage(input_tokens=66, output_tokens=8),
5475+
model_name='function:llm:',
5476+
timestamp=IsDatetime(),
5477+
),
5478+
ModelRequest(
5479+
parts=[
5480+
ToolReturnPart(
5481+
tool_name='roll_dice',
5482+
content=4,
5483+
tool_call_id='pyd_ai_tool_call_id__roll_dice',
5484+
timestamp=IsDatetime(),
5485+
)
5486+
]
5487+
),
5488+
ModelResponse(
5489+
parts=[
5490+
ToolCallPart(
5491+
tool_name='final_result',
5492+
args={'dice_roll': 4},
5493+
tool_call_id='pyd_ai_tool_call_id__final_result',
5494+
)
5495+
],
5496+
usage=RequestUsage(input_tokens=67, output_tokens=12),
5497+
model_name='function:llm:',
5498+
timestamp=IsDatetime(),
5499+
),
5500+
ModelRequest(
5501+
parts=[
5502+
ToolReturnPart(
5503+
tool_name='final_result',
5504+
content='Final result processed.',
5505+
tool_call_id='pyd_ai_tool_call_id__final_result',
5506+
timestamp=IsDatetime(),
5507+
)
5508+
]
5509+
),
5510+
]
5511+
)
5512+
5513+
assert not any(isinstance(p, ToolReturnPart) and p.tool_name == 'final_result' for p in new_messages[0].parts)

0 commit comments

Comments
 (0)