Skip to content

Commit 0b151f6

Browse files
Otto-AGPTmajdyz
andauthored
feat(copilot): Execute parallel tool calls concurrently (#12165)
When the LLM returns multiple tool calls in a single response (e.g. multiple web fetches for a research task), they now execute concurrently instead of sequentially. This can dramatically reduce latency for multi-tool turns. **Before:** Tool calls execute one after another — 7 web fetches × 2s each = 14s total **After:** All tool calls fire concurrently — 7 web fetches = ~2s total ### Changes - **`service.py`**: New `_execute_tool_calls_parallel()` function that spawns tool calls as concurrent `asyncio` tasks, collecting stream events via `asyncio.Queue` - **`service.py`**: `_yield_tool_call()` now accepts an optional `session_lock` parameter for concurrent-safe session mutations - **`base.py`**: Session lock exposed via `contextvars` so tools that need it can access it without interface changes - **`run_agent.py`**: Rate-limit counters (`successful_agent_runs`, `successful_agent_schedules`) protected with the session lock to prevent race conditions ### Concurrency Safety | Shared State | Risk | Mitigation | |---|---|---| | `session.messages` (long-running tools only) | Race on append + upsert | `session_lock` wraps mutations | | `session.successful_agent_runs` counter | Bypass max-runs check | `session_lock` wraps read-check-increment | | Tool-internal state (DB queries, API calls) | None — stateless | No mitigation needed | ### Testing - Added `parallel_tool_calls_test.py` with tests for: - Parallel timing verification (sum vs max of delays) - Single tool call regression - Retryable error propagation - Shared session lock verification - Cancellation cleanup Closes SECRT-2016 --------- Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
1 parent be2a48a commit 0b151f6

File tree

2 files changed

+382
-47
lines changed

2 files changed

+382
-47
lines changed
Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
"""Tests for parallel tool call execution in CoPilot.
2+
3+
These tests mock _yield_tool_call to avoid importing the full copilot stack
4+
which requires Prisma, DB connections, etc.
5+
"""
6+
7+
import asyncio
8+
import time
9+
from typing import Any, cast
10+
11+
import pytest
12+
13+
14+
@pytest.mark.asyncio
15+
async def test_parallel_tool_calls_run_concurrently():
16+
"""Multiple tool calls should complete in ~max(delays), not sum(delays)."""
17+
# Import here to allow module-level mocking if needed
18+
from backend.copilot.response_model import (
19+
StreamToolInputAvailable,
20+
StreamToolOutputAvailable,
21+
)
22+
from backend.copilot.service import _execute_tool_calls_parallel
23+
24+
n_tools = 3
25+
delay_per_tool = 0.2
26+
tool_calls = [
27+
{
28+
"id": f"call_{i}",
29+
"type": "function",
30+
"function": {"name": f"tool_{i}", "arguments": "{}"},
31+
}
32+
for i in range(n_tools)
33+
]
34+
35+
# Minimal session mock
36+
class FakeSession:
37+
session_id = "test"
38+
user_id = "test"
39+
40+
def __init__(self):
41+
self.messages = []
42+
43+
original_yield = None
44+
45+
async def fake_yield(tc_list, idx, sess, lock=None):
46+
yield StreamToolInputAvailable(
47+
toolCallId=tc_list[idx]["id"],
48+
toolName=tc_list[idx]["function"]["name"],
49+
input={},
50+
)
51+
await asyncio.sleep(delay_per_tool)
52+
yield StreamToolOutputAvailable(
53+
toolCallId=tc_list[idx]["id"],
54+
toolName=tc_list[idx]["function"]["name"],
55+
output="{}",
56+
)
57+
58+
import backend.copilot.service as svc
59+
60+
original_yield = svc._yield_tool_call
61+
svc._yield_tool_call = fake_yield
62+
try:
63+
start = time.monotonic()
64+
events = []
65+
async for event in _execute_tool_calls_parallel(
66+
tool_calls, cast(Any, FakeSession())
67+
):
68+
events.append(event)
69+
elapsed = time.monotonic() - start
70+
finally:
71+
svc._yield_tool_call = original_yield
72+
73+
assert len(events) == n_tools * 2
74+
# Parallel: should take ~delay, not ~n*delay
75+
assert elapsed < delay_per_tool * (
76+
n_tools - 0.5
77+
), f"Took {elapsed:.2f}s, expected parallel (~{delay_per_tool}s)"
78+
79+
80+
@pytest.mark.asyncio
81+
async def test_single_tool_call_works():
82+
"""Single tool call should work identically."""
83+
from backend.copilot.response_model import (
84+
StreamToolInputAvailable,
85+
StreamToolOutputAvailable,
86+
)
87+
from backend.copilot.service import _execute_tool_calls_parallel
88+
89+
tool_calls = [
90+
{
91+
"id": "call_0",
92+
"type": "function",
93+
"function": {"name": "t", "arguments": "{}"},
94+
}
95+
]
96+
97+
class FakeSession:
98+
session_id = "test"
99+
user_id = "test"
100+
101+
def __init__(self):
102+
self.messages = []
103+
104+
async def fake_yield(tc_list, idx, sess, lock=None):
105+
yield StreamToolInputAvailable(toolCallId="call_0", toolName="t", input={})
106+
yield StreamToolOutputAvailable(toolCallId="call_0", toolName="t", output="{}")
107+
108+
import backend.copilot.service as svc
109+
110+
orig = svc._yield_tool_call
111+
svc._yield_tool_call = fake_yield
112+
try:
113+
events = [
114+
e
115+
async for e in _execute_tool_calls_parallel(
116+
tool_calls, cast(Any, FakeSession())
117+
)
118+
]
119+
finally:
120+
svc._yield_tool_call = orig
121+
122+
assert len(events) == 2
123+
124+
125+
@pytest.mark.asyncio
126+
async def test_retryable_error_propagates():
127+
"""Retryable errors should be raised after all tools finish."""
128+
from backend.copilot.response_model import StreamToolOutputAvailable
129+
from backend.copilot.service import _execute_tool_calls_parallel
130+
131+
tool_calls = [
132+
{
133+
"id": f"call_{i}",
134+
"type": "function",
135+
"function": {"name": f"t_{i}", "arguments": "{}"},
136+
}
137+
for i in range(2)
138+
]
139+
140+
class FakeSession:
141+
session_id = "test"
142+
user_id = "test"
143+
144+
def __init__(self):
145+
self.messages = []
146+
147+
async def fake_yield(tc_list, idx, sess, lock=None):
148+
if idx == 1:
149+
raise KeyError("bad")
150+
from backend.copilot.response_model import StreamToolInputAvailable
151+
152+
yield StreamToolInputAvailable(
153+
toolCallId=tc_list[idx]["id"], toolName="t_0", input={}
154+
)
155+
await asyncio.sleep(0.05)
156+
yield StreamToolOutputAvailable(
157+
toolCallId=tc_list[idx]["id"], toolName="t_0", output="{}"
158+
)
159+
160+
import backend.copilot.service as svc
161+
162+
orig = svc._yield_tool_call
163+
svc._yield_tool_call = fake_yield
164+
try:
165+
events = []
166+
with pytest.raises(KeyError):
167+
async for event in _execute_tool_calls_parallel(
168+
tool_calls, cast(Any, FakeSession())
169+
):
170+
events.append(event)
171+
# First tool's events should still be yielded
172+
assert any(isinstance(e, StreamToolOutputAvailable) for e in events)
173+
finally:
174+
svc._yield_tool_call = orig
175+
176+
177+
@pytest.mark.asyncio
178+
async def test_session_lock_shared():
179+
"""All parallel tools should receive the same lock instance."""
180+
from backend.copilot.response_model import (
181+
StreamToolInputAvailable,
182+
StreamToolOutputAvailable,
183+
)
184+
from backend.copilot.service import _execute_tool_calls_parallel
185+
186+
tool_calls = [
187+
{
188+
"id": f"call_{i}",
189+
"type": "function",
190+
"function": {"name": f"t_{i}", "arguments": "{}"},
191+
}
192+
for i in range(3)
193+
]
194+
195+
class FakeSession:
196+
session_id = "test"
197+
user_id = "test"
198+
199+
def __init__(self):
200+
self.messages = []
201+
202+
observed_locks = []
203+
204+
async def fake_yield(tc_list, idx, sess, lock=None):
205+
observed_locks.append(lock)
206+
yield StreamToolInputAvailable(
207+
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
208+
)
209+
yield StreamToolOutputAvailable(
210+
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", output="{}"
211+
)
212+
213+
import backend.copilot.service as svc
214+
215+
orig = svc._yield_tool_call
216+
svc._yield_tool_call = fake_yield
217+
try:
218+
async for _ in _execute_tool_calls_parallel(
219+
tool_calls, cast(Any, FakeSession())
220+
):
221+
pass
222+
finally:
223+
svc._yield_tool_call = orig
224+
225+
assert len(observed_locks) == 3
226+
assert observed_locks[0] is observed_locks[1] is observed_locks[2]
227+
assert isinstance(observed_locks[0], asyncio.Lock)
228+
229+
230+
@pytest.mark.asyncio
231+
async def test_cancellation_cleans_up():
232+
"""Generator close should cancel in-flight tasks."""
233+
from backend.copilot.response_model import StreamToolInputAvailable
234+
from backend.copilot.service import _execute_tool_calls_parallel
235+
236+
tool_calls = [
237+
{
238+
"id": f"call_{i}",
239+
"type": "function",
240+
"function": {"name": f"t_{i}", "arguments": "{}"},
241+
}
242+
for i in range(2)
243+
]
244+
245+
class FakeSession:
246+
session_id = "test"
247+
user_id = "test"
248+
249+
def __init__(self):
250+
self.messages = []
251+
252+
started = asyncio.Event()
253+
254+
async def fake_yield(tc_list, idx, sess, lock=None):
255+
yield StreamToolInputAvailable(
256+
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
257+
)
258+
started.set()
259+
await asyncio.sleep(10) # simulate long-running
260+
261+
import backend.copilot.service as svc
262+
263+
orig = svc._yield_tool_call
264+
svc._yield_tool_call = fake_yield
265+
try:
266+
gen = _execute_tool_calls_parallel(tool_calls, cast(Any, FakeSession()))
267+
await gen.__anext__() # get first event
268+
await started.wait()
269+
await gen.aclose() # close generator
270+
finally:
271+
svc._yield_tool_call = orig
272+
# If we get here without hanging, cleanup worked

0 commit comments

Comments
 (0)