60
60
- `'early'`: Stop processing other tool calls once a final result is found
61
61
- `'exhaustive'`: Process all tool calls even after finding a final result
62
62
"""
63
- RunResultData = TypeVar ('RunResultData ' )
63
+ RunResultDataT = TypeVar ('RunResultDataT ' )
64
64
"""Type variable for the result data of a run where `result_type` was customized on the run call."""
65
65
66
66
@@ -214,15 +214,15 @@ async def run(
214
214
self ,
215
215
user_prompt : str ,
216
216
* ,
217
- result_type : type [RunResultData ],
217
+ result_type : type [RunResultDataT ],
218
218
message_history : list [_messages .ModelMessage ] | None = None ,
219
219
model : models .Model | models .KnownModelName | None = None ,
220
220
deps : AgentDepsT = None ,
221
221
model_settings : ModelSettings | None = None ,
222
222
usage_limits : _usage .UsageLimits | None = None ,
223
223
usage : _usage .Usage | None = None ,
224
224
infer_name : bool = True ,
225
- ) -> result .RunResult [RunResultData ]: ...
225
+ ) -> result .RunResult [RunResultDataT ]: ...
226
226
227
227
async def run (
228
228
self ,
@@ -234,7 +234,7 @@ async def run(
234
234
model_settings : ModelSettings | None = None ,
235
235
usage_limits : _usage .UsageLimits | None = None ,
236
236
usage : _usage .Usage | None = None ,
237
- result_type : type [RunResultData ] | None = None ,
237
+ result_type : type [RunResultDataT ] | None = None ,
238
238
infer_name : bool = True ,
239
239
) -> result .RunResult [Any ]:
240
240
"""Run the agent with a user prompt in async mode.
@@ -352,21 +352,21 @@ def run_sync(
352
352
self ,
353
353
user_prompt : str ,
354
354
* ,
355
- result_type : type [RunResultData ] | None ,
355
+ result_type : type [RunResultDataT ] | None ,
356
356
message_history : list [_messages .ModelMessage ] | None = None ,
357
357
model : models .Model | models .KnownModelName | None = None ,
358
358
deps : AgentDepsT = None ,
359
359
model_settings : ModelSettings | None = None ,
360
360
usage_limits : _usage .UsageLimits | None = None ,
361
361
usage : _usage .Usage | None = None ,
362
362
infer_name : bool = True ,
363
- ) -> result .RunResult [RunResultData ]: ...
363
+ ) -> result .RunResult [RunResultDataT ]: ...
364
364
365
365
def run_sync (
366
366
self ,
367
367
user_prompt : str ,
368
368
* ,
369
- result_type : type [RunResultData ] | None = None ,
369
+ result_type : type [RunResultDataT ] | None = None ,
370
370
message_history : list [_messages .ModelMessage ] | None = None ,
371
371
model : models .Model | models .KnownModelName | None = None ,
372
372
deps : AgentDepsT = None ,
@@ -442,22 +442,22 @@ def run_stream(
442
442
self ,
443
443
user_prompt : str ,
444
444
* ,
445
- result_type : type [RunResultData ],
445
+ result_type : type [RunResultDataT ],
446
446
message_history : list [_messages .ModelMessage ] | None = None ,
447
447
model : models .Model | models .KnownModelName | None = None ,
448
448
deps : AgentDepsT = None ,
449
449
model_settings : ModelSettings | None = None ,
450
450
usage_limits : _usage .UsageLimits | None = None ,
451
451
usage : _usage .Usage | None = None ,
452
452
infer_name : bool = True ,
453
- ) -> AbstractAsyncContextManager [result .StreamedRunResult [AgentDepsT , RunResultData ]]: ...
453
+ ) -> AbstractAsyncContextManager [result .StreamedRunResult [AgentDepsT , RunResultDataT ]]: ...
454
454
455
455
@asynccontextmanager
456
456
async def run_stream (
457
457
self ,
458
458
user_prompt : str ,
459
459
* ,
460
- result_type : type [RunResultData ] | None = None ,
460
+ result_type : type [RunResultDataT ] | None = None ,
461
461
message_history : list [_messages .ModelMessage ] | None = None ,
462
462
model : models .Model | models .KnownModelName | None = None ,
463
463
deps : AgentDepsT = None ,
@@ -572,7 +572,7 @@ async def on_complete():
572
572
# there are result validators that might convert the result data from an overridden
573
573
# `result_type` to a type that is not valid as such.
574
574
result_validators = cast (
575
- list [_result .ResultValidator [AgentDepsT , RunResultData ]], self ._result_validators
575
+ list [_result .ResultValidator [AgentDepsT , RunResultDataT ]], self ._result_validators
576
576
)
577
577
578
578
yield result .StreamedRunResult (
@@ -999,7 +999,7 @@ async def _get_model(self, model: models.Model | models.KnownModelName | None) -
999
999
return model_
1000
1000
1001
1001
async def _prepare_model (
1002
- self , run_context : RunContext [AgentDepsT ], result_schema : _result .ResultSchema [RunResultData ] | None
1002
+ self , run_context : RunContext [AgentDepsT ], result_schema : _result .ResultSchema [RunResultDataT ] | None
1003
1003
) -> models .AgentModel :
1004
1004
"""Build tools and create an agent model."""
1005
1005
function_tools : list [ToolDefinition ] = []
@@ -1035,8 +1035,8 @@ async def _reevaluate_dynamic_prompts(
1035
1035
)
1036
1036
1037
1037
def _prepare_result_schema (
1038
- self , result_type : type [RunResultData ] | None
1039
- ) -> _result .ResultSchema [RunResultData ] | None :
1038
+ self , result_type : type [RunResultDataT ] | None
1039
+ ) -> _result .ResultSchema [RunResultDataT ] | None :
1040
1040
if result_type is not None :
1041
1041
if self ._result_validators :
1042
1042
raise exceptions .UserError ('Cannot set a custom run `result_type` when the agent has result validators' )
@@ -1053,7 +1053,7 @@ async def _prepare_messages(
1053
1053
run_context : RunContext [AgentDepsT ],
1054
1054
) -> list [_messages .ModelMessage ]:
1055
1055
try :
1056
- ctx_messages = _messages_ctx_var . get ()
1056
+ ctx_messages = get_captured_run_messages ()
1057
1057
except LookupError :
1058
1058
messages : list [_messages .ModelMessage ] = []
1059
1059
else :
@@ -1080,8 +1080,8 @@ async def _handle_model_response(
1080
1080
self ,
1081
1081
model_response : _messages .ModelResponse ,
1082
1082
run_context : RunContext [AgentDepsT ],
1083
- result_schema : _result .ResultSchema [RunResultData ] | None ,
1084
- ) -> tuple [_MarkFinalResult [RunResultData ] | None , list [_messages .ModelRequestPart ]]:
1083
+ result_schema : _result .ResultSchema [RunResultDataT ] | None ,
1084
+ ) -> tuple [_MarkFinalResult [RunResultDataT ] | None , list [_messages .ModelRequestPart ]]:
1085
1085
"""Process a non-streamed response from the model.
1086
1086
1087
1087
Returns:
@@ -1110,11 +1110,11 @@ async def _handle_model_response(
1110
1110
raise exceptions .UnexpectedModelBehavior ('Received empty model response' )
1111
1111
1112
1112
async def _handle_text_response (
1113
- self , text : str , run_context : RunContext [AgentDepsT ], result_schema : _result .ResultSchema [RunResultData ] | None
1114
- ) -> tuple [_MarkFinalResult [RunResultData ] | None , list [_messages .ModelRequestPart ]]:
1113
+ self , text : str , run_context : RunContext [AgentDepsT ], result_schema : _result .ResultSchema [RunResultDataT ] | None
1114
+ ) -> tuple [_MarkFinalResult [RunResultDataT ] | None , list [_messages .ModelRequestPart ]]:
1115
1115
"""Handle a plain text response from the model for non-streaming responses."""
1116
1116
if self ._allow_text_result (result_schema ):
1117
- result_data_input = cast (RunResultData , text )
1117
+ result_data_input = cast (RunResultDataT , text )
1118
1118
try :
1119
1119
result_data = await self ._validate_result (result_data_input , run_context , None )
1120
1120
except _result .ToolRetryError as e :
@@ -1133,13 +1133,13 @@ async def _handle_structured_response(
1133
1133
self ,
1134
1134
tool_calls : list [_messages .ToolCallPart ],
1135
1135
run_context : RunContext [AgentDepsT ],
1136
- result_schema : _result .ResultSchema [RunResultData ] | None ,
1137
- ) -> tuple [_MarkFinalResult [RunResultData ] | None , list [_messages .ModelRequestPart ]]:
1136
+ result_schema : _result .ResultSchema [RunResultDataT ] | None ,
1137
+ ) -> tuple [_MarkFinalResult [RunResultDataT ] | None , list [_messages .ModelRequestPart ]]:
1138
1138
"""Handle a structured response containing tool calls from the model for non-streaming responses."""
1139
1139
assert tool_calls , 'Expected at least one tool call'
1140
1140
1141
1141
# first look for the result tool call
1142
- final_result : _MarkFinalResult [RunResultData ] | None = None
1142
+ final_result : _MarkFinalResult [RunResultDataT ] | None = None
1143
1143
1144
1144
parts : list [_messages .ModelRequestPart ] = []
1145
1145
if result_schema is not None :
@@ -1168,7 +1168,7 @@ async def _process_function_tools(
1168
1168
tool_calls : list [_messages .ToolCallPart ],
1169
1169
result_tool_name : str | None ,
1170
1170
run_context : RunContext [AgentDepsT ],
1171
- result_schema : _result .ResultSchema [RunResultData ] | None ,
1171
+ result_schema : _result .ResultSchema [RunResultDataT ] | None ,
1172
1172
) -> list [_messages .ModelRequestPart ]:
1173
1173
"""Process function (non-result) tool calls in parallel.
1174
1174
@@ -1227,7 +1227,7 @@ async def _handle_streamed_response(
1227
1227
self ,
1228
1228
streamed_response : models .StreamedResponse ,
1229
1229
run_context : RunContext [AgentDepsT ],
1230
- result_schema : _result .ResultSchema [RunResultData ] | None ,
1230
+ result_schema : _result .ResultSchema [RunResultDataT ] | None ,
1231
1231
) -> _MarkFinalResult [models .StreamedResponse ] | tuple [_messages .ModelResponse , list [_messages .ModelRequestPart ]]:
1232
1232
"""Process a streamed response from the model.
1233
1233
@@ -1282,15 +1282,15 @@ async def _handle_streamed_response(
1282
1282
1283
1283
async def _validate_result (
1284
1284
self ,
1285
- result_data : RunResultData ,
1285
+ result_data : RunResultDataT ,
1286
1286
run_context : RunContext [AgentDepsT ],
1287
1287
tool_call : _messages .ToolCallPart | None ,
1288
- ) -> RunResultData :
1288
+ ) -> RunResultDataT :
1289
1289
if self ._result_validators :
1290
1290
agent_result_data = cast (ResultDataT , result_data )
1291
1291
for validator in self ._result_validators :
1292
1292
agent_result_data = await validator .validate (agent_result_data , tool_call , run_context )
1293
- return cast (RunResultData , agent_result_data )
1293
+ return cast (RunResultDataT , agent_result_data )
1294
1294
else :
1295
1295
return result_data
1296
1296
@@ -1315,7 +1315,7 @@ async def _sys_parts(self, run_context: RunContext[AgentDepsT]) -> list[_message
1315
1315
def _unknown_tool (
1316
1316
self ,
1317
1317
tool_name : str ,
1318
- result_schema : _result .ResultSchema [RunResultData ] | None ,
1318
+ result_schema : _result .ResultSchema [RunResultDataT ] | None ,
1319
1319
) -> _messages .RetryPromptPart :
1320
1320
names = list (self ._function_tools .keys ())
1321
1321
if result_schema :
@@ -1358,7 +1358,7 @@ def _infer_name(self, function_frame: FrameType | None) -> None:
1358
1358
return
1359
1359
1360
1360
@staticmethod
1361
- def _allow_text_result (result_schema : _result .ResultSchema [RunResultData ] | None ) -> bool :
1361
+ def _allow_text_result (result_schema : _result .ResultSchema [RunResultDataT ] | None ) -> bool :
1362
1362
return result_schema is None or result_schema .allow_text_result
1363
1363
1364
1364
@property
@@ -1413,6 +1413,10 @@ def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
1413
1413
_messages_ctx_var .reset (token )
1414
1414
1415
1415
1416
+ def get_captured_run_messages () -> _RunMessages :
1417
+ return _messages_ctx_var .get ()
1418
+
1419
+
1416
1420
@dataclasses .dataclass
1417
1421
class _MarkFinalResult (Generic [ResultDataT ]):
1418
1422
"""Marker class to indicate that the result is the final result.
0 commit comments