Skip to content

Cannot handle consecutive model request/model response #2860

@yf-yang

Description

@yf-yang

Initial Checks

Description

If message history contains multiple ModelRequest/ModelResponse of tool return/tool calls, then the framework fails to handle it correctly.

The cause is here:

if (tool_call_results := ctx.deps.tool_call_results) is not None:
if messages and (last_message := messages[-1]) and isinstance(last_message, _messages.ModelRequest):
# If tool call results were provided, that means the previous run ended on deferred tool calls.
# That run would typically have ended on a `ModelResponse`, but if it had a mix of deferred tool calls and ones that could already be executed,
# a `ModelRequest` would already have been added to the history with the preliminary results, even if it wouldn't have been sent to the model yet.
# So now that we have all of the deferred results, we roll back to the last `ModelResponse` and store the contents of the `ModelRequest` on `deferred_tool_results` to be handled by `CallToolsNode`.
ctx.deps.tool_call_results = self._update_tool_call_results_from_model_request(
tool_call_results, last_message
)
messages.pop()
if not messages:
raise exceptions.UserError('Tool call results were provided, but the message history is empty.')
if messages and (last_message := messages[-1]):
if isinstance(last_message, _messages.ModelRequest) and self.user_prompt is None:
# Drop last message from history and reuse its parts
messages.pop()
parts.extend(last_message.parts)
elif isinstance(last_message, _messages.ModelResponse):
call_tools_node = await self._handle_message_history_model_response(ctx, last_message)
if call_tools_node is not None:
return call_tools_node

Both blocks are merely one if, instead of a loop, indicating it assumes there should be at most one request/response.

However I don't know what's the proper place to fix it. I personally think there should be a history merge step to merge consecutive requests/responses somewhere.

Example Code

import asyncio

from dotenv import load_dotenv
from pydantic_ai import (
  Agent,
  CallDeferred,
  DeferredToolRequests,
  DeferredToolResults,
)
from pydantic_ai.messages import (
  ModelRequest,
  ModelResponse,
  ToolCallPart,
  ToolReturnPart,
  UserPromptPart,
)

load_dotenv()

agent = Agent("anthropic:claude-sonnet-4-0", output_type=[str, DeferredToolRequests])


@agent.tool_plain
async def sum(a: int, b: int) -> int:
  return a + b


@agent.tool_plain
async def multiply(a: int, b: int) -> int:
  raise CallDeferred


async def main():
  # Successful case
  # messages = [
  #   ModelRequest(parts=[UserPromptPart(content="Calculate 1+2, 3+4, 5*6, 7*8")]),
  #   ModelResponse(
  #     parts=[
  #       ToolCallPart(tool_name="sum", args={"a": 1, "b": 2}, tool_call_id="sum-1"),
  #       ToolCallPart(tool_name="sum", args={"a": 3, "b": 4}, tool_call_id="sum-2"),
  #       ToolCallPart(tool_name="multiply", args={"a": 5, "b": 6}, tool_call_id="multiply-1"),
  #       ToolCallPart(tool_name="multiply", args={"a": 7, "b": 8}, tool_call_id="multiply-2"),
  #     ]
  #   ),
  #   ModelRequest(
  #     parts=[
  #       ToolReturnPart(tool_name="sum", tool_call_id="sum-1", content="3"),
  #       ToolReturnPart(tool_name="sum", tool_call_id="sum-2", content="7"),
  #     ]
  #   ),
  # ]

  # Failed case 1
  messages = [
    ModelRequest(parts=[UserPromptPart(content="Calculate 1+2, 3+4, 5*6, 7*8")]),
    ModelResponse(
      parts=[
        ToolCallPart(tool_name="sum", args={"a": 1, "b": 2}, tool_call_id="sum-1"),
        ToolCallPart(tool_name="sum", args={"a": 3, "b": 4}, tool_call_id="sum-2"),
        ToolCallPart(tool_name="multiply", args={"a": 5, "b": 6}, tool_call_id="multiply-1"),
        ToolCallPart(tool_name="multiply", args={"a": 7, "b": 8}, tool_call_id="multiply-2"),
      ]
    ),
    ModelRequest(
      parts=[
        ToolReturnPart(tool_name="sum", tool_call_id="sum-1", content="3"),
      ]
    ),
    ModelRequest(
      parts=[
        ToolReturnPart(tool_name="sum", tool_call_id="sum-2", content="7"),
      ]
    ),
  ]

  # Failed case 2
  messages = [
    ModelRequest(parts=[UserPromptPart(content="Calculate 1+2, 3+4, 5*6, 7*8")]),
    ModelResponse(parts=[ToolCallPart(tool_name="sum", args={"a": 1, "b": 2}, tool_call_id="sum-1")]),
    ModelResponse(
      parts=[
        ToolCallPart(tool_name="sum", args={"a": 3, "b": 4}, tool_call_id="sum-2"),
      ]
    ),
    ModelResponse(
      parts=[
        ToolCallPart(tool_name="multiply", args={"a": 5, "b": 6}, tool_call_id="multiply-1"),
      ]
    ),
    ModelResponse(
      parts=[
        ToolCallPart(tool_name="multiply", args={"a": 7, "b": 8}, tool_call_id="multiply-2"),
      ]
    ),
    ModelRequest(
      parts=[
        ToolReturnPart(tool_name="sum", tool_call_id="sum-1", content="3"),
        ToolReturnPart(tool_name="sum", tool_call_id="sum-2", content="7"),
      ]
    ),
  ]

  tool_results = DeferredToolResults(calls={"multiply-1": "30", "multiply-2": "56"})

  result = await agent.run(message_history=messages, deferred_tool_results=tool_results)
  print(result.output)
  print(result.all_messages())


if __name__ == "__main__":
  asyncio.run(main())

Python, Pydantic AI & LLM client version

1.0.3

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions