Skip to content

Commit ffd5691

Browse files
authored
Minor clean up in preparation of graph agent (#779)
1 parent 37651cb commit ffd5691

File tree

3 files changed

+71
-51
lines changed

3 files changed

+71
-51
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
- `'early'`: Stop processing other tool calls once a final result is found
6161
- `'exhaustive'`: Process all tool calls even after finding a final result
6262
"""
63-
RunResultData = TypeVar('RunResultData')
63+
RunResultDataT = TypeVar('RunResultDataT')
6464
"""Type variable for the result data of a run where `result_type` was customized on the run call."""
6565

6666

@@ -214,15 +214,15 @@ async def run(
214214
self,
215215
user_prompt: str,
216216
*,
217-
result_type: type[RunResultData],
217+
result_type: type[RunResultDataT],
218218
message_history: list[_messages.ModelMessage] | None = None,
219219
model: models.Model | models.KnownModelName | None = None,
220220
deps: AgentDepsT = None,
221221
model_settings: ModelSettings | None = None,
222222
usage_limits: _usage.UsageLimits | None = None,
223223
usage: _usage.Usage | None = None,
224224
infer_name: bool = True,
225-
) -> result.RunResult[RunResultData]: ...
225+
) -> result.RunResult[RunResultDataT]: ...
226226

227227
async def run(
228228
self,
@@ -234,7 +234,7 @@ async def run(
234234
model_settings: ModelSettings | None = None,
235235
usage_limits: _usage.UsageLimits | None = None,
236236
usage: _usage.Usage | None = None,
237-
result_type: type[RunResultData] | None = None,
237+
result_type: type[RunResultDataT] | None = None,
238238
infer_name: bool = True,
239239
) -> result.RunResult[Any]:
240240
"""Run the agent with a user prompt in async mode.
@@ -352,21 +352,21 @@ def run_sync(
352352
self,
353353
user_prompt: str,
354354
*,
355-
result_type: type[RunResultData] | None,
355+
result_type: type[RunResultDataT] | None,
356356
message_history: list[_messages.ModelMessage] | None = None,
357357
model: models.Model | models.KnownModelName | None = None,
358358
deps: AgentDepsT = None,
359359
model_settings: ModelSettings | None = None,
360360
usage_limits: _usage.UsageLimits | None = None,
361361
usage: _usage.Usage | None = None,
362362
infer_name: bool = True,
363-
) -> result.RunResult[RunResultData]: ...
363+
) -> result.RunResult[RunResultDataT]: ...
364364

365365
def run_sync(
366366
self,
367367
user_prompt: str,
368368
*,
369-
result_type: type[RunResultData] | None = None,
369+
result_type: type[RunResultDataT] | None = None,
370370
message_history: list[_messages.ModelMessage] | None = None,
371371
model: models.Model | models.KnownModelName | None = None,
372372
deps: AgentDepsT = None,
@@ -442,22 +442,22 @@ def run_stream(
442442
self,
443443
user_prompt: str,
444444
*,
445-
result_type: type[RunResultData],
445+
result_type: type[RunResultDataT],
446446
message_history: list[_messages.ModelMessage] | None = None,
447447
model: models.Model | models.KnownModelName | None = None,
448448
deps: AgentDepsT = None,
449449
model_settings: ModelSettings | None = None,
450450
usage_limits: _usage.UsageLimits | None = None,
451451
usage: _usage.Usage | None = None,
452452
infer_name: bool = True,
453-
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunResultData]]: ...
453+
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunResultDataT]]: ...
454454

455455
@asynccontextmanager
456456
async def run_stream(
457457
self,
458458
user_prompt: str,
459459
*,
460-
result_type: type[RunResultData] | None = None,
460+
result_type: type[RunResultDataT] | None = None,
461461
message_history: list[_messages.ModelMessage] | None = None,
462462
model: models.Model | models.KnownModelName | None = None,
463463
deps: AgentDepsT = None,
@@ -572,7 +572,7 @@ async def on_complete():
572572
# there are result validators that might convert the result data from an overridden
573573
# `result_type` to a type that is not valid as such.
574574
result_validators = cast(
575-
list[_result.ResultValidator[AgentDepsT, RunResultData]], self._result_validators
575+
list[_result.ResultValidator[AgentDepsT, RunResultDataT]], self._result_validators
576576
)
577577

578578
yield result.StreamedRunResult(
@@ -999,7 +999,7 @@ async def _get_model(self, model: models.Model | models.KnownModelName | None) -
999999
return model_
10001000

10011001
async def _prepare_model(
1002-
self, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[RunResultData] | None
1002+
self, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[RunResultDataT] | None
10031003
) -> models.AgentModel:
10041004
"""Build tools and create an agent model."""
10051005
function_tools: list[ToolDefinition] = []
@@ -1035,8 +1035,8 @@ async def _reevaluate_dynamic_prompts(
10351035
)
10361036

10371037
def _prepare_result_schema(
1038-
self, result_type: type[RunResultData] | None
1039-
) -> _result.ResultSchema[RunResultData] | None:
1038+
self, result_type: type[RunResultDataT] | None
1039+
) -> _result.ResultSchema[RunResultDataT] | None:
10401040
if result_type is not None:
10411041
if self._result_validators:
10421042
raise exceptions.UserError('Cannot set a custom run `result_type` when the agent has result validators')
@@ -1053,7 +1053,7 @@ async def _prepare_messages(
10531053
run_context: RunContext[AgentDepsT],
10541054
) -> list[_messages.ModelMessage]:
10551055
try:
1056-
ctx_messages = _messages_ctx_var.get()
1056+
ctx_messages = get_captured_run_messages()
10571057
except LookupError:
10581058
messages: list[_messages.ModelMessage] = []
10591059
else:
@@ -1080,8 +1080,8 @@ async def _handle_model_response(
10801080
self,
10811081
model_response: _messages.ModelResponse,
10821082
run_context: RunContext[AgentDepsT],
1083-
result_schema: _result.ResultSchema[RunResultData] | None,
1084-
) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
1083+
result_schema: _result.ResultSchema[RunResultDataT] | None,
1084+
) -> tuple[_MarkFinalResult[RunResultDataT] | None, list[_messages.ModelRequestPart]]:
10851085
"""Process a non-streamed response from the model.
10861086
10871087
Returns:
@@ -1110,11 +1110,11 @@ async def _handle_model_response(
11101110
raise exceptions.UnexpectedModelBehavior('Received empty model response')
11111111

11121112
async def _handle_text_response(
1113-
self, text: str, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[RunResultData] | None
1114-
) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
1113+
self, text: str, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[RunResultDataT] | None
1114+
) -> tuple[_MarkFinalResult[RunResultDataT] | None, list[_messages.ModelRequestPart]]:
11151115
"""Handle a plain text response from the model for non-streaming responses."""
11161116
if self._allow_text_result(result_schema):
1117-
result_data_input = cast(RunResultData, text)
1117+
result_data_input = cast(RunResultDataT, text)
11181118
try:
11191119
result_data = await self._validate_result(result_data_input, run_context, None)
11201120
except _result.ToolRetryError as e:
@@ -1133,13 +1133,13 @@ async def _handle_structured_response(
11331133
self,
11341134
tool_calls: list[_messages.ToolCallPart],
11351135
run_context: RunContext[AgentDepsT],
1136-
result_schema: _result.ResultSchema[RunResultData] | None,
1137-
) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
1136+
result_schema: _result.ResultSchema[RunResultDataT] | None,
1137+
) -> tuple[_MarkFinalResult[RunResultDataT] | None, list[_messages.ModelRequestPart]]:
11381138
"""Handle a structured response containing tool calls from the model for non-streaming responses."""
11391139
assert tool_calls, 'Expected at least one tool call'
11401140

11411141
# first look for the result tool call
1142-
final_result: _MarkFinalResult[RunResultData] | None = None
1142+
final_result: _MarkFinalResult[RunResultDataT] | None = None
11431143

11441144
parts: list[_messages.ModelRequestPart] = []
11451145
if result_schema is not None:
@@ -1168,7 +1168,7 @@ async def _process_function_tools(
11681168
tool_calls: list[_messages.ToolCallPart],
11691169
result_tool_name: str | None,
11701170
run_context: RunContext[AgentDepsT],
1171-
result_schema: _result.ResultSchema[RunResultData] | None,
1171+
result_schema: _result.ResultSchema[RunResultDataT] | None,
11721172
) -> list[_messages.ModelRequestPart]:
11731173
"""Process function (non-result) tool calls in parallel.
11741174
@@ -1227,7 +1227,7 @@ async def _handle_streamed_response(
12271227
self,
12281228
streamed_response: models.StreamedResponse,
12291229
run_context: RunContext[AgentDepsT],
1230-
result_schema: _result.ResultSchema[RunResultData] | None,
1230+
result_schema: _result.ResultSchema[RunResultDataT] | None,
12311231
) -> _MarkFinalResult[models.StreamedResponse] | tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]:
12321232
"""Process a streamed response from the model.
12331233
@@ -1282,15 +1282,15 @@ async def _handle_streamed_response(
12821282

12831283
async def _validate_result(
12841284
self,
1285-
result_data: RunResultData,
1285+
result_data: RunResultDataT,
12861286
run_context: RunContext[AgentDepsT],
12871287
tool_call: _messages.ToolCallPart | None,
1288-
) -> RunResultData:
1288+
) -> RunResultDataT:
12891289
if self._result_validators:
12901290
agent_result_data = cast(ResultDataT, result_data)
12911291
for validator in self._result_validators:
12921292
agent_result_data = await validator.validate(agent_result_data, tool_call, run_context)
1293-
return cast(RunResultData, agent_result_data)
1293+
return cast(RunResultDataT, agent_result_data)
12941294
else:
12951295
return result_data
12961296

@@ -1315,7 +1315,7 @@ async def _sys_parts(self, run_context: RunContext[AgentDepsT]) -> list[_message
13151315
def _unknown_tool(
13161316
self,
13171317
tool_name: str,
1318-
result_schema: _result.ResultSchema[RunResultData] | None,
1318+
result_schema: _result.ResultSchema[RunResultDataT] | None,
13191319
) -> _messages.RetryPromptPart:
13201320
names = list(self._function_tools.keys())
13211321
if result_schema:
@@ -1358,7 +1358,7 @@ def _infer_name(self, function_frame: FrameType | None) -> None:
13581358
return
13591359

13601360
@staticmethod
1361-
def _allow_text_result(result_schema: _result.ResultSchema[RunResultData] | None) -> bool:
1361+
def _allow_text_result(result_schema: _result.ResultSchema[RunResultDataT] | None) -> bool:
13621362
return result_schema is None or result_schema.allow_text_result
13631363

13641364
@property
@@ -1413,6 +1413,10 @@ def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
14131413
_messages_ctx_var.reset(token)
14141414

14151415

1416+
def get_captured_run_messages() -> _RunMessages:
1417+
return _messages_ctx_var.get()
1418+
1419+
14161420
@dataclasses.dataclass
14171421
class _MarkFinalResult(Generic[ResultDataT]):
14181422
"""Marker class to indicate that the result is the final result.

pydantic_ai_slim/pydantic_ai/tools.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ class Tool(Generic[AgentDepsT]):
158158
_var_positional_field: str | None = field(init=False)
159159
_validator: SchemaValidator = field(init=False, repr=False)
160160
_parameters_json_schema: ObjectJsonSchema = field(init=False)
161+
162+
# TODO: Move this state off the Tool class, which is otherwise stateless.
163+
# This should be tracked inside a specific agent run, not the tool.
161164
current_retry: int = field(default=0, init=False)
162165

163166
def __init__(
@@ -261,7 +264,7 @@ async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition
261264

262265
async def run(
263266
self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT]
264-
) -> _messages.ModelRequestPart:
267+
) -> _messages.ToolReturnPart | _messages.RetryPromptPart:
265268
"""Run the tool function asynchronously."""
266269
try:
267270
if isinstance(message.args, str):

pydantic_graph/pydantic_graph/graph.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import inspect
55
import types
66
from collections.abc import Sequence
7+
from contextlib import ExitStack
78
from dataclasses import dataclass, field
89
from functools import cached_property
910
from pathlib import Path
@@ -75,6 +76,7 @@ async def run(self, ctx: GraphRunContext) -> Increment | End[int]:
7576
snapshot_state: Callable[[StateT], StateT]
7677
_state_type: type[StateT] | _utils.Unset = field(repr=False)
7778
_run_end_type: type[RunEndT] | _utils.Unset = field(repr=False)
79+
_auto_instrument: bool = field(repr=False)
7880

7981
def __init__(
8082
self,
@@ -84,6 +86,7 @@ def __init__(
8486
state_type: type[StateT] | _utils.Unset = _utils.UNSET,
8587
run_end_type: type[RunEndT] | _utils.Unset = _utils.UNSET,
8688
snapshot_state: Callable[[StateT], StateT] = deep_copy_state,
89+
auto_instrument: bool = True,
8790
):
8891
"""Create a graph from a sequence of nodes.
8992
@@ -97,10 +100,12 @@ def __init__(
97100
snapshot_state: A function to snapshot the state of the graph, this is used in
98101
[`NodeStep`][pydantic_graph.state.NodeStep] and [`EndStep`][pydantic_graph.state.EndStep] to record
99102
the state before each step.
103+
auto_instrument: Whether to create a span for the graph run and the execution of each node's run method.
100104
"""
101105
self.name = name
102106
self._state_type = state_type
103107
self._run_end_type = run_end_type
108+
self._auto_instrument = auto_instrument
104109
self.snapshot_state = snapshot_state
105110

106111
parent_namespace = _utils.get_parent_namespace(inspect.currentframe())
@@ -155,26 +160,32 @@ async def main():
155160
self._infer_name(inspect.currentframe())
156161

157162
history: list[HistoryStep[StateT, T]] = []
158-
with _logfire.span(
159-
'{graph_name} run {start=}',
160-
graph_name=self.name or 'graph',
161-
start=start_node,
162-
) as run_span:
163-
while True:
164-
next_node = await self.next(start_node, history, state=state, deps=deps, infer_name=False)
165-
if isinstance(next_node, End):
166-
history.append(EndStep(result=next_node))
163+
with ExitStack() as stack:
164+
run_span: logfire_api.LogfireSpan | None = None
165+
if self._auto_instrument:
166+
run_span = stack.enter_context(
167+
_logfire.span(
168+
'{graph_name} run {start=}',
169+
graph_name=self.name or 'graph',
170+
start=start_node,
171+
)
172+
)
173+
while True:
174+
next_node = await self.next(start_node, history, state=state, deps=deps, infer_name=False)
175+
if isinstance(next_node, End):
176+
history.append(EndStep(result=next_node))
177+
if run_span is not None:
167178
run_span.set_attribute('history', history)
168-
return next_node.data, history
169-
elif isinstance(next_node, BaseNode):
170-
start_node = next_node
179+
return next_node.data, history
180+
elif isinstance(next_node, BaseNode):
181+
start_node = next_node
182+
else:
183+
if TYPE_CHECKING:
184+
typing_extensions.assert_never(next_node)
171185
else:
172-
if TYPE_CHECKING:
173-
typing_extensions.assert_never(next_node)
174-
else:
175-
raise exceptions.GraphRuntimeError(
176-
f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.'
177-
)
186+
raise exceptions.GraphRuntimeError(
187+
f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.'
188+
)
178189

179190
def run_sync(
180191
self: Graph[StateT, DepsT, T],
@@ -232,8 +243,10 @@ async def next(
232243
if node_id not in self.node_defs:
233244
raise exceptions.GraphRuntimeError(f'Node `{node}` is not in the graph.')
234245

235-
ctx = GraphRunContext(state, deps)
236-
with _logfire.span('run node {node_id}', node_id=node_id, node=node):
246+
with ExitStack() as stack:
247+
if self._auto_instrument:
248+
stack.enter_context(_logfire.span('run node {node_id}', node_id=node_id, node=node))
249+
ctx = GraphRunContext(state, deps)
237250
start_ts = _utils.now_utc()
238251
start = perf_counter()
239252
next_node = await node.run(ctx)

0 commit comments

Comments
 (0)