|
5 | 5 | from collections.abc import AsyncIterable, Callable
|
6 | 6 | from dataclasses import dataclass, replace
|
7 | 7 | from datetime import timezone
|
8 |
| -from typing import Any, Union |
| 8 | +from typing import Any, Literal, Union |
9 | 9 |
|
10 | 10 | import httpx
|
11 | 11 | import pytest
|
@@ -4327,7 +4327,8 @@ async def call_tools_parallel(messages: list[ModelMessage], info: AgentInfo) ->
|
4327 | 4327 | assert result.output == snapshot('finished')
|
4328 | 4328 |
|
4329 | 4329 |
|
4330 |
| -def test_sequential_calls(): |
| 4330 | +@pytest.mark.parametrize('mode', ['argument', 'contextmanager']) |
| 4331 | +def test_sequential_calls(mode: Literal['argument', 'contextmanager']): |
4331 | 4332 | """Test that tool calls are executed correctly when a `sequential` tool is present in the call."""
|
4332 | 4333 |
|
4333 | 4334 | async def call_tools_sequential(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
|
@@ -4355,31 +4356,37 @@ async def call_tools_sequential(messages: list[ModelMessage], info: AgentInfo) -
|
4355 | 4356 |
|
4356 | 4357 | integer_holder: int = 1
|
4357 | 4358 |
|
4358 |
| - @sequential_toolset.tool(sequential=False) |
| 4359 | + @sequential_toolset.tool |
4359 | 4360 | def call_first():
|
4360 | 4361 | nonlocal integer_holder
|
4361 | 4362 | assert integer_holder == 1
|
4362 | 4363 |
|
4363 |
| - @sequential_toolset.tool(sequential=True) |
| 4364 | + @sequential_toolset.tool(sequential=mode == 'argument') |
4364 | 4365 | def increment_integer_holder():
|
4365 | 4366 | nonlocal integer_holder
|
4366 | 4367 | integer_holder = 2
|
4367 | 4368 |
|
4368 |
| - @sequential_toolset.tool() |
| 4369 | + @sequential_toolset.tool |
4369 | 4370 | def requires_approval():
|
4370 | 4371 | from pydantic_ai.exceptions import ApprovalRequired
|
4371 | 4372 |
|
4372 | 4373 | raise ApprovalRequired()
|
4373 | 4374 |
|
4374 |
| - @sequential_toolset.tool(sequential=False) |
| 4375 | + @sequential_toolset.tool |
4375 | 4376 | def call_second():
|
4376 | 4377 | nonlocal integer_holder
|
4377 | 4378 | assert integer_holder == 2
|
4378 | 4379 |
|
4379 | 4380 | agent = Agent(
|
4380 | 4381 | FunctionModel(call_tools_sequential), toolsets=[sequential_toolset], output_type=[str, DeferredToolRequests]
|
4381 | 4382 | )
|
4382 |
| - result = agent.run_sync() |
| 4383 | + |
| 4384 | + if mode == 'contextmanager': |
| 4385 | + with agent.sequential_tool_calls(): |
| 4386 | + result = agent.run_sync() |
| 4387 | + else: |
| 4388 | + result = agent.run_sync() |
| 4389 | + |
4383 | 4390 | assert result.output == snapshot(
|
4384 | 4391 | DeferredToolRequests(approvals=[ToolCallPart(tool_name='requires_approval', tool_call_id=IsStr())])
|
4385 | 4392 | )
|
|
0 commit comments