Skip to content

Commit f897c5c

Browse files
authored
Add with agent.sequential_tool_calls(): contextmanager and use it in DBOSAgent (#2856)
1 parent d78a3db commit f897c5c

File tree

6 files changed

+52
-15
lines changed

6 files changed

+52
-15
lines changed

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from __future__ import annotations
22

33
import json
4+
from collections.abc import Iterator
5+
from contextlib import contextmanager
6+
from contextvars import ContextVar
47
from dataclasses import dataclass, field, replace
58
from typing import Any, Generic
69

@@ -16,6 +19,8 @@
1619
from .toolsets.abstract import AbstractToolset, ToolsetTool
1720
from .usage import UsageLimits
1821

22+
_sequential_tool_calls_ctx_var: ContextVar[bool] = ContextVar('sequential_tool_calls', default=False)
23+
1924

2025
@dataclass
2126
class ToolManager(Generic[AgentDepsT]):
@@ -30,6 +35,16 @@ class ToolManager(Generic[AgentDepsT]):
3035
failed_tools: set[str] = field(default_factory=set)
3136
"""Names of tools that failed in this run step."""
3237

38+
@classmethod
39+
@contextmanager
40+
def sequential_tool_calls(cls) -> Iterator[None]:
41+
"""Run tool calls sequentially during the context."""
42+
token = _sequential_tool_calls_ctx_var.set(True)
43+
try:
44+
yield
45+
finally:
46+
_sequential_tool_calls_ctx_var.reset(token)
47+
3348
async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]:
3449
"""Build a new tool manager for the next run step, carrying over the retries from the current run step."""
3550
if self.ctx is not None:
@@ -58,7 +73,9 @@ def tool_defs(self) -> list[ToolDefinition]:
5873

5974
def should_call_sequentially(self, calls: list[ToolCallPart]) -> bool:
6075
"""Whether to require sequential tool calls for a list of tool calls."""
61-
return any(tool_def.sequential for call in calls if (tool_def := self.get_tool_def(call.tool_name)))
76+
return _sequential_tool_calls_ctx_var.get() or any(
77+
tool_def.sequential for call in calls if (tool_def := self.get_tool_def(call.tool_name))
78+
)
6279

6380
def get_tool_def(self, name: str) -> ToolDefinition | None:
6481
"""Get the tool definition for a given tool name, or `None` if the tool is unknown."""

pydantic_ai_slim/pydantic_ai/agent/abstract.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
result,
2222
usage as _usage,
2323
)
24+
from .._tool_manager import ToolManager
2425
from ..output import OutputDataT, OutputSpec
2526
from ..result import AgentStream, FinalResult, StreamedRunResult
2627
from ..run import AgentRun, AgentRunResult
@@ -714,6 +715,13 @@ def _infer_name(self, function_frame: FrameType | None) -> None:
714715
self.name = name
715716
return
716717

718+
@staticmethod
719+
@contextmanager
720+
def sequential_tool_calls() -> Iterator[None]:
721+
"""Run tool calls sequentially during the context."""
722+
with ToolManager.sequential_tool_calls():
723+
yield
724+
717725
@staticmethod
718726
def is_model_request_node(
719727
node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]],

pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
models,
1414
usage as _usage,
1515
)
16-
from pydantic_ai._run_context import AgentDepsT
1716
from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, RunOutputDataT, WrapperAgent
1817
from pydantic_ai.exceptions import UserError
1918
from pydantic_ai.mcp import MCPServer
@@ -22,6 +21,7 @@
2221
from pydantic_ai.result import StreamedRunResult
2322
from pydantic_ai.settings import ModelSettings
2423
from pydantic_ai.tools import (
24+
AgentDepsT,
2525
DeferredToolResults,
2626
RunContext,
2727
Tool,
@@ -218,7 +218,10 @@ def toolsets(self) -> Sequence[AbstractToolset[AgentDepsT]]:
218218
@contextmanager
219219
def _dbos_overrides(self) -> Iterator[None]:
220220
# Override with DBOSModel and DBOSMCPServer in the toolsets.
221-
with super().override(model=self._model, toolsets=self._toolsets, tools=[]):
221+
with (
222+
super().override(model=self._model, toolsets=self._toolsets, tools=[]),
223+
self.sequential_tool_calls(),
224+
):
222225
yield
223226

224227
@overload

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
models,
2222
usage as _usage,
2323
)
24-
from pydantic_ai._run_context import AgentDepsT
2524
from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, RunOutputDataT, WrapperAgent
2625
from pydantic_ai.exceptions import UserError
2726
from pydantic_ai.models import Model
2827
from pydantic_ai.output import OutputDataT, OutputSpec
2928
from pydantic_ai.result import StreamedRunResult
3029
from pydantic_ai.settings import ModelSettings
3130
from pydantic_ai.tools import (
31+
AgentDepsT,
3232
DeferredToolResults,
3333
RunContext,
3434
Tool,

tests/test_agent.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import AsyncIterable, Callable
66
from dataclasses import dataclass, replace
77
from datetime import timezone
8-
from typing import Any, Union
8+
from typing import Any, Literal, Union
99

1010
import httpx
1111
import pytest
@@ -4327,7 +4327,8 @@ async def call_tools_parallel(messages: list[ModelMessage], info: AgentInfo) ->
43274327
assert result.output == snapshot('finished')
43284328

43294329

4330-
def test_sequential_calls():
4330+
@pytest.mark.parametrize('mode', ['argument', 'contextmanager'])
4331+
def test_sequential_calls(mode: Literal['argument', 'contextmanager']):
43314332
"""Test that tool calls are executed correctly when a `sequential` tool is present in the call."""
43324333

43334334
async def call_tools_sequential(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
@@ -4355,31 +4356,37 @@ async def call_tools_sequential(messages: list[ModelMessage], info: AgentInfo) -
43554356

43564357
integer_holder: int = 1
43574358

4358-
@sequential_toolset.tool(sequential=False)
4359+
@sequential_toolset.tool
43594360
def call_first():
43604361
nonlocal integer_holder
43614362
assert integer_holder == 1
43624363

4363-
@sequential_toolset.tool(sequential=True)
4364+
@sequential_toolset.tool(sequential=mode == 'argument')
43644365
def increment_integer_holder():
43654366
nonlocal integer_holder
43664367
integer_holder = 2
43674368

4368-
@sequential_toolset.tool()
4369+
@sequential_toolset.tool
43694370
def requires_approval():
43704371
from pydantic_ai.exceptions import ApprovalRequired
43714372

43724373
raise ApprovalRequired()
43734374

4374-
@sequential_toolset.tool(sequential=False)
4375+
@sequential_toolset.tool
43754376
def call_second():
43764377
nonlocal integer_holder
43774378
assert integer_holder == 2
43784379

43794380
agent = Agent(
43804381
FunctionModel(call_tools_sequential), toolsets=[sequential_toolset], output_type=[str, DeferredToolRequests]
43814382
)
4382-
result = agent.run_sync()
4383+
4384+
if mode == 'contextmanager':
4385+
with agent.sequential_tool_calls():
4386+
result = agent.run_sync()
4387+
else:
4388+
result = agent.run_sync()
4389+
43834390
assert result.output == snapshot(
43844391
DeferredToolRequests(approvals=[ToolCallPart(tool_name='requires_approval', tool_call_id=IsStr())])
43854392
)

tests/test_dbos.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,11 @@ def dbos() -> Generator[DBOS, Any, None]:
132132
def cleanup_test_sqlite_file() -> Iterator[None]:
133133
if os.path.exists(DBOS_SQLITE_FILE):
134134
os.remove(DBOS_SQLITE_FILE) # pragma: lax no cover
135-
yield
136-
137-
if os.path.exists(DBOS_SQLITE_FILE):
138-
os.remove(DBOS_SQLITE_FILE) # pragma: lax no cover
135+
try:
136+
yield
137+
finally:
138+
if os.path.exists(DBOS_SQLITE_FILE):
139+
os.remove(DBOS_SQLITE_FILE) # pragma: lax no cover
139140

140141

141142
model = OpenAIChatModel(
@@ -256,6 +257,7 @@ async def test_complex_agent_run_in_workflow(allow_model_requests: None, dbos: D
256257
'complex_agent__model.request_stream',
257258
'event_stream_handler',
258259
'event_stream_handler',
260+
'event_stream_handler',
259261
'complex_agent__mcp_server__mcp.call_tool',
260262
'event_stream_handler',
261263
'complex_agent__mcp_server__mcp.get_tools',

0 commit comments

Comments
 (0)