Skip to content

Commit 5271d01

Browse files
fix:task execution cancelled by client disconnect
1 parent 6d0ef59 commit 5271d01

File tree

3 files changed

+229
-1
lines changed

3 files changed

+229
-1
lines changed

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,10 @@ async def on_message_send_stream(
394394
)
395395
yield event
396396
finally:
397-
await self._cleanup_producer(producer_task, task_id)
397+
# TODO: Track this disconnected cleanup task.
398+
asyncio.create_task( # noqa: RUF006
399+
self._cleanup_producer(producer_task, task_id)
400+
)
398401

399402
async def _register_producer(
400403
self, task_id: str, producer_task: asyncio.Task

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,14 @@ async def test_on_message_send_stream_with_push_notification():
954954
configuration=message_config,
955955
)
956956

957+
# Latch to ensure background execute is scheduled before asserting
958+
execute_called = asyncio.Event()
959+
960+
async def exec_side_effect(*args, **kwargs):
961+
execute_called.set()
962+
963+
mock_agent_executor.execute.side_effect = exec_side_effect
964+
957965
# Mock ResultAggregator and its consume_and_emit
958966
mock_result_aggregator_instance = MagicMock(
959967
spec=ResultAggregator
@@ -1167,6 +1175,8 @@ def sync_get_event_stream_gen_for_prop_test(*args, **kwargs):
11671175
):
11681176
pass
11691177

1178+
await asyncio.wait_for(execute_called.wait(), timeout=0.1)
1179+
11701180
# Assertions
11711181
# 1. set_info called once at the beginning if task exists (or after task is created from message)
11721182
mock_push_config_store.set_info.assert_any_call(task_id, push_config)
@@ -1179,6 +1189,202 @@ def sync_get_event_stream_gen_for_prop_test(*args, **kwargs):
11791189
mock_agent_executor.execute.assert_awaited_once()
11801190

11811191

1192+
@pytest.mark.asyncio
1193+
async def test_stream_disconnect_then_resubscribe_receives_future_events():
1194+
"""Start streaming, disconnect, then resubscribe and ensure subsequent events are streamed."""
1195+
# Arrange
1196+
mock_task_store = AsyncMock(spec=TaskStore)
1197+
mock_agent_executor = AsyncMock(spec=AgentExecutor)
1198+
1199+
# Use a real queue manager so taps receive future events
1200+
queue_manager = InMemoryQueueManager()
1201+
1202+
task_id = 'reconn_task_1'
1203+
context_id = 'reconn_ctx_1'
1204+
1205+
# Task exists and is non-final
1206+
task_for_resub = create_sample_task(
1207+
task_id=task_id, context_id=context_id, status_state=TaskState.working
1208+
)
1209+
mock_task_store.get.return_value = task_for_resub
1210+
1211+
request_handler = DefaultRequestHandler(
1212+
agent_executor=mock_agent_executor,
1213+
task_store=mock_task_store,
1214+
queue_manager=queue_manager,
1215+
)
1216+
1217+
params = MessageSendParams(
1218+
message=Message(
1219+
role=Role.user,
1220+
message_id='msg_reconn',
1221+
parts=[],
1222+
task_id=task_id,
1223+
context_id=context_id,
1224+
)
1225+
)
1226+
1227+
# Producer behavior: emit one event, then later emit second event
1228+
exec_started = asyncio.Event()
1229+
allow_second_event = asyncio.Event()
1230+
allow_finish = asyncio.Event()
1231+
1232+
first_event = create_sample_task(
1233+
task_id=task_id, context_id=context_id, status_state=TaskState.working
1234+
)
1235+
second_event = create_sample_task(
1236+
task_id=task_id, context_id=context_id, status_state=TaskState.completed
1237+
)
1238+
1239+
async def exec_side_effect(_request, queue: EventQueue):
1240+
exec_started.set()
1241+
await queue.enqueue_event(first_event)
1242+
await allow_second_event.wait()
1243+
await queue.enqueue_event(second_event)
1244+
await allow_finish.wait()
1245+
1246+
mock_agent_executor.execute.side_effect = exec_side_effect
1247+
1248+
# Start streaming and consume first event
1249+
agen = request_handler.on_message_send_stream(
1250+
params, create_server_call_context()
1251+
)
1252+
first = await agen.__anext__()
1253+
assert first == first_event
1254+
1255+
# Simulate client disconnect
1256+
await asyncio.wait_for(agen.aclose(), timeout=0.1)
1257+
1258+
# Resubscribe and start consuming future events
1259+
resub_gen = request_handler.on_resubscribe_to_task(
1260+
TaskIdParams(id=task_id), create_server_call_context()
1261+
)
1262+
1263+
# Allow producer to emit the next event
1264+
allow_second_event.set()
1265+
1266+
received = await resub_gen.__anext__()
1267+
assert received == second_event
1268+
1269+
# Finish producer to allow cleanup paths to complete
1270+
allow_finish.set()
1271+
1272+
1273+
@pytest.mark.asyncio
1274+
async def test_on_message_send_stream_client_disconnect_triggers_background_cleanup_and_producer_continues():
1275+
"""Simulate client disconnect: stream stops early, cleanup is scheduled in background,
1276+
producer keeps running, and cleanup completes after producer finishes."""
1277+
# Arrange
1278+
mock_task_store = AsyncMock(spec=TaskStore)
1279+
mock_queue_manager = AsyncMock(spec=QueueManager)
1280+
mock_agent_executor = AsyncMock(spec=AgentExecutor)
1281+
mock_request_context_builder = AsyncMock(spec=RequestContextBuilder)
1282+
1283+
task_id = 'disc_task_1'
1284+
context_id = 'disc_ctx_1'
1285+
1286+
# RequestContext with IDs
1287+
mock_request_context = MagicMock(spec=RequestContext)
1288+
mock_request_context.task_id = task_id
1289+
mock_request_context.context_id = context_id
1290+
mock_request_context_builder.build.return_value = mock_request_context
1291+
1292+
# Queue used by _run_event_stream; must support close()
1293+
mock_queue = AsyncMock(spec=EventQueue)
1294+
mock_queue_manager.create_or_tap.return_value = mock_queue
1295+
1296+
request_handler = DefaultRequestHandler(
1297+
agent_executor=mock_agent_executor,
1298+
task_store=mock_task_store,
1299+
queue_manager=mock_queue_manager,
1300+
request_context_builder=mock_request_context_builder,
1301+
)
1302+
1303+
params = MessageSendParams(
1304+
message=Message(
1305+
role=Role.user,
1306+
message_id='mid',
1307+
parts=[],
1308+
task_id=task_id,
1309+
context_id=context_id,
1310+
)
1311+
)
1312+
1313+
# Agent executor runs in background until we allow it to finish
1314+
execute_started = asyncio.Event()
1315+
execute_finish = asyncio.Event()
1316+
1317+
async def exec_side_effect(*_args, **_kwargs):
1318+
execute_started.set()
1319+
await execute_finish.wait()
1320+
1321+
mock_agent_executor.execute.side_effect = exec_side_effect
1322+
1323+
# ResultAggregator emits one Task event (so the stream yields once)
1324+
first_event = create_sample_task(task_id=task_id, context_id=context_id)
1325+
1326+
async def single_event_stream():
1327+
yield first_event
1328+
# will never yield again; client will disconnect
1329+
1330+
mock_result_aggregator_instance = MagicMock(spec=ResultAggregator)
1331+
mock_result_aggregator_instance.consume_and_emit.return_value = (
1332+
single_event_stream()
1333+
)
1334+
1335+
produced_task: asyncio.Task | None = None
1336+
cleanup_task: asyncio.Task | None = None
1337+
1338+
orig_create_task = asyncio.create_task
1339+
1340+
def create_task_spy(coro):
1341+
nonlocal produced_task, cleanup_task
1342+
task = orig_create_task(coro)
1343+
if produced_task is None:
1344+
produced_task = task
1345+
else:
1346+
cleanup_task = task
1347+
return task
1348+
1349+
with (
1350+
patch(
1351+
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
1352+
return_value=mock_result_aggregator_instance,
1353+
),
1354+
patch('asyncio.create_task', side_effect=create_task_spy),
1355+
):
1356+
# Act: start stream and consume only the first event, then disconnect
1357+
agen = request_handler.on_message_send_stream(
1358+
params, create_server_call_context()
1359+
)
1360+
first = await agen.__anext__()
1361+
assert first == first_event
1362+
# Simulate client disconnect
1363+
await asyncio.wait_for(agen.aclose(), timeout=0.1)
1364+
1365+
# Assert cleanup was scheduled and producer was started
1366+
assert produced_task is not None
1367+
assert cleanup_task is not None
1368+
1369+
# execute should have started
1370+
await asyncio.wait_for(execute_started.wait(), timeout=0.1)
1371+
1372+
# Producer should still be running (not finished immediately on disconnect)
1373+
assert not produced_task.done()
1374+
1375+
# Allow executor to finish, which should complete producer and then cleanup
1376+
execute_finish.set()
1377+
await asyncio.wait_for(produced_task, timeout=0.2)
1378+
await asyncio.wait_for(cleanup_task, timeout=0.2)
1379+
1380+
# Queue close awaited by _run_event_stream
1381+
mock_queue.close.assert_awaited_once()
1382+
# QueueManager close called by _cleanup_producer
1383+
mock_queue_manager.close.assert_awaited_once_with(task_id)
1384+
# Running agents is cleared
1385+
assert task_id not in request_handler._running_agents
1386+
1387+
11821388
@pytest.mark.asyncio
11831389
async def test_on_message_send_stream_task_id_mismatch():
11841390
"""Test on_message_send_stream raises error if yielded task ID mismatches."""

tests/server/request_handlers/test_jsonrpc_handler.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import unittest
23
import unittest.async_case
34

@@ -366,6 +367,14 @@ async def streaming_coro():
366367
for event in events:
367368
yield event
368369

370+
# Latch to ensure background execute is scheduled before asserting
371+
execute_called = asyncio.Event()
372+
373+
async def exec_side_effect(*args, **kwargs):
374+
execute_called.set()
375+
376+
mock_agent_executor.execute.side_effect = exec_side_effect
377+
369378
with patch(
370379
'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all',
371380
return_value=streaming_coro(),
@@ -387,6 +396,7 @@ async def streaming_coro():
387396
event.root, SendStreamingMessageSuccessResponse
388397
)
389398
assert event.root.result == events[i]
399+
await asyncio.wait_for(execute_called.wait(), timeout=0.1)
390400
mock_agent_executor.execute.assert_called_once()
391401

392402
async def test_on_message_stream_new_message_existing_task_success(
@@ -423,6 +433,14 @@ async def streaming_coro():
423433
for event in events:
424434
yield event
425435

436+
# Latch to ensure background execute is scheduled before asserting
437+
execute_called = asyncio.Event()
438+
439+
async def exec_side_effect(*args, **kwargs):
440+
execute_called.set()
441+
442+
mock_agent_executor.execute.side_effect = exec_side_effect
443+
426444
with patch(
427445
'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all',
428446
return_value=streaming_coro(),
@@ -443,6 +461,7 @@ async def streaming_coro():
443461
assert isinstance(response, AsyncGenerator)
444462
collected_events = [item async for item in response]
445463
assert len(collected_events) == len(events)
464+
await asyncio.wait_for(execute_called.wait(), timeout=0.1)
446465
mock_agent_executor.execute.assert_called_once()
447466
assert mock_task.history is not None and len(mock_task.history) == 1
448467

0 commit comments

Comments
 (0)