Skip to content

Commit 82233e3

Browse files
fix:task execution cancelled by client disconnect
1 parent 598d8a1 commit 82233e3

File tree

3 files changed

+231
-42
lines changed

3 files changed

+231
-42
lines changed

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 6 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
MessageSendParams,
3737
Task,
3838
TaskIdParams,
39-
TaskNotCancelableError,
4039
TaskNotFoundError,
4140
TaskPushNotificationConfig,
4241
TaskQueryParams,
@@ -112,26 +111,6 @@ async def on_get_task(
112111
task: Task | None = await self.task_store.get(params.id)
113112
if not task:
114113
raise ServerError(error=TaskNotFoundError())
115-
116-
# Apply historyLength parameter if specified
117-
if params.history_length is not None and task.history:
118-
# Limit history to the most recent N messages
119-
limited_history = (
120-
task.history[-params.history_length :]
121-
if params.history_length > 0
122-
else []
123-
)
124-
# Create a new task instance with limited history
125-
task = Task(
126-
id=task.id,
127-
context_id=task.context_id,
128-
status=task.status,
129-
artifacts=task.artifacts,
130-
history=limited_history,
131-
metadata=task.metadata,
132-
kind=task.kind,
133-
)
134-
135114
return task
136115

137116
async def on_cancel_task(
@@ -145,14 +124,6 @@ async def on_cancel_task(
145124
if not task:
146125
raise ServerError(error=TaskNotFoundError())
147126

148-
# Check if task is in a non-cancelable state (completed, canceled, failed, rejected)
149-
if task.status.state in TERMINAL_TASK_STATES:
150-
raise ServerError(
151-
error=TaskNotCancelableError(
152-
message=f'Task cannot be canceled - current state: {task.status.state}'
153-
)
154-
)
155-
156127
task_manager = TaskManager(
157128
task_id=task.id,
158129
context_id=task.context_id,
@@ -273,9 +244,7 @@ def _validate_task_id_match(self, task_id: str, event_task_id: str) -> None:
273244
"""Validates that agent-generated task ID matches the expected task ID."""
274245
if task_id != event_task_id:
275246
logger.error(
276-
'Agent generated task_id=%s does not match the RequestContext task_id=%s.',
277-
event_task_id,
278-
task_id,
247+
f'Agent generated task_id={event_task_id} does not match the RequestContext task_id={task_id}.'
279248
)
280249
raise ServerError(
281250
InternalError(message='Task ID mismatch in agent response')
@@ -317,19 +286,11 @@ async def on_message_send(
317286

318287
interrupted_or_non_blocking = False
319288
try:
320-
# Create async callback for push notifications
321-
async def push_notification_callback() -> None:
322-
await self._send_push_notification_if_needed(
323-
task_id, result_aggregator
324-
)
325-
326289
(
327290
result,
328291
interrupted_or_non_blocking,
329292
) = await result_aggregator.consume_and_break_on_interrupt(
330-
consumer,
331-
blocking=blocking,
332-
event_callback=push_notification_callback,
293+
consumer, blocking=blocking
333294
)
334295
if not result:
335296
raise ServerError(error=InternalError()) # noqa: TRY301
@@ -385,7 +346,10 @@ async def on_message_send_stream(
385346
)
386347
yield event
387348
finally:
388-
await self._cleanup_producer(producer_task, task_id)
349+
# TODO: Track this disconnected cleanup task.
350+
asyncio.create_task( # noqa: RUF006
351+
self._cleanup_producer(producer_task, task_id)
352+
)
389353

390354
async def _register_producer(
391355
self, task_id: str, producer_task: asyncio.Task

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,14 @@ async def test_on_message_send_stream_with_push_notification():
904904
configuration=message_config,
905905
)
906906

907+
# Latch to ensure background execute is scheduled before asserting
908+
execute_called = asyncio.Event()
909+
910+
async def exec_side_effect(*args, **kwargs):
911+
execute_called.set()
912+
913+
mock_agent_executor.execute.side_effect = exec_side_effect
914+
907915
# Mock ResultAggregator and its consume_and_emit
908916
mock_result_aggregator_instance = MagicMock(
909917
spec=ResultAggregator
@@ -1117,6 +1125,8 @@ def sync_get_event_stream_gen_for_prop_test(*args, **kwargs):
11171125
):
11181126
pass
11191127

1128+
await asyncio.wait_for(execute_called.wait(), timeout=0.1)
1129+
11201130
# Assertions
11211131
# 1. set_info called once at the beginning if task exists (or after task is created from message)
11221132
mock_push_config_store.set_info.assert_any_call(task_id, push_config)
@@ -1129,6 +1139,202 @@ def sync_get_event_stream_gen_for_prop_test(*args, **kwargs):
11291139
mock_agent_executor.execute.assert_awaited_once()
11301140

11311141

1142+
@pytest.mark.asyncio
1143+
async def test_stream_disconnect_then_resubscribe_receives_future_events():
1144+
"""Start streaming, disconnect, then resubscribe and ensure subsequent events are streamed."""
1145+
# Arrange
1146+
mock_task_store = AsyncMock(spec=TaskStore)
1147+
mock_agent_executor = AsyncMock(spec=AgentExecutor)
1148+
1149+
# Use a real queue manager so taps receive future events
1150+
queue_manager = InMemoryQueueManager()
1151+
1152+
task_id = 'reconn_task_1'
1153+
context_id = 'reconn_ctx_1'
1154+
1155+
# Task exists and is non-final
1156+
task_for_resub = create_sample_task(
1157+
task_id=task_id, context_id=context_id, status_state=TaskState.working
1158+
)
1159+
mock_task_store.get.return_value = task_for_resub
1160+
1161+
request_handler = DefaultRequestHandler(
1162+
agent_executor=mock_agent_executor,
1163+
task_store=mock_task_store,
1164+
queue_manager=queue_manager,
1165+
)
1166+
1167+
params = MessageSendParams(
1168+
message=Message(
1169+
role=Role.user,
1170+
message_id='msg_reconn',
1171+
parts=[],
1172+
task_id=task_id,
1173+
context_id=context_id,
1174+
)
1175+
)
1176+
1177+
# Producer behavior: emit one event, then later emit second event
1178+
exec_started = asyncio.Event()
1179+
allow_second_event = asyncio.Event()
1180+
allow_finish = asyncio.Event()
1181+
1182+
first_event = create_sample_task(
1183+
task_id=task_id, context_id=context_id, status_state=TaskState.working
1184+
)
1185+
second_event = create_sample_task(
1186+
task_id=task_id, context_id=context_id, status_state=TaskState.completed
1187+
)
1188+
1189+
async def exec_side_effect(_request, queue: EventQueue):
1190+
exec_started.set()
1191+
await queue.enqueue_event(first_event)
1192+
await allow_second_event.wait()
1193+
await queue.enqueue_event(second_event)
1194+
await allow_finish.wait()
1195+
1196+
mock_agent_executor.execute.side_effect = exec_side_effect
1197+
1198+
# Start streaming and consume first event
1199+
agen = request_handler.on_message_send_stream(
1200+
params, create_server_call_context()
1201+
)
1202+
first = await agen.__anext__()
1203+
assert first == first_event
1204+
1205+
# Simulate client disconnect
1206+
await asyncio.wait_for(agen.aclose(), timeout=0.1)
1207+
1208+
# Resubscribe and start consuming future events
1209+
resub_gen = request_handler.on_resubscribe_to_task(
1210+
TaskIdParams(id=task_id), create_server_call_context()
1211+
)
1212+
1213+
# Allow producer to emit the next event
1214+
allow_second_event.set()
1215+
1216+
received = await resub_gen.__anext__()
1217+
assert received == second_event
1218+
1219+
# Finish producer to allow cleanup paths to complete
1220+
allow_finish.set()
1221+
1222+
1223+
@pytest.mark.asyncio
1224+
async def test_on_message_send_stream_client_disconnect_triggers_background_cleanup_and_producer_continues():
1225+
"""Simulate client disconnect: stream stops early, cleanup is scheduled in background,
1226+
producer keeps running, and cleanup completes after producer finishes."""
1227+
# Arrange
1228+
mock_task_store = AsyncMock(spec=TaskStore)
1229+
mock_queue_manager = AsyncMock(spec=QueueManager)
1230+
mock_agent_executor = AsyncMock(spec=AgentExecutor)
1231+
mock_request_context_builder = AsyncMock(spec=RequestContextBuilder)
1232+
1233+
task_id = 'disc_task_1'
1234+
context_id = 'disc_ctx_1'
1235+
1236+
# RequestContext with IDs
1237+
mock_request_context = MagicMock(spec=RequestContext)
1238+
mock_request_context.task_id = task_id
1239+
mock_request_context.context_id = context_id
1240+
mock_request_context_builder.build.return_value = mock_request_context
1241+
1242+
# Queue used by _run_event_stream; must support close()
1243+
mock_queue = AsyncMock(spec=EventQueue)
1244+
mock_queue_manager.create_or_tap.return_value = mock_queue
1245+
1246+
request_handler = DefaultRequestHandler(
1247+
agent_executor=mock_agent_executor,
1248+
task_store=mock_task_store,
1249+
queue_manager=mock_queue_manager,
1250+
request_context_builder=mock_request_context_builder,
1251+
)
1252+
1253+
params = MessageSendParams(
1254+
message=Message(
1255+
role=Role.user,
1256+
message_id='mid',
1257+
parts=[],
1258+
task_id=task_id,
1259+
context_id=context_id,
1260+
)
1261+
)
1262+
1263+
# Agent executor runs in background until we allow it to finish
1264+
execute_started = asyncio.Event()
1265+
execute_finish = asyncio.Event()
1266+
1267+
async def exec_side_effect(*_args, **_kwargs):
1268+
execute_started.set()
1269+
await execute_finish.wait()
1270+
1271+
mock_agent_executor.execute.side_effect = exec_side_effect
1272+
1273+
# ResultAggregator emits one Task event (so the stream yields once)
1274+
first_event = create_sample_task(task_id=task_id, context_id=context_id)
1275+
1276+
async def single_event_stream():
1277+
yield first_event
1278+
# will never yield again; client will disconnect
1279+
1280+
mock_result_aggregator_instance = MagicMock(spec=ResultAggregator)
1281+
mock_result_aggregator_instance.consume_and_emit.return_value = (
1282+
single_event_stream()
1283+
)
1284+
1285+
produced_task: asyncio.Task | None = None
1286+
cleanup_task: asyncio.Task | None = None
1287+
1288+
orig_create_task = asyncio.create_task
1289+
1290+
def create_task_spy(coro):
1291+
nonlocal produced_task, cleanup_task
1292+
task = orig_create_task(coro)
1293+
if produced_task is None:
1294+
produced_task = task
1295+
else:
1296+
cleanup_task = task
1297+
return task
1298+
1299+
with (
1300+
patch(
1301+
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
1302+
return_value=mock_result_aggregator_instance,
1303+
),
1304+
patch('asyncio.create_task', side_effect=create_task_spy),
1305+
):
1306+
# Act: start stream and consume only the first event, then disconnect
1307+
agen = request_handler.on_message_send_stream(
1308+
params, create_server_call_context()
1309+
)
1310+
first = await agen.__anext__()
1311+
assert first == first_event
1312+
# Simulate client disconnect
1313+
await asyncio.wait_for(agen.aclose(), timeout=0.1)
1314+
1315+
# Assert cleanup was scheduled and producer was started
1316+
assert produced_task is not None
1317+
assert cleanup_task is not None
1318+
1319+
# execute should have started
1320+
await asyncio.wait_for(execute_started.wait(), timeout=0.1)
1321+
1322+
# Producer should still be running (not finished immediately on disconnect)
1323+
assert not produced_task.done()
1324+
1325+
# Allow executor to finish, which should complete producer and then cleanup
1326+
execute_finish.set()
1327+
await asyncio.wait_for(produced_task, timeout=0.2)
1328+
await asyncio.wait_for(cleanup_task, timeout=0.2)
1329+
1330+
# Queue close awaited by _run_event_stream
1331+
mock_queue.close.assert_awaited_once()
1332+
# QueueManager close called by _cleanup_producer
1333+
mock_queue_manager.close.assert_awaited_once_with(task_id)
1334+
# Running agents is cleared
1335+
assert task_id not in request_handler._running_agents
1336+
1337+
11321338
@pytest.mark.asyncio
11331339
async def test_on_message_send_stream_task_id_mismatch():
11341340
"""Test on_message_send_stream raises error if yielded task ID mismatches."""

tests/server/request_handlers/test_jsonrpc_handler.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import unittest
23
import unittest.async_case
34

@@ -361,6 +362,14 @@ async def streaming_coro():
361362
for event in events:
362363
yield event
363364

365+
# Latch to ensure background execute is scheduled before asserting
366+
execute_called = asyncio.Event()
367+
368+
async def exec_side_effect(*args, **kwargs):
369+
execute_called.set()
370+
371+
mock_agent_executor.execute.side_effect = exec_side_effect
372+
364373
with patch(
365374
'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all',
366375
return_value=streaming_coro(),
@@ -382,6 +391,7 @@ async def streaming_coro():
382391
event.root, SendStreamingMessageSuccessResponse
383392
)
384393
assert event.root.result == events[i]
394+
await asyncio.wait_for(execute_called.wait(), timeout=0.1)
385395
mock_agent_executor.execute.assert_called_once()
386396

387397
async def test_on_message_stream_new_message_existing_task_success(
@@ -418,6 +428,14 @@ async def streaming_coro():
418428
for event in events:
419429
yield event
420430

431+
# Latch to ensure background execute is scheduled before asserting
432+
execute_called = asyncio.Event()
433+
434+
async def exec_side_effect(*args, **kwargs):
435+
execute_called.set()
436+
437+
mock_agent_executor.execute.side_effect = exec_side_effect
438+
421439
with patch(
422440
'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all',
423441
return_value=streaming_coro(),
@@ -438,6 +456,7 @@ async def streaming_coro():
438456
assert isinstance(response, AsyncGenerator)
439457
collected_events = [item async for item in response]
440458
assert len(collected_events) == len(events)
459+
await asyncio.wait_for(execute_called.wait(), timeout=0.1)
441460
mock_agent_executor.execute.assert_called_once()
442461
assert mock_task.history is not None and len(mock_task.history) == 1
443462

0 commit comments

Comments
 (0)