|
4 | 4 | import pickle |
5 | 5 | from collections.abc import Callable, Sequence |
6 | 6 | from enum import IntEnum, auto |
7 | | -from typing import Any |
| 7 | +from typing import Any, cast |
8 | 8 | from unittest.mock import patch |
9 | 9 |
|
10 | 10 | import pytest |
@@ -956,3 +956,54 @@ async def export_frame(self): |
956 | 956 | ) |
957 | 957 | assert response.status_code == 200 |
958 | 958 | assert response.json()["result"] == "3" |
| 959 | + |
| 960 | + |
| 961 | +@pytest.mark.asyncio |
| 962 | +async def test_mixed_concurrency() -> None: |
| 963 | + # Counts the number of tools executing concurrently |
| 964 | + counter = 0 |
| 965 | + counter_lock = asyncio.Lock() |
| 966 | + |
| 967 | + async def sleep_fn() -> int: |
| 968 | + """Stub.""" |
| 969 | + nonlocal counter |
| 970 | + async with counter_lock: |
| 971 | + counter += 1 |
| 972 | + counter_val = counter |
| 973 | + await asyncio.sleep(0.5) |
| 974 | + async with counter_lock: |
| 975 | + counter -= 1 |
| 976 | + return counter_val |
| 977 | + |
| 978 | + async def unsafe_sleep_fn() -> int: |
| 979 | + """Stub.""" |
| 980 | + return await sleep_fn() |
| 981 | + |
| 982 | + safe_sleep = Tool.from_function(sleep_fn) |
| 983 | + unsafe_sleep = Tool.from_function(unsafe_sleep_fn, concurrency_safe=False) |
| 984 | + |
| 985 | + dummy_env = DummyEnv() |
| 986 | + await dummy_env.reset() |
| 987 | + dummy_env.tools = [safe_sleep, unsafe_sleep] |
| 988 | + |
| 989 | + safes = [True, True, True, False, True, False, False, True, True] |
| 990 | + obs, *_ = await dummy_env.step( |
| 991 | + ToolRequestMessage( |
| 992 | + tool_calls=[ |
| 993 | + ToolCall.from_tool(safe_sleep if safe else unsafe_sleep) |
| 994 | + for safe in safes |
| 995 | + ] |
| 996 | + ) |
| 997 | + ) |
| 998 | + |
| 999 | + at_least_one_parallel = False |
| 1000 | + for safe, msg in zip(safes, obs, strict=True): |
| 1001 | + count = int(cast(str, msg.content)) |
| 1002 | + if safe: |
| 1003 | + at_least_one_parallel |= count > 1 |
| 1004 | + else: |
| 1005 | + assert count == 1, "Expected unsafe tools to block all other tool calls." |
| 1006 | + |
| 1007 | + assert at_least_one_parallel, ( |
| 1008 | + "Expected at least one safe tool call to run concurrently with another." |
| 1009 | + ) |
0 commit comments