1+ from contextlib import asynccontextmanager
2+
3+ import pytest
4+ import asyncio
5+ from unittest .mock import MagicMock , patch
6+
7+ from a2a .server .events import InMemoryQueueManager
8+ from a2a .server .events .event_queue import EventQueue
9+ from a2a .server .events .queue_manager import (
10+ NoTaskQueue ,
11+ TaskQueueExists ,
12+ )
13+
14+
15+
16+ class TestInMemoryQueueManager :
17+ @pytest .fixture
18+ def queue_manager (self ):
19+ """Fixture to create a fresh InMemoryQueueManager for each test."""
20+ manager = InMemoryQueueManager ()
21+ return manager
22+
23+ @pytest .fixture
24+ def event_queue (self ):
25+ """Fixture to create a mock EventQueue."""
26+ queue = MagicMock (spec = EventQueue )
27+ # Mock the tap method to return itself
28+ queue .tap .return_value = queue
29+ return queue
30+
31+ @pytest .mark .asyncio
32+ async def test_init (self , queue_manager ):
33+ """Test that the InMemoryQueueManager initializes with empty task queue and a lock."""
34+ assert queue_manager ._task_queue == {}
35+ assert isinstance (queue_manager ._lock , asyncio .Lock )
36+
37+ @pytest .mark .asyncio
38+ async def test_add_new_queue (self , queue_manager , event_queue ):
39+ """Test adding a new queue to the manager."""
40+ task_id = "test_task_id"
41+ await queue_manager .add (task_id , event_queue )
42+ assert queue_manager ._task_queue [task_id ] == event_queue
43+
44+ @pytest .mark .asyncio
45+ async def test_add_existing_queue (self , queue_manager , event_queue ):
46+ """Test adding a queue with an existing task_id raises TaskQueueExists."""
47+ task_id = "test_task_id"
48+ await queue_manager .add (task_id , event_queue )
49+
50+ with pytest .raises (TaskQueueExists ):
51+ await queue_manager .add (task_id , event_queue )
52+
53+ @pytest .mark .asyncio
54+ async def test_get_existing_queue (self , queue_manager , event_queue ):
55+ """Test getting an existing queue returns the queue."""
56+ task_id = "test_task_id"
57+ await queue_manager .add (task_id , event_queue )
58+
59+ result = await queue_manager .get (task_id )
60+ assert result == event_queue
61+
62+ @pytest .mark .asyncio
63+ async def test_get_nonexistent_queue (self , queue_manager ):
64+ """Test getting a non-existent queue returns None."""
65+ result = await queue_manager .get ("nonexistent_task_id" )
66+ assert result is None
67+
68+ @pytest .mark .asyncio
69+ async def test_tap_existing_queue (self , queue_manager , event_queue ):
70+ """Test tapping an existing queue returns the tapped queue."""
71+ task_id = "test_task_id"
72+ await queue_manager .add (task_id , event_queue )
73+
74+ result = await queue_manager .tap (task_id )
75+ assert result == event_queue
76+ event_queue .tap .assert_called_once ()
77+
78+ @pytest .mark .asyncio
79+ async def test_tap_nonexistent_queue (self , queue_manager ):
80+ """Test tapping a non-existent queue returns None."""
81+ result = await queue_manager .tap ("nonexistent_task_id" )
82+ assert result is None
83+
84+ @pytest .mark .asyncio
85+ async def test_close_existing_queue (self , queue_manager , event_queue ):
86+ """Test closing an existing queue removes it from the manager."""
87+ task_id = "test_task_id"
88+ await queue_manager .add (task_id , event_queue )
89+
90+ await queue_manager .close (task_id )
91+ assert task_id not in queue_manager ._task_queue
92+
93+ @pytest .mark .asyncio
94+ async def test_close_nonexistent_queue (self , queue_manager ):
95+ """Test closing a non-existent queue raises NoTaskQueue."""
96+ with pytest .raises (NoTaskQueue ):
97+ await queue_manager .close ("nonexistent_task_id" )
98+
99+ @pytest .mark .asyncio
100+ async def test_create_or_tap_new_queue (self , queue_manager ):
101+ """Test create_or_tap with a new task_id creates and returns a new queue."""
102+ task_id = "test_task_id"
103+
104+ result = await queue_manager .create_or_tap (task_id )
105+ assert isinstance (result , EventQueue )
106+ assert queue_manager ._task_queue [task_id ] == result
107+
108+ @pytest .mark .asyncio
109+ async def test_create_or_tap_existing_queue (self , queue_manager , event_queue ):
110+ """Test create_or_tap with an existing task_id taps and returns the existing queue."""
111+ task_id = "test_task_id"
112+ await queue_manager .add (task_id , event_queue )
113+
114+ result = await queue_manager .create_or_tap (task_id )
115+
116+ assert result == event_queue
117+ event_queue .tap .assert_called_once ()
118+
119+ @pytest .mark .asyncio
120+ async def test_concurrency (self , queue_manager ):
121+ """Test concurrent access to the queue manager."""
122+
123+ async def add_task (task_id ):
124+ queue = EventQueue ()
125+ await queue_manager .add (task_id , queue )
126+ return task_id
127+
128+ async def get_task (task_id ):
129+ return await queue_manager .get (task_id )
130+
131+ # Create 10 different task IDs
132+ task_ids = [f"task_{ i } " for i in range (10 )]
133+
134+ # Add tasks concurrently
135+ add_tasks = [add_task (task_id ) for task_id in task_ids ]
136+ added_task_ids = await asyncio .gather (* add_tasks )
137+
138+ # Verify all tasks were added
139+ assert set (added_task_ids ) == set (task_ids )
140+
141+ # Get tasks concurrently
142+ get_tasks = [get_task (task_id ) for task_id in task_ids ]
143+ queues = await asyncio .gather (* get_tasks )
144+
145+ # Verify all queues are not None
146+ assert all (queue is not None for queue in queues )
147+
148+ # Verify all tasks are in the manager
149+ for task_id in task_ids :
150+ assert task_id in queue_manager ._task_queue
0 commit comments