6
6
from collections .abc import AsyncIterator , Awaitable , Iterator , Sequence
7
7
from contextlib import asynccontextmanager , contextmanager
8
8
from contextvars import ContextVar
9
- from dataclasses import dataclass , field
10
9
from types import FrameType
11
10
from typing import Any , Callable , Generic , Literal , cast , final , overload
12
11
60
59
61
60
62
61
@final
63
- @dataclass (init = False )
62
+ @dataclasses . dataclass (init = False )
64
63
class Agent (Generic [AgentDeps , ResultData ]):
65
64
"""Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
66
65
@@ -100,17 +99,17 @@ class Agent(Generic[AgentDeps, ResultData]):
100
99
be merged with this value, with the runtime argument taking priority.
101
100
"""
102
101
103
- _result_schema : _result .ResultSchema [ResultData ] | None = field (repr = False )
104
- _result_validators : list [_result .ResultValidator [AgentDeps , ResultData ]] = field (repr = False )
105
- _allow_text_result : bool = field (repr = False )
106
- _system_prompts : tuple [str , ...] = field (repr = False )
107
- _function_tools : dict [str , Tool [AgentDeps ]] = field (repr = False )
108
- _default_retries : int = field (repr = False )
109
- _system_prompt_functions : list [_system_prompt .SystemPromptRunner [AgentDeps ]] = field (repr = False )
110
- _deps_type : type [AgentDeps ] = field (repr = False )
111
- _max_result_retries : int = field (repr = False )
112
- _override_deps : _utils .Option [AgentDeps ] = field (default = None , repr = False )
113
- _override_model : _utils .Option [models .Model ] = field (default = None , repr = False )
102
+ _result_schema : _result .ResultSchema [ResultData ] | None = dataclasses . field (repr = False )
103
+ _result_validators : list [_result .ResultValidator [AgentDeps , ResultData ]] = dataclasses . field (repr = False )
104
+ _allow_text_result : bool = dataclasses . field (repr = False )
105
+ _system_prompts : tuple [str , ...] = dataclasses . field (repr = False )
106
+ _function_tools : dict [str , Tool [AgentDeps ]] = dataclasses . field (repr = False )
107
+ _default_retries : int = dataclasses . field (repr = False )
108
+ _system_prompt_functions : list [_system_prompt .SystemPromptRunner [AgentDeps ]] = dataclasses . field (repr = False )
109
+ _deps_type : type [AgentDeps ] = dataclasses . field (repr = False )
110
+ _max_result_retries : int = dataclasses . field (repr = False )
111
+ _override_deps : _utils .Option [AgentDeps ] = dataclasses . field (default = None , repr = False )
112
+ _override_model : _utils .Option [models .Model ] = dataclasses . field (default = None , repr = False )
114
113
115
114
def __init__ (
116
115
self ,
@@ -836,15 +835,15 @@ async def _prepare_messages(
836
835
self , user_prompt : str , message_history : list [_messages .ModelMessage ] | None , run_context : RunContext [AgentDeps ]
837
836
) -> list [_messages .ModelMessage ]:
838
837
try :
839
- messages = _messages_ctx_var .get ()
838
+ ctx_messages = _messages_ctx_var .get ()
840
839
except LookupError :
841
- messages = []
840
+ messages : list [ _messages . ModelMessage ] = []
842
841
else :
843
- if messages :
844
- raise exceptions . UserError (
845
- 'The capture_run_messages() context manager may only be used to wrap '
846
- 'one call to run(), run_sync(), or run_stream().'
847
- )
842
+ if ctx_messages . used :
843
+ messages = []
844
+ else :
845
+ messages = ctx_messages . messages
846
+ ctx_messages . used = True
848
847
849
848
if message_history :
850
849
# shallow copy messages
@@ -1138,7 +1137,13 @@ def last_run_messages(self) -> list[_messages.ModelMessage]:
1138
1137
raise AttributeError ('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.' )
1139
1138
1140
1139
1141
- _messages_ctx_var : ContextVar [list [_messages .ModelMessage ]] = ContextVar ('var' )
1140
+ @dataclasses .dataclass
1141
+ class _RunMessages :
1142
+ messages : list [_messages .ModelMessage ]
1143
+ used : bool = False
1144
+
1145
+
1146
+ _messages_ctx_var : ContextVar [_RunMessages ] = ContextVar ('var' )
1142
1147
1143
1148
1144
1149
@contextmanager
@@ -1162,21 +1167,21 @@ def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
1162
1167
```
1163
1168
1164
1169
!!! note
1165
- You may not call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context.
1166
- If you try to do so, a [`UserError`][pydantic_ai.exceptions.UserError] will be raised .
1170
+ If you call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context,
1171
+ `messages` will represent the messages exchanged during the first call only .
1167
1172
"""
1168
1173
try :
1169
- yield _messages_ctx_var .get ()
1174
+ yield _messages_ctx_var .get (). messages
1170
1175
except LookupError :
1171
1176
messages : list [_messages .ModelMessage ] = []
1172
- token = _messages_ctx_var .set (messages )
1177
+ token = _messages_ctx_var .set (_RunMessages ( messages ) )
1173
1178
try :
1174
1179
yield messages
1175
1180
finally :
1176
1181
_messages_ctx_var .reset (token )
1177
1182
1178
1183
1179
- @dataclass
1184
+ @dataclasses . dataclass
1180
1185
class _MarkFinalResult (Generic [ResultData ]):
1181
1186
"""Marker class to indicate that the result is the final result.
1182
1187
0 commit comments