@@ -904,6 +904,14 @@ async def test_on_message_send_stream_with_push_notification():
904904 configuration = message_config ,
905905 )
906906
907+ # Latch to ensure background execute is scheduled before asserting
908+ execute_called = asyncio .Event ()
909+
910+ async def exec_side_effect (* args , ** kwargs ):
911+ execute_called .set ()
912+
913+ mock_agent_executor .execute .side_effect = exec_side_effect
914+
907915 # Mock ResultAggregator and its consume_and_emit
908916 mock_result_aggregator_instance = MagicMock (
909917 spec = ResultAggregator
@@ -1117,6 +1125,8 @@ def sync_get_event_stream_gen_for_prop_test(*args, **kwargs):
11171125 ):
11181126 pass
11191127
1128+ await asyncio .wait_for (execute_called .wait (), timeout = 0.1 )
1129+
11201130 # Assertions
11211131 # 1. set_info called once at the beginning if task exists (or after task is created from message)
11221132 mock_push_config_store .set_info .assert_any_call (task_id , push_config )
@@ -1129,6 +1139,202 @@ def sync_get_event_stream_gen_for_prop_test(*args, **kwargs):
11291139 mock_agent_executor .execute .assert_awaited_once ()
11301140
11311141
1142+ @pytest .mark .asyncio
1143+ async def test_stream_disconnect_then_resubscribe_receives_future_events ():
1144+ """Start streaming, disconnect, then resubscribe and ensure subsequent events are streamed."""
1145+ # Arrange
1146+ mock_task_store = AsyncMock (spec = TaskStore )
1147+ mock_agent_executor = AsyncMock (spec = AgentExecutor )
1148+
1149+ # Use a real queue manager so taps receive future events
1150+ queue_manager = InMemoryQueueManager ()
1151+
1152+ task_id = 'reconn_task_1'
1153+ context_id = 'reconn_ctx_1'
1154+
1155+ # Task exists and is non-final
1156+ task_for_resub = create_sample_task (
1157+ task_id = task_id , context_id = context_id , status_state = TaskState .working
1158+ )
1159+ mock_task_store .get .return_value = task_for_resub
1160+
1161+ request_handler = DefaultRequestHandler (
1162+ agent_executor = mock_agent_executor ,
1163+ task_store = mock_task_store ,
1164+ queue_manager = queue_manager ,
1165+ )
1166+
1167+ params = MessageSendParams (
1168+ message = Message (
1169+ role = Role .user ,
1170+ message_id = 'msg_reconn' ,
1171+ parts = [],
1172+ task_id = task_id ,
1173+ context_id = context_id ,
1174+ )
1175+ )
1176+
1177+ # Producer behavior: emit one event, then later emit second event
1178+ exec_started = asyncio .Event ()
1179+ allow_second_event = asyncio .Event ()
1180+ allow_finish = asyncio .Event ()
1181+
1182+ first_event = create_sample_task (
1183+ task_id = task_id , context_id = context_id , status_state = TaskState .working
1184+ )
1185+ second_event = create_sample_task (
1186+ task_id = task_id , context_id = context_id , status_state = TaskState .completed
1187+ )
1188+
1189+ async def exec_side_effect (_request , queue : EventQueue ):
1190+ exec_started .set ()
1191+ await queue .enqueue_event (first_event )
1192+ await allow_second_event .wait ()
1193+ await queue .enqueue_event (second_event )
1194+ await allow_finish .wait ()
1195+
1196+ mock_agent_executor .execute .side_effect = exec_side_effect
1197+
1198+ # Start streaming and consume first event
1199+ agen = request_handler .on_message_send_stream (
1200+ params , create_server_call_context ()
1201+ )
1202+ first = await agen .__anext__ ()
1203+ assert first == first_event
1204+
1205+ # Simulate client disconnect
1206+ await asyncio .wait_for (agen .aclose (), timeout = 0.1 )
1207+
1208+ # Resubscribe and start consuming future events
1209+ resub_gen = request_handler .on_resubscribe_to_task (
1210+ TaskIdParams (id = task_id ), create_server_call_context ()
1211+ )
1212+
1213+ # Allow producer to emit the next event
1214+ allow_second_event .set ()
1215+
1216+ received = await resub_gen .__anext__ ()
1217+ assert received == second_event
1218+
1219+ # Finish producer to allow cleanup paths to complete
1220+ allow_finish .set ()
1221+
1222+
1223+ @pytest .mark .asyncio
1224+ async def test_on_message_send_stream_client_disconnect_triggers_background_cleanup_and_producer_continues ():
1225+ """Simulate client disconnect: stream stops early, cleanup is scheduled in background,
1226+ producer keeps running, and cleanup completes after producer finishes."""
1227+ # Arrange
1228+ mock_task_store = AsyncMock (spec = TaskStore )
1229+ mock_queue_manager = AsyncMock (spec = QueueManager )
1230+ mock_agent_executor = AsyncMock (spec = AgentExecutor )
1231+ mock_request_context_builder = AsyncMock (spec = RequestContextBuilder )
1232+
1233+ task_id = 'disc_task_1'
1234+ context_id = 'disc_ctx_1'
1235+
1236+ # RequestContext with IDs
1237+ mock_request_context = MagicMock (spec = RequestContext )
1238+ mock_request_context .task_id = task_id
1239+ mock_request_context .context_id = context_id
1240+ mock_request_context_builder .build .return_value = mock_request_context
1241+
1242+ # Queue used by _run_event_stream; must support close()
1243+ mock_queue = AsyncMock (spec = EventQueue )
1244+ mock_queue_manager .create_or_tap .return_value = mock_queue
1245+
1246+ request_handler = DefaultRequestHandler (
1247+ agent_executor = mock_agent_executor ,
1248+ task_store = mock_task_store ,
1249+ queue_manager = mock_queue_manager ,
1250+ request_context_builder = mock_request_context_builder ,
1251+ )
1252+
1253+ params = MessageSendParams (
1254+ message = Message (
1255+ role = Role .user ,
1256+ message_id = 'mid' ,
1257+ parts = [],
1258+ task_id = task_id ,
1259+ context_id = context_id ,
1260+ )
1261+ )
1262+
1263+ # Agent executor runs in background until we allow it to finish
1264+ execute_started = asyncio .Event ()
1265+ execute_finish = asyncio .Event ()
1266+
1267+ async def exec_side_effect (* _args , ** _kwargs ):
1268+ execute_started .set ()
1269+ await execute_finish .wait ()
1270+
1271+ mock_agent_executor .execute .side_effect = exec_side_effect
1272+
1273+ # ResultAggregator emits one Task event (so the stream yields once)
1274+ first_event = create_sample_task (task_id = task_id , context_id = context_id )
1275+
1276+ async def single_event_stream ():
1277+ yield first_event
1278+ # will never yield again; client will disconnect
1279+
1280+ mock_result_aggregator_instance = MagicMock (spec = ResultAggregator )
1281+ mock_result_aggregator_instance .consume_and_emit .return_value = (
1282+ single_event_stream ()
1283+ )
1284+
1285+ produced_task : asyncio .Task | None = None
1286+ cleanup_task : asyncio .Task | None = None
1287+
1288+ orig_create_task = asyncio .create_task
1289+
1290+ def create_task_spy (coro ):
1291+ nonlocal produced_task , cleanup_task
1292+ task = orig_create_task (coro )
1293+ if produced_task is None :
1294+ produced_task = task
1295+ else :
1296+ cleanup_task = task
1297+ return task
1298+
1299+ with (
1300+ patch (
1301+ 'a2a.server.request_handlers.default_request_handler.ResultAggregator' ,
1302+ return_value = mock_result_aggregator_instance ,
1303+ ),
1304+ patch ('asyncio.create_task' , side_effect = create_task_spy ),
1305+ ):
1306+ # Act: start stream and consume only the first event, then disconnect
1307+ agen = request_handler .on_message_send_stream (
1308+ params , create_server_call_context ()
1309+ )
1310+ first = await agen .__anext__ ()
1311+ assert first == first_event
1312+ # Simulate client disconnect
1313+ await asyncio .wait_for (agen .aclose (), timeout = 0.1 )
1314+
1315+ # Assert cleanup was scheduled and producer was started
1316+ assert produced_task is not None
1317+ assert cleanup_task is not None
1318+
1319+ # execute should have started
1320+ await asyncio .wait_for (execute_started .wait (), timeout = 0.1 )
1321+
1322+ # Producer should still be running (not finished immediately on disconnect)
1323+ assert not produced_task .done ()
1324+
1325+ # Allow executor to finish, which should complete producer and then cleanup
1326+ execute_finish .set ()
1327+ await asyncio .wait_for (produced_task , timeout = 0.2 )
1328+ await asyncio .wait_for (cleanup_task , timeout = 0.2 )
1329+
1330+ # Queue close awaited by _run_event_stream
1331+ mock_queue .close .assert_awaited_once ()
1332+ # QueueManager close called by _cleanup_producer
1333+ mock_queue_manager .close .assert_awaited_once_with (task_id )
1334+ # Running agents is cleared
1335+ assert task_id not in request_handler ._running_agents
1336+
1337+
11321338@pytest .mark .asyncio
11331339async def test_on_message_send_stream_task_id_mismatch ():
11341340 """Test on_message_send_stream raises error if yielded task ID mismatches."""
0 commit comments