Skip to content

Commit 8f74221

Browse files
authored
add messages to RunContext (#257)
1 parent ccd26a1 commit 8f74221

File tree

7 files changed

+127
-69
lines changed

7 files changed

+127
-69
lines changed

pydantic_ai_slim/pydantic_ai/_result.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@
1010
from pydantic import TypeAdapter, ValidationError
1111
from typing_extensions import Self, TypeAliasType, TypedDict
1212

13-
from . import _utils, messages
13+
from . import _utils, messages as _messages
1414
from .exceptions import ModelRetry
15-
from .messages import ModelResponse, ToolCallPart
1615
from .result import ResultData
1716
from .tools import AgentDeps, ResultValidatorFunc, RunContext, ToolDefinition
1817

@@ -28,7 +27,12 @@ def __post_init__(self):
2827
self._is_async = inspect.iscoroutinefunction(self.function)
2928

3029
async def validate(
31-
self, result: ResultData, deps: AgentDeps, retry: int, tool_call: messages.ToolCallPart | None
30+
self,
31+
result: ResultData,
32+
deps: AgentDeps,
33+
retry: int,
34+
tool_call: _messages.ToolCallPart | None,
35+
messages: list[_messages.Message],
3236
) -> ResultData:
3337
"""Validate a result but calling the function.
3438
@@ -37,12 +41,13 @@ async def validate(
3741
deps: The agent dependencies.
3842
retry: The current retry number.
3943
tool_call: The original tool call message, `None` if there was no tool call.
44+
messages: The messages exchanged so far in the conversation.
4045
4146
Returns:
4247
Result of either the validated result data (ok) or a retry message (Err).
4348
"""
4449
if self._takes_ctx:
45-
args = RunContext(deps, retry, tool_call.tool_name if tool_call else None), result
50+
args = RunContext(deps, retry, messages, tool_call.tool_name if tool_call else None), result
4651
else:
4752
args = (result,)
4853

@@ -54,7 +59,7 @@ async def validate(
5459
function = cast(Callable[[Any], ResultData], self.function)
5560
result_data = await _utils.run_in_executor(function, *args)
5661
except ModelRetry as r:
57-
m = messages.RetryPrompt(content=r.message)
62+
m = _messages.RetryPrompt(content=r.message)
5863
if tool_call is not None:
5964
m.tool_name = tool_call.tool_name
6065
m.tool_call_id = tool_call.tool_call_id
@@ -66,7 +71,7 @@ async def validate(
6671
class ToolRetryError(Exception):
6772
"""Internal exception used to signal a `ToolRetry` message should be returned to the LLM."""
6873

69-
def __init__(self, tool_retry: messages.RetryPrompt):
74+
def __init__(self, tool_retry: _messages.RetryPrompt):
7075
self.tool_retry = tool_retry
7176
super().__init__()
7277

@@ -108,10 +113,12 @@ def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[ResultDat
108113

109114
return cls(tools=tools, allow_text_result=allow_text_result)
110115

111-
def find_tool(self, message: ModelResponse) -> tuple[ToolCallPart, ResultTool[ResultData]] | None:
116+
def find_tool(
117+
self, message: _messages.ModelResponse
118+
) -> tuple[_messages.ToolCallPart, ResultTool[ResultData]] | None:
112119
"""Find a tool that matches one of the calls."""
113120
for item in message.parts:
114-
if isinstance(item, ToolCallPart):
121+
if isinstance(item, _messages.ToolCallPart):
115122
if result := self.tools.get(item.tool_name):
116123
return item, result
117124

@@ -168,7 +175,7 @@ def __init__(self, response_type: type[ResultData], name: str, description: str
168175
)
169176

170177
def validate(
171-
self, tool_call: messages.ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
178+
self, tool_call: _messages.ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
172179
) -> ResultData:
173180
"""Validate a result message.
174181
@@ -182,7 +189,7 @@ def validate(
182189
"""
183190
try:
184191
pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
185-
if isinstance(tool_call.args, messages.ArgsJson):
192+
if isinstance(tool_call.args, _messages.ArgsJson):
186193
result = self.type_adapter.validate_json(
187194
tool_call.args.args_json or '', experimental_allow_partial=pyd_allow_partial
188195
)
@@ -192,7 +199,7 @@ def validate(
192199
)
193200
except ValidationError as e:
194201
if wrap_validation_errors:
195-
m = messages.RetryPrompt(
202+
m = _messages.RetryPrompt(
196203
tool_name=tool_call.tool_name,
197204
content=e.errors(include_url=False),
198205
tool_call_id=tool_call.tool_call_id,

pydantic_ai_slim/pydantic_ai/_system_prompt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __post_init__(self):
2121

2222
async def run(self, deps: AgentDeps) -> str:
2323
if self._takes_ctx:
24-
args = (RunContext(deps, 0),)
24+
args = (RunContext(deps, 0, [], None),)
2525
else:
2626
args = ()
2727

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ async def run(
243243
while True:
244244
run_step += 1
245245
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
246-
agent_model = await self._prepare_model(model_used, deps)
246+
agent_model = await self._prepare_model(model_used, deps, messages)
247247

248248
with _logfire.span('model request', run_step=run_step) as model_req_span:
249249
model_response, request_cost = await agent_model.request(messages, model_settings)
@@ -255,7 +255,7 @@ async def run(
255255
cost += request_cost
256256

257257
with _logfire.span('handle model response', run_step=run_step) as handle_span:
258-
final_result, response_messages = await self._handle_model_response(model_response, deps)
258+
final_result, response_messages = await self._handle_model_response(model_response, deps, messages)
259259

260260
# Add all messages to the conversation
261261
messages.extend(response_messages)
@@ -391,7 +391,7 @@ async def main():
391391
run_step += 1
392392

393393
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
394-
agent_model = await self._prepare_model(model_used, deps)
394+
agent_model = await self._prepare_model(model_used, deps, messages)
395395

396396
with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
397397
async with agent_model.request_stream(messages, model_settings) as model_response:
@@ -402,7 +402,7 @@ async def main():
402402

403403
with _logfire.span('handle model response') as handle_span:
404404
final_result, response_messages = await self._handle_streamed_model_response(
405-
model_response, deps
405+
model_response, deps, messages
406406
)
407407

408408
# Add all messages to the conversation
@@ -773,12 +773,14 @@ async def _get_model(self, model: models.Model | models.KnownModelName | None) -
773773

774774
return model_, mode_selection
775775

776-
async def _prepare_model(self, model: models.Model, deps: AgentDeps) -> models.AgentModel:
776+
async def _prepare_model(
777+
self, model: models.Model, deps: AgentDeps, messages: list[_messages.Message]
778+
) -> models.AgentModel:
777779
"""Create building tools and create an agent model."""
778780
function_tools: list[ToolDefinition] = []
779781

780782
async def add_tool(tool: Tool[AgentDeps]) -> None:
781-
ctx = RunContext(deps, tool.current_retry, tool.name)
783+
ctx = RunContext(deps, tool.current_retry, messages, tool.name)
782784
if tool_def := await tool.prepare_tool_def(ctx):
783785
function_tools.append(tool_def)
784786

@@ -807,7 +809,7 @@ async def _prepare_messages(
807809
return new_message_index, messages
808810

809811
async def _handle_model_response(
810-
self, model_response: _messages.ModelResponse, deps: AgentDeps
812+
self, model_response: _messages.ModelResponse, deps: AgentDeps, conv_messages: list[_messages.Message]
811813
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.Message]]:
812814
"""Process a non-streamed response from the model.
813815
@@ -824,20 +826,20 @@ async def _handle_model_response(
824826

825827
if texts:
826828
text = '\n\n'.join(texts)
827-
return await self._handle_text_response(text, deps)
829+
return await self._handle_text_response(text, deps, conv_messages)
828830
elif tool_calls:
829-
return await self._handle_structured_response(tool_calls, deps)
831+
return await self._handle_structured_response(tool_calls, deps, conv_messages)
830832
else:
831833
raise exceptions.UnexpectedModelBehavior('Received empty model response')
832834

833835
async def _handle_text_response(
834-
self, text: str, deps: AgentDeps
836+
self, text: str, deps: AgentDeps, conv_messages: list[_messages.Message]
835837
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.Message]]:
836838
"""Handle a plain text response from the model for non-streaming responses."""
837839
if self._allow_text_result:
838840
result_data_input = cast(ResultData, text)
839841
try:
840-
result_data = await self._validate_result(result_data_input, deps, None)
842+
result_data = await self._validate_result(result_data_input, deps, None, conv_messages)
841843
except _result.ToolRetryError as e:
842844
self._incr_result_retry()
843845
return None, [e.tool_retry]
@@ -851,26 +853,24 @@ async def _handle_text_response(
851853
return None, [response]
852854

853855
async def _handle_structured_response(
854-
self, tool_calls: list[_messages.ToolCallPart], deps: AgentDeps
856+
self, tool_calls: list[_messages.ToolCallPart], deps: AgentDeps, conv_messages: list[_messages.Message]
855857
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.Message]]:
856858
"""Handle a structured response containing tool calls from the model for non-streaming responses."""
857859
assert tool_calls, 'Expected at least one tool call'
858860

859861
# First process any final result tool calls
860-
final_result, final_messages = await self._process_final_tool_calls(tool_calls, deps)
862+
final_result, final_messages = await self._process_final_tool_calls(tool_calls, deps, conv_messages)
861863

862864
# Then process regular tools based on end strategy
863865
if self.end_strategy == 'early' and final_result:
864866
tool_messages = self._mark_skipped_function_tools(tool_calls)
865867
else:
866-
tool_messages = await self._process_function_tools(tool_calls, deps)
868+
tool_messages = await self._process_function_tools(tool_calls, deps, conv_messages)
867869

868870
return final_result, [*final_messages, *tool_messages]
869871

870872
async def _process_final_tool_calls(
871-
self,
872-
tool_calls: list[_messages.ToolCallPart],
873-
deps: AgentDeps,
873+
self, tool_calls: list[_messages.ToolCallPart], deps: AgentDeps, conv_messages: list[_messages.Message]
874874
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.Message]]:
875875
"""Process any final result tool calls and return the first valid result."""
876876
if not self._result_schema:
@@ -888,7 +888,7 @@ async def _process_final_tool_calls(
888888
# This is the first result tool - try to use it
889889
try:
890890
result_data = result_tool.validate(call)
891-
result_data = await self._validate_result(result_data, deps, call)
891+
result_data = await self._validate_result(result_data, deps, call, conv_messages)
892892
except _result.ToolRetryError as e:
893893
self._incr_result_retry()
894894
messages.append(e.tool_retry)
@@ -914,17 +914,15 @@ async def _process_final_tool_calls(
914914
return final_result, messages
915915

916916
async def _process_function_tools(
917-
self,
918-
tool_calls: list[_messages.ToolCallPart],
919-
deps: AgentDeps,
917+
self, tool_calls: list[_messages.ToolCallPart], deps: AgentDeps, conv_messages: list[_messages.Message]
920918
) -> list[_messages.Message]:
921919
"""Process function (non-final) tool calls in parallel."""
922920
messages: list[_messages.Message] = []
923921
tasks: list[asyncio.Task[_messages.Message]] = []
924922

925923
for call in tool_calls:
926924
if tool := self._function_tools.get(call.tool_name):
927-
tasks.append(asyncio.create_task(tool.run(deps, call), name=call.tool_name))
925+
tasks.append(asyncio.create_task(tool.run(deps, call, conv_messages), name=call.tool_name))
928926
elif self._result_schema is None or call.tool_name not in self._result_schema.tools:
929927
messages.append(self._unknown_tool(call.tool_name))
930928

@@ -958,7 +956,7 @@ def _mark_skipped_function_tools(
958956
return messages
959957

960958
async def _handle_streamed_model_response(
961-
self, model_response: models.EitherStreamedResponse, deps: AgentDeps
959+
self, model_response: models.EitherStreamedResponse, deps: AgentDeps, conv_messages: list[_messages.Message]
962960
) -> tuple[_MarkFinalResult[models.EitherStreamedResponse] | None, list[_messages.Message]]:
963961
"""Process a streamed response from the model.
964962
@@ -1015,7 +1013,7 @@ async def _handle_streamed_model_response(
10151013
if isinstance(item, _messages.ToolCallPart):
10161014
call = item
10171015
if tool := self._function_tools.get(call.tool_name):
1018-
tasks.append(asyncio.create_task(tool.run(deps, call), name=call.tool_name))
1016+
tasks.append(asyncio.create_task(tool.run(deps, call, conv_messages), name=call.tool_name))
10191017
else:
10201018
messages.append(self._unknown_tool(call.tool_name))
10211019

@@ -1025,10 +1023,16 @@ async def _handle_streamed_model_response(
10251023
return None, messages
10261024

10271025
async def _validate_result(
1028-
self, result_data: ResultData, deps: AgentDeps, tool_call: _messages.ToolCallPart | None
1026+
self,
1027+
result_data: ResultData,
1028+
deps: AgentDeps,
1029+
tool_call: _messages.ToolCallPart | None,
1030+
conv_messages: list[_messages.Message],
10291031
) -> ResultData:
10301032
for validator in self._result_validators:
1031-
result_data = await validator.validate(result_data, deps, self._current_result_retry, tool_call)
1033+
result_data = await validator.validate(
1034+
result_data, deps, self._current_result_retry, tool_call, conv_messages
1035+
)
10321036
return result_data
10331037

10341038
def _incr_result_retry(self) -> None:

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,13 @@ class ModelResponse:
220220
message_kind: Literal['model-response'] = 'model-response'
221221
"""Message source identifier, this type is available on all messages as a discriminator."""
222222

223-
@staticmethod
224-
def from_text(content: str, timestamp: datetime | None = None) -> ModelResponse:
225-
return ModelResponse([TextPart(content)], timestamp=timestamp or _now_utc())
223+
@classmethod
224+
def from_text(cls, content: str, timestamp: datetime | None = None) -> ModelResponse:
225+
return cls([TextPart(content)], timestamp=timestamp or _now_utc())
226+
227+
@classmethod
228+
def from_tool_call(cls, tool_call: ToolCallPart) -> ModelResponse:
229+
return cls([tool_call])
226230

227231

228232
Message = Union[SystemPrompt, UserPrompt, ToolReturn, RetryPrompt, ModelResponse]

0 commit comments

Comments
 (0)