Skip to content

Commit 47ca656

Browse files
authored
Merge pull request #213 from sjay8/main
Fixed issues 191 195
2 parents 77557ad + 4fa1662 commit 47ca656

File tree

3 files changed

+118
-9
lines changed

3 files changed

+118
-9
lines changed

agentic_security/mcp/client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
)
1212

1313

14-
async def run():
14+
async def run()-> None:
1515
async with stdio_client(server_params) as (read, write):
1616
async with ClientSession(read, write) as session:
17-
# Initialize the connection
17+
# Initialize the connection --> connection doesnt work
1818
await session.initialize()
1919

20-
# List available prompts, resources, and tools
20+
# List available prompts, resources, and tools --> no avalialbe tools
2121
prompts = await session.list_prompts()
2222
print(f"Available prompts: {prompts}")
2323

@@ -27,7 +27,7 @@ async def run():
2727
tools = await session.list_tools()
2828
print(f"Available tools: {tools}")
2929

30-
# Call the echo tool
30+
# Call the echo tool --> echo tool iisue
3131
echo_result = await session.call_tool(
3232
"echo_tool", arguments={"message": "Hello from client!"}
3333
)

agentic_security/mcp/main.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,15 @@
1414

1515
@mcp.tool()
1616
async def verify_llm(spec: str) -> dict:
17-
"""Verify an LLM model specification using the FastAPI server."""
17+
"""
18+
Verify an LLM model specification using the FastAPI server
19+
20+
Returns:
21+
dict: containing the verification result form the FastAPI server
22+
23+
Args: spect(str): The specification of the LLM model to verify.
24+
25+
"""
1826
url = f"{AGENTIC_SECURITY}/verify"
1927
async with httpx.AsyncClient() as client:
2028
response = await client.post(url, json={"spec": spec})
@@ -28,7 +36,18 @@ async def start_scan(
2836
optimize: bool = False,
2937
enableMultiStepAttack: bool = False,
3038
) -> dict:
31-
"""Start an LLM security scan via the FastAPI server."""
39+
"""
40+
Start an LLM security scan via the FastAPI server.
41+
Returns:
42+
dict: The scan initiation result from the FastAPI server.
43+
44+
Args:
45+
llmSpec (str): The specification of the LLM model.
46+
maxBudget (int): The maximum budget for the scan.
47+
optimize (bool, optional): Whether to enable optimization during scanning. Defaults to False.
48+
enableMultiStepAttack (bool, optional): Whether to enable multi-step attack
49+
50+
"""
3251
url = f"{AGENTIC_SECURITY}/scan"
3352
payload = {
3453
"llmSpec": llmSpec,
@@ -46,7 +65,11 @@ async def start_scan(
4665

4766
@mcp.tool()
4867
async def stop_scan() -> dict:
49-
"""Stop an ongoing scan via the FastAPI server."""
68+
"""Stop an ongoing scan via the FastAPI server.
69+
70+
Returns:
71+
dict: The confirmation from the FastAPI server that the scan has been stopped.
72+
"""
5073
url = f"{AGENTIC_SECURITY}/stop"
5174
async with httpx.AsyncClient() as client:
5275
response = await client.post(url)
@@ -55,7 +78,12 @@ async def stop_scan() -> dict:
5578

5679
@mcp.tool()
5780
async def get_data_config() -> list:
58-
"""Retrieve data configuration from the FastAPI server."""
81+
"""
82+
Retrieve data configuration from the FastAPI server.
83+
84+
Returns:
85+
list: The response from the FastAPI server, confirming the scan has been stopped.
86+
"""
5987
url = f"{AGENTIC_SECURITY}/v1/data-config"
6088
async with httpx.AsyncClient() as client:
6189
response = await client.get(url)
@@ -64,7 +92,12 @@ async def get_data_config() -> list:
6492

6593
@mcp.tool()
6694
async def get_spec_templates() -> list:
67-
"""Retrieve data configuration from the FastAPI server."""
95+
"""
96+
Retrieve data configuration from the FastAPI server.
97+
98+
Returns:
99+
list: The LLM specification templates from the FastAPI server.
100+
"""
68101
url = f"{AGENTIC_SECURITY}/v1/llm-specs"
69102
async with httpx.AsyncClient() as client:
70103
response = await client.get(url)

tests/test_client.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import pytest
2+
import asyncio
3+
from unittest.mock import AsyncMock, patch
4+
from agentic_security.mcp import ClientSession
5+
from ..agentic_security.mcp.client import run
6+
7+
# Fixtures
8+
@pytest.fixture(scope="session")
9+
def event_loop():
10+
"""Create an instance of the default event loop for each test case."""
11+
loop = asyncio.get_event_loop_policy().new_event_loop()
12+
yield loop
13+
loop.close()
14+
15+
@pytest.fixture
16+
async def mock_session():
17+
with patch('mcp.client.stdio.stdio_client') as mock_client:
18+
19+
mock_read = AsyncMock()
20+
mock_write = AsyncMock()
21+
22+
# Configures mock client such that mock responses are returned
23+
mock_client.return_value.__aenter__.return_value = (mock_read, mock_write)
24+
25+
# Creates a mock session
26+
mock_session = AsyncMock(spec=ClientSession)
27+
28+
# Expected responses
29+
mock_session.initialize = AsyncMock()
30+
mock_session.list_prompts = AsyncMock(return_value=["test_prompt"])
31+
mock_session.list_resources = AsyncMock(return_value=["test_resource"])
32+
mock_session.list_tools = AsyncMock(return_value=["echo_tool"])
33+
mock_session.call_tool = AsyncMock(return_value="Hello from client!")
34+
35+
with patch('mcp.ClientSession', return_value=mock_session):
36+
yield mock_session
37+
38+
# Tests
39+
40+
@pytest.mark.asyncio
41+
async def test_initialization(mock_session):
42+
"""Test initialization success and failure cases"""
43+
# Test initialization success case
44+
await run()
45+
mock_session.initialize.assert_called_once()
46+
47+
# Resetting the mock to test for failure case
48+
mock_session.initialize.reset_mock()
49+
mock_session.initialize.side_effect = ConnectionError("Failed to connect")
50+
51+
# Test connection error
52+
with pytest.raises(ConnectionError):
53+
await run()
54+
55+
@pytest.mark.asyncio
56+
async def test_list_resources(mock_session):
57+
"""Test listing available resources"""
58+
await run()
59+
mock_session.list_resources.assert_called_once()
60+
assert await mock_session.list_resources() == ["test_resource"]
61+
62+
@pytest.mark.asyncio
63+
async def test_list_tools(mock_session):
64+
"""Test listing available tools"""
65+
await run()
66+
mock_session.list_tools.assert_called_once()
67+
assert await mock_session.list_tools() == ["echo_tool"]
68+
69+
@pytest.mark.asyncio
70+
async def test_echo_tool(mock_session):
71+
"""Test the echo tool functionality"""
72+
await run()
73+
mock_session.call_tool.assert_called_once_with(
74+
"echo_tool",
75+
arguments={"message": "Hello from client!"}
76+
)

0 commit comments

Comments
 (0)