Skip to content

Commit d94931e

Browse files
authored
Make capture_run_messages support nested agent calls (#573)
1 parent fde6c9a commit d94931e

File tree

3 files changed

+67
-35
lines changed

3 files changed

+67
-35
lines changed

docs/agents.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,4 @@ with capture_run_messages() as messages: # (2)!
494494
_(This example is complete, it can be run "as is")_
495495

496496
!!! note
497-
You may not call [`run`][pydantic_ai.Agent.run], [`run_sync`][pydantic_ai.Agent.run_sync], or [`run_stream`][pydantic_ai.Agent.run_stream] more than once within a single `capture_run_messages` context.
498-
499-
If you try to do so, a [`UserError`][pydantic_ai.exceptions.UserError] will be raised.
497+
If you call [`run`][pydantic_ai.Agent.run], [`run_sync`][pydantic_ai.Agent.run_sync], or [`run_stream`][pydantic_ai.Agent.run_stream] more than once within a single `capture_run_messages` context, `messages` will represent the messages exchanged during the first call only.

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
77
from contextlib import asynccontextmanager, contextmanager
88
from contextvars import ContextVar
9-
from dataclasses import dataclass, field
109
from types import FrameType
1110
from typing import Any, Callable, Generic, Literal, cast, final, overload
1211

@@ -60,7 +59,7 @@
6059

6160

6261
@final
63-
@dataclass(init=False)
62+
@dataclasses.dataclass(init=False)
6463
class Agent(Generic[AgentDeps, ResultData]):
6564
"""Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
6665
@@ -100,17 +99,17 @@ class Agent(Generic[AgentDeps, ResultData]):
10099
be merged with this value, with the runtime argument taking priority.
101100
"""
102101

103-
_result_schema: _result.ResultSchema[ResultData] | None = field(repr=False)
104-
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False)
105-
_allow_text_result: bool = field(repr=False)
106-
_system_prompts: tuple[str, ...] = field(repr=False)
107-
_function_tools: dict[str, Tool[AgentDeps]] = field(repr=False)
108-
_default_retries: int = field(repr=False)
109-
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False)
110-
_deps_type: type[AgentDeps] = field(repr=False)
111-
_max_result_retries: int = field(repr=False)
112-
_override_deps: _utils.Option[AgentDeps] = field(default=None, repr=False)
113-
_override_model: _utils.Option[models.Model] = field(default=None, repr=False)
102+
_result_schema: _result.ResultSchema[ResultData] | None = dataclasses.field(repr=False)
103+
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = dataclasses.field(repr=False)
104+
_allow_text_result: bool = dataclasses.field(repr=False)
105+
_system_prompts: tuple[str, ...] = dataclasses.field(repr=False)
106+
_function_tools: dict[str, Tool[AgentDeps]] = dataclasses.field(repr=False)
107+
_default_retries: int = dataclasses.field(repr=False)
108+
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field(repr=False)
109+
_deps_type: type[AgentDeps] = dataclasses.field(repr=False)
110+
_max_result_retries: int = dataclasses.field(repr=False)
111+
_override_deps: _utils.Option[AgentDeps] = dataclasses.field(default=None, repr=False)
112+
_override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False)
114113

115114
def __init__(
116115
self,
@@ -836,15 +835,15 @@ async def _prepare_messages(
836835
self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
837836
) -> list[_messages.ModelMessage]:
838837
try:
839-
messages = _messages_ctx_var.get()
838+
ctx_messages = _messages_ctx_var.get()
840839
except LookupError:
841-
messages = []
840+
messages: list[_messages.ModelMessage] = []
842841
else:
843-
if messages:
844-
raise exceptions.UserError(
845-
'The capture_run_messages() context manager may only be used to wrap '
846-
'one call to run(), run_sync(), or run_stream().'
847-
)
842+
if ctx_messages.used:
843+
messages = []
844+
else:
845+
messages = ctx_messages.messages
846+
ctx_messages.used = True
848847

849848
if message_history:
850849
# shallow copy messages
@@ -1138,7 +1137,13 @@ def last_run_messages(self) -> list[_messages.ModelMessage]:
11381137
raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.')
11391138

11401139

1141-
_messages_ctx_var: ContextVar[list[_messages.ModelMessage]] = ContextVar('var')
1140+
@dataclasses.dataclass
1141+
class _RunMessages:
1142+
messages: list[_messages.ModelMessage]
1143+
used: bool = False
1144+
1145+
1146+
_messages_ctx_var: ContextVar[_RunMessages] = ContextVar('var')
11421147

11431148

11441149
@contextmanager
@@ -1162,21 +1167,21 @@ def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
11621167
```
11631168
11641169
!!! note
1165-
You may not call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context.
1166-
If you try to do so, a [`UserError`][pydantic_ai.exceptions.UserError] will be raised.
1170+
If you call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context,
1171+
`messages` will represent the messages exchanged during the first call only.
11671172
"""
11681173
try:
1169-
yield _messages_ctx_var.get()
1174+
yield _messages_ctx_var.get().messages
11701175
except LookupError:
11711176
messages: list[_messages.ModelMessage] = []
1172-
token = _messages_ctx_var.set(messages)
1177+
token = _messages_ctx_var.set(_RunMessages(messages))
11731178
try:
11741179
yield messages
11751180
finally:
11761181
_messages_ctx_var.reset(token)
11771182

11781183

1179-
@dataclass
1184+
@dataclasses.dataclass
11801185
class _MarkFinalResult(Generic[ResultData]):
11811186
"""Marker class to indicate that the result is the final result.
11821187

tests/test_agent.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,15 +1218,44 @@ def test_double_capture_run_messages(set_event_loop: None) -> None:
12181218
assert messages == []
12191219
result = agent.run_sync('Hello')
12201220
assert result.data == 'success (no tool calls)'
1221-
with pytest.raises(UserError) as exc_info:
1222-
agent.run_sync('Hello')
1223-
assert (
1224-
str(exc_info.value)
1225-
== 'The capture_run_messages() context manager may only be used to wrap one call to run(), run_sync(), or run_stream().'
1226-
)
1221+
result2 = agent.run_sync('Hello 2')
1222+
assert result2.data == 'success (no tool calls)'
1223+
12271224
assert messages == snapshot(
12281225
[
12291226
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
12301227
ModelResponse(parts=[TextPart(content='success (no tool calls)')], timestamp=IsNow(tz=timezone.utc)),
12311228
]
12321229
)
1230+
1231+
1232+
def test_capture_run_messages_tool_agent(set_event_loop: None) -> None:
1233+
agent_outer = Agent('test')
1234+
agent_inner = Agent(TestModel(custom_result_text='inner agent result'))
1235+
1236+
@agent_outer.tool_plain
1237+
async def foobar(x: str) -> str:
1238+
result_ = await agent_inner.run(x)
1239+
return result_.data
1240+
1241+
with capture_run_messages() as messages:
1242+
result = agent_outer.run_sync('foobar')
1243+
1244+
assert result.data == snapshot('{"foobar":"inner agent result"}')
1245+
assert messages == snapshot(
1246+
[
1247+
ModelRequest(parts=[UserPromptPart(content='foobar', timestamp=IsNow(tz=timezone.utc))]),
1248+
ModelResponse(
1249+
parts=[ToolCallPart(tool_name='foobar', args=ArgsDict(args_dict={'x': 'a'}))],
1250+
timestamp=IsNow(tz=timezone.utc),
1251+
),
1252+
ModelRequest(
1253+
parts=[
1254+
ToolReturnPart(tool_name='foobar', content='inner agent result', timestamp=IsNow(tz=timezone.utc))
1255+
]
1256+
),
1257+
ModelResponse(
1258+
parts=[TextPart(content='{"foobar":"inner agent result"}')], timestamp=IsNow(tz=timezone.utc)
1259+
),
1260+
]
1261+
)

0 commit comments

Comments
 (0)