|
2 | 2 | Comprehensive async safety tests for mcpgateway.
|
3 | 3 | """
|
4 | 4 |
|
| 5 | +from typing import Any, List |
5 | 6 | import pytest
|
6 | 7 | import asyncio
|
7 |
| -import warnings |
8 | 8 | import time
|
9 |
| -from unittest.mock import AsyncMock, patch |
10 | 9 |
|
11 | 10 |
|
12 | 11 | class TestAsyncSafety:
|
13 | 12 | """Test async safety and proper coroutine handling."""
|
14 | 13 |
|
15 |
| - def test_no_unawaited_coroutines(self): |
16 |
| - """Test that no coroutines are left unawaited.""" |
17 |
| - |
18 |
| - # Capture async warnings |
19 |
| - with warnings.catch_warnings(record=True) as caught_warnings: |
20 |
| - warnings.simplefilter("always") |
21 |
| - |
22 |
| - # Run async code that might have unawaited coroutines |
23 |
| - asyncio.run(self._test_async_operations()) |
24 |
| - |
25 |
| - # Check for unawaited coroutine warnings |
26 |
| - unawaited_warnings = [w for w in caught_warnings if "coroutine" in str(w.message) and "never awaited" in str(w.message)] |
27 |
| - |
28 |
| - assert len(unawaited_warnings) == 0, f"Found {len(unawaited_warnings)} unawaited coroutines" |
29 |
| - |
30 |
| - async def _test_async_operations(self): |
31 |
| - """Test various async operations for safety.""" |
32 |
| - |
33 |
| - # Test WebSocket operations |
34 |
| - await self._test_websocket_safety() |
35 |
| - |
36 |
| - # Test database operations |
37 |
| - await self._test_database_safety() |
38 |
| - |
39 |
| - # Test MCP operations |
40 |
| - await self._test_mcp_safety() |
41 |
| - |
42 |
| - async def _test_websocket_safety(self): |
43 |
| - """Test WebSocket async safety.""" |
44 |
| - |
45 |
| - # Mock WebSocket operations |
46 |
| - with patch("websockets.connect") as mock_connect: |
47 |
| - mock_websocket = AsyncMock() |
48 |
| - mock_connect.return_value.__aenter__.return_value = mock_websocket |
49 |
| - |
50 |
| - # Test proper awaiting |
51 |
| - async with mock_connect("ws://test") as websocket: |
52 |
| - await websocket.send("test") |
53 |
| - await websocket.recv() |
54 |
| - |
55 |
| - async def _test_database_safety(self): |
56 |
| - """Test database async safety.""" |
57 |
| - |
58 |
| - # Mock database operations |
59 |
| - with patch("asyncpg.connect") as mock_connect: |
60 |
| - mock_connection = AsyncMock() |
61 |
| - mock_connect.return_value = mock_connection |
62 |
| - |
63 |
| - # Test proper connection handling |
64 |
| - connection = await mock_connect("postgresql://test") |
65 |
| - await connection.execute("SELECT 1") |
66 |
| - await connection.close() |
67 |
| - |
68 |
| - async def _test_mcp_safety(self): |
69 |
| - """Test MCP async safety.""" |
70 |
| - |
71 |
| - # Mock MCP operations |
72 |
| - with patch("aiohttp.ClientSession") as mock_session: |
73 |
| - mock_response = AsyncMock() |
74 |
| - mock_session.return_value.post.return_value.__aenter__.return_value = mock_response |
75 |
| - |
76 |
| - # Test proper session handling |
77 |
| - async with mock_session() as session: |
78 |
| - async with session.post("http://test") as response: |
79 |
| - await response.json() |
80 |
| - |
81 | 14 | @pytest.mark.asyncio
|
82 | 15 | async def test_concurrent_operations_performance(self):
|
83 | 16 | """Test performance of concurrent async operations."""
|
@@ -108,7 +41,7 @@ async def background_task():
|
108 | 41 | await asyncio.sleep(0.1)
|
109 | 42 |
|
110 | 43 | # Create and properly manage tasks
|
111 |
| - tasks = [] |
| 44 | + tasks: List[Any] = [] |
112 | 45 | for _ in range(10):
|
113 | 46 | task = asyncio.create_task(background_task())
|
114 | 47 | tasks.append(task)
|
|
0 commit comments