Skip to content

Commit e454bcd

Browse files
use RunContext more widely (#500)
Co-authored-by: sydney-runkle <[email protected]>
1 parent 3b7dc13 commit e454bcd

File tree

5 files changed

+79
-83
lines changed

5 files changed

+79
-83
lines changed

pydantic_ai_slim/pydantic_ai/_result.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,22 @@ def __post_init__(self):
2929
async def validate(
3030
self,
3131
result: ResultData,
32-
deps: AgentDeps,
33-
retry: int,
3432
tool_call: _messages.ToolCallPart | None,
35-
messages: list[_messages.ModelMessage],
33+
run_context: RunContext[AgentDeps],
3634
) -> ResultData:
3735
"""Validate a result but calling the function.
3836
3937
Args:
4038
result: The result data after Pydantic validation the message content.
41-
deps: The agent dependencies.
42-
retry: The current retry number.
4339
tool_call: The original tool call message, `None` if there was no tool call.
44-
messages: The messages exchanged so far in the conversation.
40+
run_context: The current run context.
4541
4642
Returns:
4743
Result of either the validated result data (ok) or a retry message (Err).
4844
"""
4945
if self._takes_ctx:
50-
args = RunContext(deps, retry, messages, tool_call.tool_name if tool_call else None), result
46+
ctx = run_context.replace_with(tool_name=tool_call.tool_name if tool_call else None)
47+
args = ctx, result
5148
else:
5249
args = (result,)
5350

pydantic_ai_slim/pydantic_ai/_system_prompt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ def __post_init__(self):
1919
self._takes_ctx = len(inspect.signature(self.function).parameters) > 0
2020
self._is_async = inspect.iscoroutinefunction(self.function)
2121

22-
async def run(self, deps: AgentDeps) -> str:
22+
async def run(self, run_context: RunContext[AgentDeps]) -> str:
2323
if self._takes_ctx:
24-
args = (RunContext(deps, 0, [], None),)
24+
args = (run_context,)
2525
else:
2626
args = ()
2727

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 45 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ class Agent(Generic[AgentDeps, ResultData]):
104104
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False)
105105
_deps_type: type[AgentDeps] = field(repr=False)
106106
_max_result_retries: int = field(repr=False)
107-
_current_result_retry: int = field(repr=False)
108107
_override_deps: _utils.Option[AgentDeps] = field(default=None, repr=False)
109108
_override_model: _utils.Option[models.Model] = field(default=None, repr=False)
110109

@@ -180,7 +179,6 @@ def __init__(
180179
self._deps_type = deps_type
181180
self._system_prompt_functions = []
182181
self._max_result_retries = result_retries if result_retries is not None else retries
183-
self._current_result_retry = 0
184182
self._result_validators = []
185183

186184
async def run(
@@ -234,7 +232,9 @@ async def run(
234232
model_name=model_used.name(),
235233
agent_name=self.name or 'agent',
236234
) as run_span:
237-
self.last_run_messages = messages = await self._prepare_messages(deps, user_prompt, message_history)
235+
run_context = RunContext(deps, 0, [], None, model_used)
236+
messages = await self._prepare_messages(user_prompt, message_history, run_context)
237+
self.last_run_messages = run_context.messages = messages
238238

239239
for tool in self._function_tools.values():
240240
tool.current_retry = 0
@@ -249,7 +249,7 @@ async def run(
249249

250250
run_step += 1
251251
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
252-
agent_model = await self._prepare_model(model_used, deps, messages)
252+
agent_model = await self._prepare_model(run_context)
253253

254254
with _logfire.span('model request', run_step=run_step) as model_req_span:
255255
model_response, request_usage = await agent_model.request(messages, model_settings)
@@ -262,7 +262,7 @@ async def run(
262262
usage_limits.check_tokens(request_usage)
263263

264264
with _logfire.span('handle model response', run_step=run_step) as handle_span:
265-
final_result, tool_responses = await self._handle_model_response(model_response, deps, messages)
265+
final_result, tool_responses = await self._handle_model_response(model_response, run_context)
266266

267267
if tool_responses:
268268
# Add parts to the conversation as a new message
@@ -391,7 +391,9 @@ async def main():
391391
model_name=model_used.name(),
392392
agent_name=self.name or 'agent',
393393
) as run_span:
394-
self.last_run_messages = messages = await self._prepare_messages(deps, user_prompt, message_history)
394+
run_context = RunContext(deps, 0, [], None, model_used)
395+
messages = await self._prepare_messages(user_prompt, message_history, run_context)
396+
self.last_run_messages = run_context.messages = messages
395397

396398
for tool in self._function_tools.values():
397399
tool.current_retry = 0
@@ -406,7 +408,7 @@ async def main():
406408
usage_limits.check_before_request(usage)
407409

408410
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
409-
agent_model = await self._prepare_model(model_used, deps, messages)
411+
agent_model = await self._prepare_model(run_context)
410412

411413
with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
412414
async with agent_model.request_stream(messages, model_settings) as model_response:
@@ -417,9 +419,7 @@ async def main():
417419
model_req_span.__exit__(None, None, None)
418420

419421
with _logfire.span('handle model response') as handle_span:
420-
maybe_final_result = await self._handle_streamed_model_response(
421-
model_response, deps, messages
422-
)
422+
maybe_final_result = await self._handle_streamed_model_response(model_response, run_context)
423423

424424
# Check if we got a final result
425425
if isinstance(maybe_final_result, _MarkFinalResult):
@@ -439,7 +439,7 @@ async def on_complete():
439439
part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)
440440
]
441441
parts = await self._process_function_tools(
442-
tool_calls, result_tool_name, deps, messages
442+
tool_calls, result_tool_name, run_context
443443
)
444444
if parts:
445445
messages.append(_messages.ModelRequest(parts))
@@ -452,7 +452,7 @@ async def on_complete():
452452
usage_limits,
453453
result_stream,
454454
self._result_schema,
455-
deps,
455+
run_context,
456456
self._result_validators,
457457
result_tool_name,
458458
on_complete,
@@ -815,41 +815,39 @@ async def _get_model(self, model: models.Model | models.KnownModelName | None) -
815815

816816
return model_, mode_selection
817817

818-
async def _prepare_model(
819-
self, model: models.Model, deps: AgentDeps, messages: list[_messages.ModelMessage]
820-
) -> models.AgentModel:
821-
"""Create building tools and create an agent model."""
818+
async def _prepare_model(self, run_context: RunContext[AgentDeps]) -> models.AgentModel:
819+
"""Build tools and create an agent model."""
822820
function_tools: list[ToolDefinition] = []
823821

824822
async def add_tool(tool: Tool[AgentDeps]) -> None:
825-
ctx = RunContext(deps, tool.current_retry, messages, tool.name)
823+
ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
826824
if tool_def := await tool.prepare_tool_def(ctx):
827825
function_tools.append(tool_def)
828826

829827
await asyncio.gather(*map(add_tool, self._function_tools.values()))
830828

831-
return await model.agent_model(
829+
return await run_context.model.agent_model(
832830
function_tools=function_tools,
833831
allow_text_result=self._allow_text_result,
834832
result_tools=self._result_schema.tool_defs() if self._result_schema is not None else [],
835833
)
836834

837835
async def _prepare_messages(
838-
self, deps: AgentDeps, user_prompt: str, message_history: list[_messages.ModelMessage] | None
836+
self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
839837
) -> list[_messages.ModelMessage]:
840838
if message_history:
841839
# shallow copy messages
842840
messages = message_history.copy()
843841
messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)]))
844842
else:
845-
parts = await self._sys_parts(deps)
843+
parts = await self._sys_parts(run_context)
846844
parts.append(_messages.UserPromptPart(user_prompt))
847845
messages: list[_messages.ModelMessage] = [_messages.ModelRequest(parts)]
848846

849847
return messages
850848

851849
async def _handle_model_response(
852-
self, model_response: _messages.ModelResponse, deps: AgentDeps, conv_messages: list[_messages.ModelMessage]
850+
self, model_response: _messages.ModelResponse, run_context: RunContext[AgentDeps]
853851
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
854852
"""Process a non-streamed response from the model.
855853
@@ -868,34 +866,34 @@ async def _handle_model_response(
868866

869867
if texts:
870868
text = '\n\n'.join(texts)
871-
return await self._handle_text_response(text, deps, conv_messages)
869+
return await self._handle_text_response(text, run_context)
872870
elif tool_calls:
873-
return await self._handle_structured_response(tool_calls, deps, conv_messages)
871+
return await self._handle_structured_response(tool_calls, run_context)
874872
else:
875873
raise exceptions.UnexpectedModelBehavior('Received empty model response')
876874

877875
async def _handle_text_response(
878-
self, text: str, deps: AgentDeps, conv_messages: list[_messages.ModelMessage]
876+
self, text: str, run_context: RunContext[AgentDeps]
879877
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
880878
"""Handle a plain text response from the model for non-streaming responses."""
881879
if self._allow_text_result:
882880
result_data_input = cast(ResultData, text)
883881
try:
884-
result_data = await self._validate_result(result_data_input, deps, None, conv_messages)
882+
result_data = await self._validate_result(result_data_input, run_context, None)
885883
except _result.ToolRetryError as e:
886-
self._incr_result_retry()
884+
self._incr_result_retry(run_context)
887885
return None, [e.tool_retry]
888886
else:
889887
return _MarkFinalResult(result_data, None), []
890888
else:
891-
self._incr_result_retry()
889+
self._incr_result_retry(run_context)
892890
response = _messages.RetryPromptPart(
893891
content='Plain text responses are not permitted, please call one of the functions instead.',
894892
)
895893
return None, [response]
896894

897895
async def _handle_structured_response(
898-
self, tool_calls: list[_messages.ToolCallPart], deps: AgentDeps, conv_messages: list[_messages.ModelMessage]
896+
self, tool_calls: list[_messages.ToolCallPart], run_context: RunContext[AgentDeps]
899897
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
900898
"""Handle a structured response containing tool calls from the model for non-streaming responses."""
901899
assert tool_calls, 'Expected at least one tool call'
@@ -909,26 +907,23 @@ async def _handle_structured_response(
909907
call, result_tool = match
910908
try:
911909
result_data = result_tool.validate(call)
912-
result_data = await self._validate_result(result_data, deps, call, conv_messages)
910+
result_data = await self._validate_result(result_data, run_context, call)
913911
except _result.ToolRetryError as e:
914-
self._incr_result_retry()
912+
self._incr_result_retry(run_context)
915913
parts.append(e.tool_retry)
916914
else:
917915
final_result = _MarkFinalResult(result_data, call.tool_name)
918916

919917
# Then build the other request parts based on end strategy
920-
parts += await self._process_function_tools(
921-
tool_calls, final_result and final_result.tool_name, deps, conv_messages
922-
)
918+
parts += await self._process_function_tools(tool_calls, final_result and final_result.tool_name, run_context)
923919

924920
return final_result, parts
925921

926922
async def _process_function_tools(
927923
self,
928924
tool_calls: list[_messages.ToolCallPart],
929925
result_tool_name: str | None,
930-
deps: AgentDeps,
931-
conv_messages: list[_messages.ModelMessage],
926+
run_context: RunContext[AgentDeps],
932927
) -> list[_messages.ModelRequestPart]:
933928
"""Process function (non-result) tool calls in parallel.
934929
@@ -961,7 +956,7 @@ async def _process_function_tools(
961956
)
962957
)
963958
else:
964-
tasks.append(asyncio.create_task(tool.run(deps, call, conv_messages), name=call.tool_name))
959+
tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
965960
elif self._result_schema is not None and call.tool_name in self._result_schema.tools:
966961
# if tool_name is in _result_schema, it means we found a result tool but an error occurred in
967962
# validation, we don't add another part here
@@ -974,7 +969,7 @@ async def _process_function_tools(
974969
)
975970
)
976971
else:
977-
parts.append(self._unknown_tool(call.tool_name))
972+
parts.append(self._unknown_tool(call.tool_name, run_context))
978973

979974
# Run all tool tasks in parallel
980975
if tasks:
@@ -986,8 +981,7 @@ async def _process_function_tools(
986981
async def _handle_streamed_model_response(
987982
self,
988983
model_response: models.EitherStreamedResponse,
989-
deps: AgentDeps,
990-
conv_messages: list[_messages.ModelMessage],
984+
run_context: RunContext[AgentDeps],
991985
) -> (
992986
_MarkFinalResult[models.EitherStreamedResponse]
993987
| tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]
@@ -1003,7 +997,7 @@ async def _handle_streamed_model_response(
1003997
if self._allow_text_result:
1004998
return _MarkFinalResult(model_response, None)
1005999
else:
1006-
self._incr_result_retry()
1000+
self._incr_result_retry(run_context)
10071001
response = _messages.RetryPromptPart(
10081002
content='Plain text responses are not permitted, please call one of the functions instead.',
10091003
)
@@ -1043,9 +1037,9 @@ async def _handle_streamed_model_response(
10431037
if isinstance(item, _messages.ToolCallPart):
10441038
call = item
10451039
if tool := self._function_tools.get(call.tool_name):
1046-
tasks.append(asyncio.create_task(tool.run(deps, call, conv_messages), name=call.tool_name))
1040+
tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
10471041
else:
1048-
parts.append(self._unknown_tool(call.tool_name))
1042+
parts.append(self._unknown_tool(call.tool_name, run_context))
10491043

10501044
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
10511045
task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
@@ -1057,33 +1051,30 @@ async def _handle_streamed_model_response(
10571051
async def _validate_result(
10581052
self,
10591053
result_data: ResultData,
1060-
deps: AgentDeps,
1054+
run_context: RunContext[AgentDeps],
10611055
tool_call: _messages.ToolCallPart | None,
1062-
conv_messages: list[_messages.ModelMessage],
10631056
) -> ResultData:
10641057
for validator in self._result_validators:
1065-
result_data = await validator.validate(
1066-
result_data, deps, self._current_result_retry, tool_call, conv_messages
1067-
)
1058+
result_data = await validator.validate(result_data, tool_call, run_context)
10681059
return result_data
10691060

1070-
def _incr_result_retry(self) -> None:
1071-
self._current_result_retry += 1
1072-
if self._current_result_retry > self._max_result_retries:
1061+
def _incr_result_retry(self, run_context: RunContext[AgentDeps]) -> None:
1062+
run_context.retry += 1
1063+
if run_context.retry > self._max_result_retries:
10731064
raise exceptions.UnexpectedModelBehavior(
10741065
f'Exceeded maximum retries ({self._max_result_retries}) for result validation'
10751066
)
10761067

1077-
async def _sys_parts(self, deps: AgentDeps) -> list[_messages.ModelRequestPart]:
1068+
async def _sys_parts(self, run_context: RunContext[AgentDeps]) -> list[_messages.ModelRequestPart]:
10781069
"""Build the initial messages for the conversation."""
10791070
messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self._system_prompts]
10801071
for sys_prompt_runner in self._system_prompt_functions:
1081-
prompt = await sys_prompt_runner.run(deps)
1072+
prompt = await sys_prompt_runner.run(run_context)
10821073
messages.append(_messages.SystemPromptPart(prompt))
10831074
return messages
10841075

1085-
def _unknown_tool(self, tool_name: str) -> _messages.RetryPromptPart:
1086-
self._incr_result_retry()
1076+
def _unknown_tool(self, tool_name: str, run_context: RunContext[AgentDeps]) -> _messages.RetryPromptPart:
1077+
self._incr_result_retry(run_context)
10871078
names = list(self._function_tools.keys())
10881079
if self._result_schema:
10891080
names.extend(self._result_schema.tool_names())

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from . import _result, _utils, exceptions, messages as _messages, models
1212
from .settings import UsageLimits
13-
from .tools import AgentDeps
13+
from .tools import AgentDeps, RunContext
1414

1515
__all__ = (
1616
'ResultData',
@@ -124,7 +124,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
124124
_usage_limits: UsageLimits | None
125125
_stream_response: models.EitherStreamedResponse
126126
_result_schema: _result.ResultSchema[ResultData] | None
127-
_deps: AgentDeps
127+
_run_ctx: RunContext[AgentDeps]
128128
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]]
129129
_result_tool_name: str | None
130130
_on_complete: Callable[[], Awaitable[None]]
@@ -311,17 +311,15 @@ async def validate_structured_result(
311311
result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False)
312312

313313
for validator in self._result_validators:
314-
result_data = await validator.validate(result_data, self._deps, 0, call, self._all_messages)
314+
result_data = await validator.validate(result_data, call, self._run_ctx)
315315
return result_data
316316

317317
async def _validate_text_result(self, text: str) -> str:
318318
for validator in self._result_validators:
319319
text = await validator.validate( # pyright: ignore[reportAssignmentType]
320320
text, # pyright: ignore[reportArgumentType]
321-
self._deps,
322-
0,
323321
None,
324-
self._all_messages,
322+
self._run_ctx,
325323
)
326324
return text
327325

0 commit comments

Comments
 (0)