|
1 | | -import asyncio |
2 | | -from unittest.mock import AsyncMock, patch |
3 | | - |
4 | 1 | import pytest |
5 | 2 |
|
6 | | - |
7 | | -from agentic_security.mcp.client import run, ClientSession |
8 | | - |
9 | | - |
10 | | -# Fixtures |
11 | | -@pytest.fixture(scope="session") |
12 | | -def event_loop(): |
13 | | - """Create an instance of the default event loop for each test case.""" |
14 | | - loop = asyncio.get_event_loop_policy().new_event_loop() |
15 | | - yield loop |
16 | | - loop.close() |
17 | | - |
18 | | - |
19 | | -@pytest.fixture |
20 | | -async def mock_session(): |
21 | | - with patch("mcp.client.stdio.stdio_client") as mock_client: |
22 | | - mock_read = AsyncMock() |
23 | | - mock_write = AsyncMock() |
24 | | - |
25 | | - # Configures mock client such that mock responses are returned |
26 | | - mock_client.return_value.__aenter__.return_value = (mock_read, mock_write) |
27 | | - |
28 | | - # Creates a mock session |
29 | | - mock_session = AsyncMock(spec=ClientSession) |
30 | | - |
31 | | - # Expected responses |
32 | | - mock_session.initialize = AsyncMock() |
33 | | - mock_session.list_prompts = AsyncMock(return_value=["test_prompt"]) |
34 | | - mock_session.list_resources = AsyncMock(return_value=["test_resource"]) |
35 | | - mock_session.list_tools = AsyncMock(return_value=["echo_tool"]) |
36 | | - mock_session.call_tool = AsyncMock(return_value="Hello from client!") |
37 | | - |
38 | | - with patch("mcp.ClientSession", return_value=mock_session): |
39 | | - yield mock_session |
40 | | - |
41 | | - |
42 | | -# Tests |
43 | | - |
44 | | - |
45 | | -@pytest.mark.asyncio |
46 | | -async def test_mcp_initialization(mock_session): |
47 | | - """Test initialization success and failure cases""" |
48 | | - # Test initialization success case |
49 | | - await run() |
50 | | - mock_session.initialize.assert_called_once() |
51 | | - |
52 | | - # Resetting the mock to test for failure case |
53 | | - mock_session.initialize.reset_mock() |
54 | | - mock_session.initialize.side_effect = ConnectionError("Failed to connect") |
55 | | - |
56 | | - # Test connection error |
57 | | - with pytest.raises(ConnectionError): |
58 | | - await run() |
59 | | - |
60 | | - |
61 | | -@pytest.mark.asyncio |
62 | | -async def test_mcp_list_resources(mock_session): |
63 | | - """Test listing available resources""" |
64 | | - await run() |
65 | | - mock_session.list_resources.assert_called_once() |
66 | | - assert await mock_session.list_resources() == ["test_mcp_resource"] |
67 | | - |
68 | | - |
69 | | -@pytest.mark.asyncio |
70 | | -async def test_mcp_list_tools(mock_session): |
71 | | - """Test listing available tools""" |
72 | | - await run() |
73 | | - mock_session.list_tools.assert_called_once() |
74 | | - assert await mock_session.list_tools() == ["echo_tool"] |
| 3 | +from agentic_security.mcp.client import run |
75 | 4 |
|
76 | 5 |
|
77 | 6 | @pytest.mark.asyncio |
78 | | -async def test_mcp_echo_tool(mock_session): |
| 7 | +async def test_mcp_echo_tool(): |
79 | 8 | """Test the echo tool functionality""" |
80 | | - await run() |
81 | | - mock_session.call_tool.assert_called_once_with( |
82 | | - "echo_tool", arguments={"message": "Hello from client!"} |
83 | | - ) |
| 9 | + prompts, resources, tools = await run() |
| 10 | + assert prompts |
| 11 | + assert resources |
| 12 | + assert tools |
0 commit comments