Skip to content

Commit 881955c

Browse files
committed
adding 12 tests for InMemoryQueueManager
1 parent 53a5861 commit 881955c

File tree

1 file changed

+150
-0
lines changed

1 file changed

+150
-0
lines changed
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
from contextlib import asynccontextmanager
2+
3+
import pytest
4+
import asyncio
5+
from unittest.mock import MagicMock, patch
6+
7+
from a2a.server.events import InMemoryQueueManager
8+
from a2a.server.events.event_queue import EventQueue
9+
from a2a.server.events.queue_manager import (
10+
NoTaskQueue,
11+
TaskQueueExists,
12+
)
13+
14+
15+
16+
class TestInMemoryQueueManager:
17+
@pytest.fixture
18+
def queue_manager(self):
19+
"""Fixture to create a fresh InMemoryQueueManager for each test."""
20+
manager = InMemoryQueueManager()
21+
return manager
22+
23+
@pytest.fixture
24+
def event_queue(self):
25+
"""Fixture to create a mock EventQueue."""
26+
queue = MagicMock(spec=EventQueue)
27+
# Mock the tap method to return itself
28+
queue.tap.return_value = queue
29+
return queue
30+
31+
@pytest.mark.asyncio
32+
async def test_init(self, queue_manager):
33+
"""Test that the InMemoryQueueManager initializes with empty task queue and a lock."""
34+
assert queue_manager._task_queue == {}
35+
assert isinstance(queue_manager._lock, asyncio.Lock)
36+
37+
@pytest.mark.asyncio
38+
async def test_add_new_queue(self, queue_manager, event_queue):
39+
"""Test adding a new queue to the manager."""
40+
task_id = "test_task_id"
41+
await queue_manager.add(task_id, event_queue)
42+
assert queue_manager._task_queue[task_id] == event_queue
43+
44+
@pytest.mark.asyncio
45+
async def test_add_existing_queue(self, queue_manager, event_queue):
46+
"""Test adding a queue with an existing task_id raises TaskQueueExists."""
47+
task_id = "test_task_id"
48+
await queue_manager.add(task_id, event_queue)
49+
50+
with pytest.raises(TaskQueueExists):
51+
await queue_manager.add(task_id, event_queue)
52+
53+
@pytest.mark.asyncio
54+
async def test_get_existing_queue(self, queue_manager, event_queue):
55+
"""Test getting an existing queue returns the queue."""
56+
task_id = "test_task_id"
57+
await queue_manager.add(task_id, event_queue)
58+
59+
result = await queue_manager.get(task_id)
60+
assert result == event_queue
61+
62+
@pytest.mark.asyncio
63+
async def test_get_nonexistent_queue(self, queue_manager):
64+
"""Test getting a non-existent queue returns None."""
65+
result = await queue_manager.get("nonexistent_task_id")
66+
assert result is None
67+
68+
@pytest.mark.asyncio
69+
async def test_tap_existing_queue(self, queue_manager, event_queue):
70+
"""Test tapping an existing queue returns the tapped queue."""
71+
task_id = "test_task_id"
72+
await queue_manager.add(task_id, event_queue)
73+
74+
result = await queue_manager.tap(task_id)
75+
assert result == event_queue
76+
event_queue.tap.assert_called_once()
77+
78+
@pytest.mark.asyncio
79+
async def test_tap_nonexistent_queue(self, queue_manager):
80+
"""Test tapping a non-existent queue returns None."""
81+
result = await queue_manager.tap("nonexistent_task_id")
82+
assert result is None
83+
84+
@pytest.mark.asyncio
85+
async def test_close_existing_queue(self, queue_manager, event_queue):
86+
"""Test closing an existing queue removes it from the manager."""
87+
task_id = "test_task_id"
88+
await queue_manager.add(task_id, event_queue)
89+
90+
await queue_manager.close(task_id)
91+
assert task_id not in queue_manager._task_queue
92+
93+
@pytest.mark.asyncio
94+
async def test_close_nonexistent_queue(self, queue_manager):
95+
"""Test closing a non-existent queue raises NoTaskQueue."""
96+
with pytest.raises(NoTaskQueue):
97+
await queue_manager.close("nonexistent_task_id")
98+
99+
@pytest.mark.asyncio
100+
async def test_create_or_tap_new_queue(self, queue_manager):
101+
"""Test create_or_tap with a new task_id creates and returns a new queue."""
102+
task_id = "test_task_id"
103+
104+
result = await queue_manager.create_or_tap(task_id)
105+
assert isinstance(result, EventQueue)
106+
assert queue_manager._task_queue[task_id] == result
107+
108+
@pytest.mark.asyncio
109+
async def test_create_or_tap_existing_queue(self, queue_manager, event_queue):
110+
"""Test create_or_tap with an existing task_id taps and returns the existing queue."""
111+
task_id = "test_task_id"
112+
await queue_manager.add(task_id, event_queue)
113+
114+
result = await queue_manager.create_or_tap(task_id)
115+
116+
assert result == event_queue
117+
event_queue.tap.assert_called_once()
118+
119+
@pytest.mark.asyncio
120+
async def test_concurrency(self, queue_manager):
121+
"""Test concurrent access to the queue manager."""
122+
123+
async def add_task(task_id):
124+
queue = EventQueue()
125+
await queue_manager.add(task_id, queue)
126+
return task_id
127+
128+
async def get_task(task_id):
129+
return await queue_manager.get(task_id)
130+
131+
# Create 10 different task IDs
132+
task_ids = [f"task_{i}" for i in range(10)]
133+
134+
# Add tasks concurrently
135+
add_tasks = [add_task(task_id) for task_id in task_ids]
136+
added_task_ids = await asyncio.gather(*add_tasks)
137+
138+
# Verify all tasks were added
139+
assert set(added_task_ids) == set(task_ids)
140+
141+
# Get tasks concurrently
142+
get_tasks = [get_task(task_id) for task_id in task_ids]
143+
queues = await asyncio.gather(*get_tasks)
144+
145+
# Verify all queues are not None
146+
assert all(queue is not None for queue in queues)
147+
148+
# Verify all tasks are in the manager
149+
for task_id in task_ids:
150+
assert task_id in queue_manager._task_queue

0 commit comments

Comments
 (0)