23
23
from . import _output , _system_prompt , exceptions , messages as _messages , models , result , usage as _usage
24
24
from .exceptions import ToolRetryError
25
25
from .output import OutputDataT , OutputSpec
26
- from .settings import ModelSettings , merge_model_settings
26
+ from .settings import ModelSettings
27
27
from .tools import RunContext , ToolDefinition , ToolKind
28
28
29
29
if TYPE_CHECKING :
@@ -158,28 +158,7 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
158
158
159
159
async def run (
160
160
self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]]
161
- ) -> ModelRequestNode [DepsT , NodeRunEndT ]:
162
- return ModelRequestNode [DepsT , NodeRunEndT ](request = await self ._get_first_message (ctx ))
163
-
164
- async def _get_first_message (
165
- self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]]
166
- ) -> _messages .ModelRequest :
167
- run_context = build_run_context (ctx )
168
- history , next_message = await self ._prepare_messages (
169
- self .user_prompt , ctx .state .message_history , ctx .deps .get_instructions , run_context
170
- )
171
- ctx .state .message_history = history
172
- run_context .messages = history
173
-
174
- return next_message
175
-
176
- async def _prepare_messages (
177
- self ,
178
- user_prompt : str | Sequence [_messages .UserContent ] | None ,
179
- message_history : list [_messages .ModelMessage ] | None ,
180
- get_instructions : Callable [[RunContext [DepsT ]], Awaitable [str | None ]],
181
- run_context : RunContext [DepsT ],
182
- ) -> tuple [list [_messages .ModelMessage ], _messages .ModelRequest ]:
161
+ ) -> Union [ModelRequestNode [DepsT , NodeRunEndT ], CallToolsNode [DepsT , NodeRunEndT ]]: # noqa UP007
183
162
try :
184
163
ctx_messages = get_captured_run_messages ()
185
164
except LookupError :
@@ -191,29 +170,48 @@ async def _prepare_messages(
191
170
messages = ctx_messages .messages
192
171
ctx_messages .used = True
193
172
173
+ # Add message history to the `capture_run_messages` list, which will be empty at this point
174
+ messages .extend (ctx .state .message_history )
175
+ # Use the `capture_run_messages` list as the message history so that new messages are added to it
176
+ ctx .state .message_history = messages
177
+
178
+ run_context = build_run_context (ctx )
179
+
194
180
parts : list [_messages .ModelRequestPart ] = []
195
- instructions = await get_instructions (run_context )
196
- if message_history :
197
- # Shallow copy messages
198
- messages .extend (message_history )
181
+ if messages :
199
182
# Reevaluate any dynamic system prompt parts
200
183
await self ._reevaluate_dynamic_prompts (messages , run_context )
201
184
else :
202
185
parts .extend (await self ._sys_parts (run_context ))
203
186
204
- if user_prompt is not None :
205
- parts .append (_messages .UserPromptPart (user_prompt ))
206
- elif (
207
- len (parts ) == 0
208
- and message_history
209
- and (last_message := message_history [- 1 ])
210
- and isinstance (last_message , _messages .ModelRequest )
211
- ):
212
- # Drop last message that came from history and reuse its parts
213
- messages .pop ()
214
- parts .extend (last_message .parts )
187
+ if messages and (last_message := messages [- 1 ]):
188
+ if isinstance (last_message , _messages .ModelRequest ) and self .user_prompt is None :
189
+ # Drop last message from history and reuse its parts
190
+ messages .pop ()
191
+ parts .extend (last_message .parts )
192
+ elif isinstance (last_message , _messages .ModelResponse ):
193
+ if self .user_prompt is None :
194
+ # `CallToolsNode` requires the tool manager to be prepared for the run step
195
+ # This will raise errors for any tool name conflicts
196
+ ctx .deps .tool_manager = await ctx .deps .tool_manager .for_run_step (run_context )
197
+
198
+ # Skip ModelRequestNode and go directly to CallToolsNode
199
+ return CallToolsNode [DepsT , NodeRunEndT ](model_response = last_message )
200
+ elif any (isinstance (part , _messages .ToolCallPart ) for part in last_message .parts ):
201
+ raise exceptions .UserError (
202
+ 'Cannot provide a new user prompt when the message history ends with '
203
+ 'a model response containing unprocessed tool calls. Either process the '
204
+ 'tool calls first (by calling `iter` with `user_prompt=None`) or append a '
205
+ '`ModelRequest` with `ToolResultPart`s.'
206
+ )
207
+
208
+ if self .user_prompt is not None :
209
+ parts .append (_messages .UserPromptPart (self .user_prompt ))
210
+
211
+ instructions = await ctx .deps .get_instructions (run_context )
212
+ next_message = _messages .ModelRequest (parts , instructions = instructions )
215
213
216
- return messages , _messages . ModelRequest ( parts , instructions = instructions )
214
+ return ModelRequestNode [ DepsT , NodeRunEndT ]( request = next_message )
217
215
218
216
async def _reevaluate_dynamic_prompts (
219
217
self , messages : list [_messages .ModelMessage ], run_context : RunContext [DepsT ]
@@ -250,11 +248,6 @@ async def _prepare_request_parameters(
250
248
ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]],
251
249
) -> models .ModelRequestParameters :
252
250
"""Build tools and create an agent model."""
253
- run_context = build_run_context (ctx )
254
-
255
- # This will raise errors for any tool name conflicts
256
- ctx .deps .tool_manager = await ctx .deps .tool_manager .for_run_step (run_context )
257
-
258
251
output_schema = ctx .deps .output_schema
259
252
output_object = None
260
253
if isinstance (output_schema , _output .NativeOutputSchema ):
@@ -357,21 +350,21 @@ async def _prepare_request(
357
350
358
351
run_context = build_run_context (ctx )
359
352
360
- model_settings = merge_model_settings (ctx .deps .model_settings , None )
353
+ # This will raise errors for any tool name conflicts
354
+ ctx .deps .tool_manager = await ctx .deps .tool_manager .for_run_step (run_context )
355
+
356
+ message_history = await _process_message_history (ctx .state , ctx .deps .history_processors , run_context )
361
357
362
358
model_request_parameters = await _prepare_request_parameters (ctx )
363
359
model_request_parameters = ctx .deps .model .customize_request_parameters (model_request_parameters )
364
360
365
- message_history = await _process_message_history (ctx .state , ctx .deps .history_processors , run_context )
366
-
361
+ model_settings = ctx .deps .model_settings
367
362
usage = ctx .state .usage
368
363
if ctx .deps .usage_limits .count_tokens_before_request :
369
364
# Copy to avoid modifying the original usage object with the counted usage
370
365
usage = dataclasses .replace (usage )
371
366
372
- counted_usage = await ctx .deps .model .count_tokens (
373
- message_history , ctx .deps .model_settings , model_request_parameters
374
- )
367
+ counted_usage = await ctx .deps .model .count_tokens (message_history , model_settings , model_request_parameters )
375
368
usage .incr (counted_usage )
376
369
377
370
ctx .deps .usage_limits .check_before_request (usage )
0 commit comments