Skip to content

Commit 05ea2b0

Browse files
authored
Partial tool concurrency (#305)
1 parent 3a074b9 commit 05ea2b0

File tree

4 files changed

+145
-20
lines changed

4 files changed

+145
-20
lines changed

src/aviary/env.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
ToolRequestMessage,
3232
ToolResponseMessage,
3333
)
34-
from aviary.utils import format_exc, is_coroutine_callable
34+
from aviary.utils import ReaderWriterLock, format_exc, is_coroutine_callable
3535

3636
logger = logging.getLogger(__name__)
3737

@@ -215,6 +215,7 @@ async def exec_tool_calls(
215215
Ordered list of ToolResponseMessages, order matches the order of tool calls
216216
in the input message.
217217
"""
218+
concurrency_lock = ReaderWriterLock()
218219

219220
async def _exec_tool_call(tool_call: ToolCall) -> ToolResponseMessage:
220221
start = time.monotonic()
@@ -227,6 +228,7 @@ async def _exec_tool_call(tool_call: ToolCall) -> ToolResponseMessage:
227228
f"{tool_call.function.name!r} not a valid name in"
228229
f" { {t.info.name for t in self.tools} }."
229230
) from exc
231+
230232
# we do a special convenience to make
231233
# state be optional in the function signature
232234
need_to_filter = (
@@ -239,25 +241,33 @@ async def _exec_tool_call(tool_call: ToolCall) -> ToolResponseMessage:
239241
if need_to_filter
240242
else function_kwargs
241243
)
244+
245+
concurrency_context = (
246+
concurrency_lock.read_lock()
247+
if tool.concurrency_safe
248+
else concurrency_lock.write_lock()
249+
)
250+
242251
tool_exc: Exception | None = None
243252
try:
244-
if is_coroutine_callable(tool._tool_fn):
245-
content = await maybe_wait_for(
246-
tool._tool_fn(
247-
**tool_call.function.arguments, **filtered_kwargs
248-
),
249-
exec_timeout,
250-
)
251-
else:
252-
# If the function is synchronous, run on a thread
253-
content = await maybe_wait_for(
254-
asyncio.to_thread(
255-
tool._tool_fn,
256-
**tool_call.function.arguments,
257-
**filtered_kwargs,
258-
),
259-
exec_timeout,
260-
)
253+
async with concurrency_context:
254+
if is_coroutine_callable(tool._tool_fn):
255+
content = await maybe_wait_for(
256+
tool._tool_fn(
257+
**tool_call.function.arguments, **filtered_kwargs
258+
),
259+
exec_timeout,
260+
)
261+
else:
262+
# If the function is synchronous, run on a thread
263+
content = await maybe_wait_for(
264+
asyncio.to_thread(
265+
tool._tool_fn,
266+
**tool_call.function.arguments,
267+
**filtered_kwargs,
268+
),
269+
exec_timeout,
270+
)
261271
except Exception as exc:
262272
if not handle_tool_exc:
263273
raise

src/aviary/tools/base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,19 @@ class Tool(BaseModel):
344344
" serialization, and the validation alias enables deserialization."
345345
),
346346
)
347+
concurrency_safe: bool = Field(
348+
default=True,
349+
# Exclude since we need Tool.model_dump() to conform to OpenAI schema.
350+
# Note that this is safe: while we do (de)serialize tools when e.g. passing to
351+
# agents, only Environment.exec_tool_calls uses this field. And we never serialize
352+
# an env after reset.
353+
exclude=True,
354+
description=(
355+
"Whether the tool is safe to run concurrently with itself and other tools. "
356+
"If set to False (not default), then executing this tool will block all "
357+
"other tool calls (including concurrency-safe tools)."
358+
),
359+
)
347360

348361
def __init__(
349362
self,
@@ -382,6 +395,7 @@ def from_function(
382395
docstring_style: DocstringStyle = DocstringStyle.AUTO,
383396
allow_empty_param_descriptions: bool = False,
384397
types_in_param_descriptions: bool = False,
398+
concurrency_safe: bool = True,
385399
**formats,
386400
) -> "Tool":
387401
"""Hydrate this class via inspection from a free function with a docstring."""
@@ -450,6 +464,7 @@ def from_function(
450464
),
451465
parameters=json_schema,
452466
),
467+
concurrency_safe=concurrency_safe,
453468
)
454469

455470

src/aviary/utils.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
import asyncio
12
import base64
23
import inspect
34
import io
45
import random
56
import string
67
from ast import literal_eval
78
from collections import UserDict
8-
from collections.abc import Sequence
9+
from collections.abc import AsyncIterator, Sequence
10+
from contextlib import asynccontextmanager
911
from enum import StrEnum
1012
from typing import (
1113
TYPE_CHECKING,
@@ -510,3 +512,50 @@ def format_exc(exc: BaseException) -> str:
510512
f" {', '.join(repr(e) for e in exc.exceptions)}"
511513
)
512514
return repr(exc)
515+
516+
517+
class ReaderWriterLock:
518+
"""An asyncio lock that allows execution of multiple readers or a single writer.
519+
520+
When a writer is executing, it will block all readers and writers.
521+
The main use case here is for concurrency-unsafe tools to block execution
522+
of other tool calls, while still allowing concurrency-safe tools to execute
523+
in parallel with each other.
524+
"""
525+
526+
def __init__(self):
527+
self._readers = 0
528+
self._writer = False
529+
self._lock = asyncio.Lock()
530+
self._write_ok = asyncio.Condition(self._lock)
531+
self._read_ok = asyncio.Condition(self._lock)
532+
533+
@asynccontextmanager
534+
async def read_lock(self) -> AsyncIterator[None]:
535+
"""Acquire a read lock. This blocks all writers."""
536+
async with self._lock:
537+
while self._writer:
538+
await self._read_ok.wait()
539+
self._readers += 1
540+
try:
541+
yield
542+
finally:
543+
async with self._lock:
544+
self._readers -= 1
545+
if self._readers == 0:
546+
self._write_ok.notify_all()
547+
548+
@asynccontextmanager
549+
async def write_lock(self) -> AsyncIterator[None]:
550+
"""Acquire a write lock. This blocks all readers and writers."""
551+
async with self._lock:
552+
while self._writer or self._readers > 0:
553+
await self._write_ok.wait()
554+
self._writer = True
555+
try:
556+
yield
557+
finally:
558+
async with self._lock:
559+
self._writer = False
560+
self._read_ok.notify_all()
561+
self._write_ok.notify_all()

tests/test_tools.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pickle
55
from collections.abc import Callable, Sequence
66
from enum import IntEnum, auto
7-
from typing import Any
7+
from typing import Any, cast
88
from unittest.mock import patch
99

1010
import pytest
@@ -956,3 +956,54 @@ async def export_frame(self):
956956
)
957957
assert response.status_code == 200
958958
assert response.json()["result"] == "3"
959+
960+
961+
@pytest.mark.asyncio
962+
async def test_mixed_concurrency() -> None:
963+
# Counts the number of tools executing concurrently
964+
counter = 0
965+
counter_lock = asyncio.Lock()
966+
967+
async def sleep_fn() -> int:
968+
"""Stub."""
969+
nonlocal counter
970+
async with counter_lock:
971+
counter += 1
972+
counter_val = counter
973+
await asyncio.sleep(0.5)
974+
async with counter_lock:
975+
counter -= 1
976+
return counter_val
977+
978+
async def unsafe_sleep_fn() -> int:
979+
"""Stub."""
980+
return await sleep_fn()
981+
982+
safe_sleep = Tool.from_function(sleep_fn)
983+
unsafe_sleep = Tool.from_function(unsafe_sleep_fn, concurrency_safe=False)
984+
985+
dummy_env = DummyEnv()
986+
await dummy_env.reset()
987+
dummy_env.tools = [safe_sleep, unsafe_sleep]
988+
989+
safes = [True, True, True, False, True, False, False, True, True]
990+
obs, *_ = await dummy_env.step(
991+
ToolRequestMessage(
992+
tool_calls=[
993+
ToolCall.from_tool(safe_sleep if safe else unsafe_sleep)
994+
for safe in safes
995+
]
996+
)
997+
)
998+
999+
at_least_one_parallel = False
1000+
for safe, msg in zip(safes, obs, strict=True):
1001+
count = int(cast(str, msg.content))
1002+
if safe:
1003+
at_least_one_parallel |= count > 1
1004+
else:
1005+
assert count == 1, "Expected unsafe tools to block all other tool calls."
1006+
1007+
assert at_least_one_parallel, (
1008+
"Expected at least one safe tool call to run concurrently with another."
1009+
)

0 commit comments

Comments
 (0)