Skip to content

Commit 851df07

Browse files
authored
Let message history end on ModelResponse and execute pending tool calls (#2562)
1 parent 5d870ce commit 851df07

File tree

2 files changed

+292
-53
lines changed

2 files changed

+292
-53
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 43 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage
2424
from .exceptions import ToolRetryError
2525
from .output import OutputDataT, OutputSpec
26-
from .settings import ModelSettings, merge_model_settings
26+
from .settings import ModelSettings
2727
from .tools import RunContext, ToolDefinition, ToolKind
2828

2929
if TYPE_CHECKING:
@@ -158,28 +158,7 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
158158

159159
async def run(
160160
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
161-
) -> ModelRequestNode[DepsT, NodeRunEndT]:
162-
return ModelRequestNode[DepsT, NodeRunEndT](request=await self._get_first_message(ctx))
163-
164-
async def _get_first_message(
165-
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
166-
) -> _messages.ModelRequest:
167-
run_context = build_run_context(ctx)
168-
history, next_message = await self._prepare_messages(
169-
self.user_prompt, ctx.state.message_history, ctx.deps.get_instructions, run_context
170-
)
171-
ctx.state.message_history = history
172-
run_context.messages = history
173-
174-
return next_message
175-
176-
async def _prepare_messages(
177-
self,
178-
user_prompt: str | Sequence[_messages.UserContent] | None,
179-
message_history: list[_messages.ModelMessage] | None,
180-
get_instructions: Callable[[RunContext[DepsT]], Awaitable[str | None]],
181-
run_context: RunContext[DepsT],
182-
) -> tuple[list[_messages.ModelMessage], _messages.ModelRequest]:
161+
) -> Union[ModelRequestNode[DepsT, NodeRunEndT], CallToolsNode[DepsT, NodeRunEndT]]: # noqa UP007
183162
try:
184163
ctx_messages = get_captured_run_messages()
185164
except LookupError:
@@ -191,29 +170,48 @@ async def _prepare_messages(
191170
messages = ctx_messages.messages
192171
ctx_messages.used = True
193172

173+
# Add message history to the `capture_run_messages` list, which will be empty at this point
174+
messages.extend(ctx.state.message_history)
175+
# Use the `capture_run_messages` list as the message history so that new messages are added to it
176+
ctx.state.message_history = messages
177+
178+
run_context = build_run_context(ctx)
179+
194180
parts: list[_messages.ModelRequestPart] = []
195-
instructions = await get_instructions(run_context)
196-
if message_history:
197-
# Shallow copy messages
198-
messages.extend(message_history)
181+
if messages:
199182
# Reevaluate any dynamic system prompt parts
200183
await self._reevaluate_dynamic_prompts(messages, run_context)
201184
else:
202185
parts.extend(await self._sys_parts(run_context))
203186

204-
if user_prompt is not None:
205-
parts.append(_messages.UserPromptPart(user_prompt))
206-
elif (
207-
len(parts) == 0
208-
and message_history
209-
and (last_message := message_history[-1])
210-
and isinstance(last_message, _messages.ModelRequest)
211-
):
212-
# Drop last message that came from history and reuse its parts
213-
messages.pop()
214-
parts.extend(last_message.parts)
187+
if messages and (last_message := messages[-1]):
188+
if isinstance(last_message, _messages.ModelRequest) and self.user_prompt is None:
189+
# Drop last message from history and reuse its parts
190+
messages.pop()
191+
parts.extend(last_message.parts)
192+
elif isinstance(last_message, _messages.ModelResponse):
193+
if self.user_prompt is None:
194+
# `CallToolsNode` requires the tool manager to be prepared for the run step
195+
# This will raise errors for any tool name conflicts
196+
ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context)
197+
198+
# Skip ModelRequestNode and go directly to CallToolsNode
199+
return CallToolsNode[DepsT, NodeRunEndT](model_response=last_message)
200+
elif any(isinstance(part, _messages.ToolCallPart) for part in last_message.parts):
201+
raise exceptions.UserError(
202+
'Cannot provide a new user prompt when the message history ends with '
203+
'a model response containing unprocessed tool calls. Either process the '
204+
'tool calls first (by calling `iter` with `user_prompt=None`) or append a '
205+
'`ModelRequest` with `ToolResultPart`s.'
206+
)
207+
208+
if self.user_prompt is not None:
209+
parts.append(_messages.UserPromptPart(self.user_prompt))
210+
211+
instructions = await ctx.deps.get_instructions(run_context)
212+
next_message = _messages.ModelRequest(parts, instructions=instructions)
215213

216-
return messages, _messages.ModelRequest(parts, instructions=instructions)
214+
return ModelRequestNode[DepsT, NodeRunEndT](request=next_message)
217215

218216
async def _reevaluate_dynamic_prompts(
219217
self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT]
@@ -250,11 +248,6 @@ async def _prepare_request_parameters(
250248
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
251249
) -> models.ModelRequestParameters:
252250
"""Build tools and create an agent model."""
253-
run_context = build_run_context(ctx)
254-
255-
# This will raise errors for any tool name conflicts
256-
ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context)
257-
258251
output_schema = ctx.deps.output_schema
259252
output_object = None
260253
if isinstance(output_schema, _output.NativeOutputSchema):
@@ -357,21 +350,21 @@ async def _prepare_request(
357350

358351
run_context = build_run_context(ctx)
359352

360-
model_settings = merge_model_settings(ctx.deps.model_settings, None)
353+
# This will raise errors for any tool name conflicts
354+
ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context)
355+
356+
message_history = await _process_message_history(ctx.state, ctx.deps.history_processors, run_context)
361357

362358
model_request_parameters = await _prepare_request_parameters(ctx)
363359
model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)
364360

365-
message_history = await _process_message_history(ctx.state, ctx.deps.history_processors, run_context)
366-
361+
model_settings = ctx.deps.model_settings
367362
usage = ctx.state.usage
368363
if ctx.deps.usage_limits.count_tokens_before_request:
369364
# Copy to avoid modifying the original usage object with the counted usage
370365
usage = dataclasses.replace(usage)
371366

372-
counted_usage = await ctx.deps.model.count_tokens(
373-
message_history, ctx.deps.model_settings, model_request_parameters
374-
)
367+
counted_usage = await ctx.deps.model.count_tokens(message_history, model_settings, model_request_parameters)
375368
usage.incr(counted_usage)
376369

377370
ctx.deps.usage_limits.check_before_request(usage)

0 commit comments

Comments
 (0)