@@ -954,6 +954,14 @@ async def test_on_message_send_stream_with_push_notification():
954954 configuration = message_config ,
955955 )
956956
957+ # Latch to ensure background execute is scheduled before asserting
958+ execute_called = asyncio .Event ()
959+
960+ async def exec_side_effect (* args , ** kwargs ):
961+ execute_called .set ()
962+
963+ mock_agent_executor .execute .side_effect = exec_side_effect
964+
957965 # Mock ResultAggregator and its consume_and_emit
958966 mock_result_aggregator_instance = MagicMock (
959967 spec = ResultAggregator
@@ -1167,6 +1175,8 @@ def sync_get_event_stream_gen_for_prop_test(*args, **kwargs):
11671175 ):
11681176 pass
11691177
1178+ await asyncio .wait_for (execute_called .wait (), timeout = 0.1 )
1179+
11701180 # Assertions
11711181 # 1. set_info called once at the beginning if task exists (or after task is created from message)
11721182 mock_push_config_store .set_info .assert_any_call (task_id , push_config )
@@ -1179,6 +1189,202 @@ def sync_get_event_stream_gen_for_prop_test(*args, **kwargs):
11791189 mock_agent_executor .execute .assert_awaited_once ()
11801190
11811191
1192+ @pytest .mark .asyncio
1193+ async def test_stream_disconnect_then_resubscribe_receives_future_events ():
1194+ """Start streaming, disconnect, then resubscribe and ensure subsequent events are streamed."""
1195+ # Arrange
1196+ mock_task_store = AsyncMock (spec = TaskStore )
1197+ mock_agent_executor = AsyncMock (spec = AgentExecutor )
1198+
1199+ # Use a real queue manager so taps receive future events
1200+ queue_manager = InMemoryQueueManager ()
1201+
1202+ task_id = 'reconn_task_1'
1203+ context_id = 'reconn_ctx_1'
1204+
1205+ # Task exists and is non-final
1206+ task_for_resub = create_sample_task (
1207+ task_id = task_id , context_id = context_id , status_state = TaskState .working
1208+ )
1209+ mock_task_store .get .return_value = task_for_resub
1210+
1211+ request_handler = DefaultRequestHandler (
1212+ agent_executor = mock_agent_executor ,
1213+ task_store = mock_task_store ,
1214+ queue_manager = queue_manager ,
1215+ )
1216+
1217+ params = MessageSendParams (
1218+ message = Message (
1219+ role = Role .user ,
1220+ message_id = 'msg_reconn' ,
1221+ parts = [],
1222+ task_id = task_id ,
1223+ context_id = context_id ,
1224+ )
1225+ )
1226+
1227+ # Producer behavior: emit one event, then later emit second event
1228+ exec_started = asyncio .Event ()
1229+ allow_second_event = asyncio .Event ()
1230+ allow_finish = asyncio .Event ()
1231+
1232+ first_event = create_sample_task (
1233+ task_id = task_id , context_id = context_id , status_state = TaskState .working
1234+ )
1235+ second_event = create_sample_task (
1236+ task_id = task_id , context_id = context_id , status_state = TaskState .completed
1237+ )
1238+
1239+ async def exec_side_effect (_request , queue : EventQueue ):
1240+ exec_started .set ()
1241+ await queue .enqueue_event (first_event )
1242+ await allow_second_event .wait ()
1243+ await queue .enqueue_event (second_event )
1244+ await allow_finish .wait ()
1245+
1246+ mock_agent_executor .execute .side_effect = exec_side_effect
1247+
1248+ # Start streaming and consume first event
1249+ agen = request_handler .on_message_send_stream (
1250+ params , create_server_call_context ()
1251+ )
1252+ first = await agen .__anext__ ()
1253+ assert first == first_event
1254+
1255+ # Simulate client disconnect
1256+ await asyncio .wait_for (agen .aclose (), timeout = 0.1 )
1257+
1258+ # Resubscribe and start consuming future events
1259+ resub_gen = request_handler .on_resubscribe_to_task (
1260+ TaskIdParams (id = task_id ), create_server_call_context ()
1261+ )
1262+
1263+ # Allow producer to emit the next event
1264+ allow_second_event .set ()
1265+
1266+ received = await resub_gen .__anext__ ()
1267+ assert received == second_event
1268+
1269+ # Finish producer to allow cleanup paths to complete
1270+ allow_finish .set ()
1271+
1272+
1273+ @pytest .mark .asyncio
1274+ async def test_on_message_send_stream_client_disconnect_triggers_background_cleanup_and_producer_continues ():
1275+ """Simulate client disconnect: stream stops early, cleanup is scheduled in background,
1276+ producer keeps running, and cleanup completes after producer finishes."""
1277+ # Arrange
1278+ mock_task_store = AsyncMock (spec = TaskStore )
1279+ mock_queue_manager = AsyncMock (spec = QueueManager )
1280+ mock_agent_executor = AsyncMock (spec = AgentExecutor )
1281+ mock_request_context_builder = AsyncMock (spec = RequestContextBuilder )
1282+
1283+ task_id = 'disc_task_1'
1284+ context_id = 'disc_ctx_1'
1285+
1286+ # RequestContext with IDs
1287+ mock_request_context = MagicMock (spec = RequestContext )
1288+ mock_request_context .task_id = task_id
1289+ mock_request_context .context_id = context_id
1290+ mock_request_context_builder .build .return_value = mock_request_context
1291+
1292+ # Queue used by _run_event_stream; must support close()
1293+ mock_queue = AsyncMock (spec = EventQueue )
1294+ mock_queue_manager .create_or_tap .return_value = mock_queue
1295+
1296+ request_handler = DefaultRequestHandler (
1297+ agent_executor = mock_agent_executor ,
1298+ task_store = mock_task_store ,
1299+ queue_manager = mock_queue_manager ,
1300+ request_context_builder = mock_request_context_builder ,
1301+ )
1302+
1303+ params = MessageSendParams (
1304+ message = Message (
1305+ role = Role .user ,
1306+ message_id = 'mid' ,
1307+ parts = [],
1308+ task_id = task_id ,
1309+ context_id = context_id ,
1310+ )
1311+ )
1312+
1313+ # Agent executor runs in background until we allow it to finish
1314+ execute_started = asyncio .Event ()
1315+ execute_finish = asyncio .Event ()
1316+
1317+ async def exec_side_effect (* _args , ** _kwargs ):
1318+ execute_started .set ()
1319+ await execute_finish .wait ()
1320+
1321+ mock_agent_executor .execute .side_effect = exec_side_effect
1322+
1323+ # ResultAggregator emits one Task event (so the stream yields once)
1324+ first_event = create_sample_task (task_id = task_id , context_id = context_id )
1325+
1326+ async def single_event_stream ():
1327+ yield first_event
1328+ # will never yield again; client will disconnect
1329+
1330+ mock_result_aggregator_instance = MagicMock (spec = ResultAggregator )
1331+ mock_result_aggregator_instance .consume_and_emit .return_value = (
1332+ single_event_stream ()
1333+ )
1334+
1335+ produced_task : asyncio .Task | None = None
1336+ cleanup_task : asyncio .Task | None = None
1337+
1338+ orig_create_task = asyncio .create_task
1339+
1340+ def create_task_spy (coro ):
1341+ nonlocal produced_task , cleanup_task
1342+ task = orig_create_task (coro )
1343+ if produced_task is None :
1344+ produced_task = task
1345+ else :
1346+ cleanup_task = task
1347+ return task
1348+
1349+ with (
1350+ patch (
1351+ 'a2a.server.request_handlers.default_request_handler.ResultAggregator' ,
1352+ return_value = mock_result_aggregator_instance ,
1353+ ),
1354+ patch ('asyncio.create_task' , side_effect = create_task_spy ),
1355+ ):
1356+ # Act: start stream and consume only the first event, then disconnect
1357+ agen = request_handler .on_message_send_stream (
1358+ params , create_server_call_context ()
1359+ )
1360+ first = await agen .__anext__ ()
1361+ assert first == first_event
1362+ # Simulate client disconnect
1363+ await asyncio .wait_for (agen .aclose (), timeout = 0.1 )
1364+
1365+ # Assert cleanup was scheduled and producer was started
1366+ assert produced_task is not None
1367+ assert cleanup_task is not None
1368+
1369+ # execute should have started
1370+ await asyncio .wait_for (execute_started .wait (), timeout = 0.1 )
1371+
1372+ # Producer should still be running (not finished immediately on disconnect)
1373+ assert not produced_task .done ()
1374+
1375+ # Allow executor to finish, which should complete producer and then cleanup
1376+ execute_finish .set ()
1377+ await asyncio .wait_for (produced_task , timeout = 0.2 )
1378+ await asyncio .wait_for (cleanup_task , timeout = 0.2 )
1379+
1380+ # Queue close awaited by _run_event_stream
1381+ mock_queue .close .assert_awaited_once ()
1382+ # QueueManager close called by _cleanup_producer
1383+ mock_queue_manager .close .assert_awaited_once_with (task_id )
1384+ # Running agents is cleared
1385+ assert task_id not in request_handler ._running_agents
1386+
1387+
11821388@pytest .mark .asyncio
11831389async def test_on_message_send_stream_task_id_mismatch ():
11841390 """Test on_message_send_stream raises error if yielded task ID mismatches."""
0 commit comments