|
5 | 5 | import inspect
|
6 | 6 | from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
7 | 7 | from contextlib import asynccontextmanager, contextmanager
|
| 8 | +from contextvars import ContextVar |
8 | 9 | from dataclasses import dataclass, field
|
9 | 10 | from types import FrameType
|
10 | 11 | from typing import Any, Callable, Generic, Literal, cast, final, overload
|
11 | 12 |
|
12 | 13 | import logfire_api
|
13 |
| -from typing_extensions import assert_never |
| 14 | +from typing_extensions import assert_never, deprecated |
14 | 15 |
|
15 | 16 | from . import (
|
16 | 17 | _result,
|
|
35 | 36 | ToolPrepareFunc,
|
36 | 37 | )
|
37 | 38 |
|
38 |
| -__all__ = ('Agent',) |
| 39 | +__all__ = 'Agent', 'capture_run_messages', 'EndStrategy' |
39 | 40 |
|
40 | 41 | _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
|
41 | 42 |
|
@@ -89,12 +90,6 @@ class Agent(Generic[AgentDeps, ResultData]):
|
89 | 90 | be merged with this value, with the runtime argument taking priority.
|
90 | 91 | """
|
91 | 92 |
|
92 |
| - last_run_messages: list[_messages.ModelMessage] | None |
93 |
| - """The messages from the last run, useful when a run raised an exception. |
94 |
| -
|
95 |
| - Note: these are not used by the agent, e.g. in future runs, they are just stored for developers' convenience. |
96 |
| - """ |
97 |
| - |
98 | 93 | _result_schema: _result.ResultSchema[ResultData] | None = field(repr=False)
|
99 | 94 | _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False)
|
100 | 95 | _allow_text_result: bool = field(repr=False)
|
@@ -161,7 +156,6 @@ def __init__(
|
161 | 156 | self.end_strategy = end_strategy
|
162 | 157 | self.name = name
|
163 | 158 | self.model_settings = model_settings
|
164 |
| - self.last_run_messages = None |
165 | 159 | self._result_schema = _result.ResultSchema[result_type].build(
|
166 | 160 | result_type, result_tool_name, result_tool_description
|
167 | 161 | )
|
@@ -234,7 +228,7 @@ async def run(
|
234 | 228 | ) as run_span:
|
235 | 229 | run_context = RunContext(deps, 0, [], None, model_used)
|
236 | 230 | messages = await self._prepare_messages(user_prompt, message_history, run_context)
|
237 |
| - self.last_run_messages = run_context.messages = messages |
| 231 | + run_context.messages = messages |
238 | 232 |
|
239 | 233 | for tool in self._function_tools.values():
|
240 | 234 | tool.current_retry = 0
|
@@ -393,7 +387,7 @@ async def main():
|
393 | 387 | ) as run_span:
|
394 | 388 | run_context = RunContext(deps, 0, [], None, model_used)
|
395 | 389 | messages = await self._prepare_messages(user_prompt, message_history, run_context)
|
396 |
| - self.last_run_messages = run_context.messages = messages |
| 390 | + run_context.messages = messages |
397 | 391 |
|
398 | 392 | for tool in self._function_tools.values():
|
399 | 393 | tool.current_retry = 0
|
@@ -614,7 +608,7 @@ async def result_validator_deps(ctx: RunContext[str], data: str) -> str:
|
614 | 608 | #> success (no tool calls)
|
615 | 609 | ```
|
616 | 610 | """
|
617 |
| - self._result_validators.append(_result.ResultValidator(func)) |
| 611 | + self._result_validators.append(_result.ResultValidator[AgentDeps, Any](func)) |
618 | 612 | return func
|
619 | 613 |
|
620 | 614 | @overload
|
@@ -835,14 +829,25 @@ async def add_tool(tool: Tool[AgentDeps]) -> None:
|
835 | 829 | async def _prepare_messages(
|
836 | 830 | self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
|
837 | 831 | ) -> list[_messages.ModelMessage]:
|
| 832 | + try: |
| 833 | + messages = _messages_ctx_var.get() |
| 834 | + except LookupError: |
| 835 | + messages = [] |
| 836 | + else: |
| 837 | + if messages: |
| 838 | + raise exceptions.UserError( |
| 839 | + 'The capture_run_messages() context manager may only be used to wrap ' |
| 840 | + 'one call to run(), run_sync(), or run_stream().' |
| 841 | + ) |
| 842 | + |
838 | 843 | if message_history:
|
839 | 844 | # shallow copy messages
|
840 |
| - messages = message_history.copy() |
| 845 | + messages.extend(message_history) |
841 | 846 | messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)]))
|
842 | 847 | else:
|
843 | 848 | parts = await self._sys_parts(run_context)
|
844 | 849 | parts.append(_messages.UserPromptPart(user_prompt))
|
845 |
| - messages: list[_messages.ModelMessage] = [_messages.ModelRequest(parts)] |
| 850 | + messages.append(_messages.ModelRequest(parts)) |
846 | 851 |
|
847 | 852 | return messages
|
848 | 853 |
|
@@ -1119,6 +1124,51 @@ def _infer_name(self, function_frame: FrameType | None) -> None:
|
1119 | 1124 | self.name = name
|
1120 | 1125 | return
|
1121 | 1126 |
|
| 1127 | + @property |
| 1128 | + @deprecated( |
| 1129 | + 'The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.', category=None |
| 1130 | + ) |
| 1131 | + def last_run_messages(self) -> list[_messages.ModelMessage]: |
| 1132 | + raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.') |
| 1133 | + |
| 1134 | + |
| 1135 | +_messages_ctx_var: ContextVar[list[_messages.ModelMessage]] = ContextVar('var') |
| 1136 | + |
| 1137 | + |
| 1138 | +@contextmanager |
| 1139 | +def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]: |
| 1140 | + """Context manager to access the messages used in a [`run`][pydantic_ai.Agent.run], [`run_sync`][pydantic_ai.Agent.run_sync], or [`run_stream`][pydantic_ai.Agent.run_stream] call. |
| 1141 | +
|
| 1142 | + Useful when a run may raise an exception, see [model errors](../agents.md#model-errors) for more information. |
| 1143 | +
|
| 1144 | + Examples: |
| 1145 | + ```python |
| 1146 | + from pydantic_ai import Agent, capture_run_messages |
| 1147 | +
|
| 1148 | + agent = Agent('test') |
| 1149 | +
|
| 1150 | + with capture_run_messages() as messages: |
| 1151 | + try: |
| 1152 | + result = agent.run_sync('foobar') |
| 1153 | + except Exception: |
| 1154 | + print(messages) |
| 1155 | + raise |
| 1156 | + ``` |
| 1157 | +
|
| 1158 | + !!! note |
| 1159 | + You may not call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context. |
| 1160 | + If you try to do so, a [`UserError`][pydantic_ai.exceptions.UserError] will be raised. |
| 1161 | + """ |
| 1162 | + try: |
| 1163 | + yield _messages_ctx_var.get() |
| 1164 | + except LookupError: |
| 1165 | + messages: list[_messages.ModelMessage] = [] |
| 1166 | + token = _messages_ctx_var.set(messages) |
| 1167 | + try: |
| 1168 | + yield messages |
| 1169 | + finally: |
| 1170 | + _messages_ctx_var.reset(token) |
| 1171 | + |
1122 | 1172 |
|
1123 | 1173 | @dataclass
|
1124 | 1174 | class _MarkFinalResult(Generic[ResultData]):
|
|
0 commit comments