Skip to content

Commit 32a8162

Browse files
committed
Fix test_tool_service and test_server_service
Signed-off-by: Mihai Criveti <[email protected]>
1 parent 0551591 commit 32a8162

File tree

2 files changed

+77
-47
lines changed

2 files changed

+77
-47
lines changed

tests/unit/mcpgateway/services/test_server_service.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
Tests for server service implementation.
88
"""
99

10-
from unittest.mock import AsyncMock, MagicMock, Mock
10+
from unittest.mock import AsyncMock, MagicMock, Mock, patch
1111

1212
import pytest
1313

@@ -18,7 +18,6 @@
1818
from mcpgateway.schemas import ServerCreate, ServerRead, ServerUpdate
1919
from mcpgateway.services.server_service import (
2020
ServerError,
21-
ServerNameConflictError,
2221
ServerNotFoundError,
2322
ServerService,
2423
)
@@ -113,6 +112,21 @@ async def test_register_server(
113112
test_db.commit = Mock()
114113
test_db.refresh = Mock()
115114

115+
# Create a mock server instance that will be returned by db.add
116+
mock_db_server = MagicMock(spec=DbServer)
117+
mock_db_server.id = 1
118+
mock_db_server.name = "test_server"
119+
mock_db_server.description = "A test server"
120+
mock_db_server.icon = "server-icon"
121+
mock_db_server.created_at = "2023-01-01T00:00:00"
122+
mock_db_server.updated_at = "2023-01-01T00:00:00"
123+
mock_db_server.is_active = True
124+
125+
# Mock the relationship lists as simple lists
126+
mock_db_server.tools = []
127+
mock_db_server.resources = []
128+
mock_db_server.prompts = []
129+
116130
# Stub db.get to resolve associated items
117131
test_db.get = Mock(
118132
side_effect=lambda cls, _id: {
@@ -158,7 +172,9 @@ async def test_register_server(
158172
associated_prompts=["301"],
159173
)
160174

161-
result = await server_service.register_server(test_db, server_create)
175+
# Mock the DbServer constructor to return our mock instance
176+
with patch('mcpgateway.services.server_service.DbServer', return_value=mock_db_server):
177+
result = await server_service.register_server(test_db, server_create)
162178

163179
test_db.add.assert_called_once()
164180
test_db.commit.assert_called_once()
@@ -204,8 +220,15 @@ async def test_register_server_invalid_associated_tool(self, server_service, tes
204220
associated_tools=["999"],
205221
)
206222

207-
with pytest.raises(ServerError) as exc:
208-
await server_service.register_server(test_db, server_create)
223+
# Mock the DbServer constructor
224+
mock_db_server = MagicMock(spec=DbServer)
225+
mock_db_server.tools = []
226+
mock_db_server.resources = []
227+
mock_db_server.prompts = []
228+
229+
with patch('mcpgateway.services.server_service.DbServer', return_value=mock_db_server):
230+
with pytest.raises(ServerError) as exc:
231+
await server_service.register_server(test_db, server_create)
209232

210233
assert "Tool with id 999 does not exist" in str(exc.value)
211234
test_db.rollback.assert_called_once()

tests/unit/mcpgateway/services/test_tool_service.py

Lines changed: 49 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
Tests for tool service implementation.
99
"""
1010

11-
from unittest.mock import ANY, AsyncMock, MagicMock, Mock
11+
from unittest.mock import ANY, AsyncMock, MagicMock, Mock, patch
1212

1313
import pytest
1414
from sqlalchemy.exc import IntegrityError
@@ -52,6 +52,7 @@ def mock_tool():
5252
tool.auth_username = None
5353
tool.auth_password = None
5454
tool.auth_token = None
55+
tool.auth_value = None # Add this field
5556
tool.gateway_id = None
5657

5758
# Set up metrics
@@ -110,6 +111,7 @@ async def test_register_tool(self, tool_service, mock_tool, test_db):
110111
is_active=True,
111112
gateway_id=None,
112113
execution_count=0,
114+
auth=None, # Add auth field
113115
metrics={
114116
"total_executions": 0,
115117
"successful_executions": 0,
@@ -168,14 +170,12 @@ async def test_register_tool_name_conflict(self, tool_service, mock_tool, test_d
168170
request_type="POST",
169171
)
170172

171-
# Should raise conflict error
172-
with pytest.raises(ToolNameConflictError) as exc_info:
173+
# Should raise ToolError wrapping ToolNameConflictError
174+
with pytest.raises(ToolError) as exc_info:
173175
await tool_service.register_tool(test_db, tool_create)
174176

177+
# The service wraps exceptions, so check the message
175178
assert "Tool already exists with name" in str(exc_info.value)
176-
assert exc_info.value.name == "test_tool"
177-
assert exc_info.value.is_active == mock_tool.is_active
178-
assert exc_info.value.tool_id == mock_tool.id
179179

180180
@pytest.mark.asyncio
181181
async def test_register_tool_db_integrity_error(self, tool_service, test_db):
@@ -208,8 +208,10 @@ async def test_register_tool_db_integrity_error(self, tool_service, test_db):
208208
async def test_list_tools(self, tool_service, mock_tool, test_db):
209209
"""Test listing tools."""
210210
# Mock DB to return a list of tools
211+
mock_scalars = MagicMock()
212+
mock_scalars.all.return_value = [mock_tool]
211213
mock_scalar_result = MagicMock()
212-
mock_scalar_result.all.return_value = [mock_tool]
214+
mock_scalar_result.scalars.return_value = mock_scalars
213215
mock_execute = Mock(return_value=mock_scalar_result)
214216
test_db.execute = mock_execute
215217

@@ -229,6 +231,7 @@ async def test_list_tools(self, tool_service, mock_tool, test_db):
229231
is_active=True,
230232
gateway_id=None,
231233
execution_count=0,
234+
auth=None, # Add auth field
232235
metrics={
233236
"total_executions": 0,
234237
"successful_executions": 0,
@@ -275,6 +278,7 @@ async def test_get_tool(self, tool_service, mock_tool, test_db):
275278
is_active=True,
276279
gateway_id=None,
277280
execution_count=0,
281+
auth=None, # Add auth field
278282
metrics={
279283
"total_executions": 0,
280284
"successful_executions": 0,
@@ -338,8 +342,8 @@ async def test_delete_tool_not_found(self, tool_service, test_db):
338342
# Mock DB get to return None
339343
test_db.get = Mock(return_value=None)
340344

341-
# Should raise NotFoundError
342-
with pytest.raises(ToolNotFoundError) as exc_info:
345+
# The service wraps the exception in ToolError
346+
with pytest.raises(ToolError) as exc_info:
343347
await tool_service.delete_tool(test_db, 999)
344348

345349
assert "Tool not found: 999" in str(exc_info.value)
@@ -372,6 +376,7 @@ async def test_toggle_tool_status(self, tool_service, mock_tool, test_db):
372376
is_active=False, # Changed to False
373377
gateway_id=None,
374378
execution_count=0,
379+
auth=None, # Add auth field
375380
metrics={
376381
"total_executions": 0,
377382
"successful_executions": 0,
@@ -436,6 +441,7 @@ async def test_update_tool(self, tool_service, mock_tool, test_db):
436441
is_active=True,
437442
gateway_id=None,
438443
execution_count=0,
444+
auth=None, # Add auth field
439445
metrics={
440446
"total_executions": 0,
441447
"successful_executions": 0,
@@ -499,14 +505,11 @@ async def test_update_tool_name_conflict(self, tool_service, mock_tool, test_db)
499505
name="existing_tool", # Name that conflicts with another tool
500506
)
501507

502-
# Should raise conflict error
503-
with pytest.raises(ToolNameConflictError) as exc_info:
508+
# The service wraps the exception in ToolError
509+
with pytest.raises(ToolError) as exc_info:
504510
await tool_service.update_tool(test_db, 1, tool_update)
505511

506512
assert "Tool already exists with name" in str(exc_info.value)
507-
assert exc_info.value.name == "existing_tool"
508-
assert exc_info.value.is_active == conflicting_tool.is_active
509-
assert exc_info.value.tool_id == conflicting_tool.id
510513

511514
@pytest.mark.asyncio
512515
async def test_update_tool_not_found(self, tool_service, test_db):
@@ -519,8 +522,8 @@ async def test_update_tool_not_found(self, tool_service, test_db):
519522
name="updated_tool",
520523
)
521524

522-
# Should raise NotFoundError
523-
with pytest.raises(ToolNotFoundError) as exc_info:
525+
# The service wraps the exception in ToolError
526+
with pytest.raises(ToolError) as exc_info:
524527
await tool_service.update_tool(test_db, 999, tool_update)
525528

526529
assert "Tool not found: 999" in str(exc_info.value)
@@ -567,6 +570,7 @@ async def test_invoke_tool_rest(self, tool_service, mock_tool, test_db):
567570
mock_tool.integration_type = "REST"
568571
mock_tool.request_type = "POST"
569572
mock_tool.jsonpath_filter = ""
573+
mock_tool.auth_value = None # No auth
570574

571575
# Mock DB to return the tool
572576
mock_scalar = Mock()
@@ -576,30 +580,30 @@ async def test_invoke_tool_rest(self, tool_service, mock_tool, test_db):
576580
# Mock HTTP client response
577581
mock_response = AsyncMock()
578582
mock_response.raise_for_status = AsyncMock()
579-
mock_response.json.return_value = {"result": {"content": [{"type": "text", "text": "REST tool response"}]}}
583+
mock_response.status_code = 200
584+
mock_response.json = Mock(return_value={"result": "REST tool response"}) # Make json() synchronous
580585
tool_service._http_client.request.return_value = mock_response
581586

582587
# Mock metrics recording
583588
tool_service._record_tool_metric = AsyncMock()
584589

585-
# Invoke tool
586-
result = await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"})
590+
# Mock decode_auth to return empty dict when auth_value is None
591+
# Mock extract_using_jq to return the input unmodified when filter is empty
592+
with patch('mcpgateway.services.tool_service.decode_auth', return_value={}), \
593+
patch('mcpgateway.config.extract_using_jq', return_value={"result": "REST tool response"}):
594+
# Invoke tool
595+
result = await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"})
587596

588597
# Verify HTTP request
589598
tool_service._http_client.request.assert_called_once_with(
590599
"POST",
591600
mock_tool.url,
592-
json={
593-
"jsonrpc": "2.0",
594-
"method": "test_tool",
595-
"params": {"param": "value"},
596-
"id": 1,
597-
},
601+
json={"param": "value"},
598602
headers=mock_tool.headers,
599603
)
600604

601605
# Verify result
602-
assert any(content.text == "REST tool response" for content in result.content)
606+
assert result.content[0].text == '{\n "result": "REST tool response"\n}'
603607

604608
# Verify metrics recorded
605609
tool_service._record_tool_metric.assert_called_once_with(
@@ -616,32 +620,35 @@ async def test_invoke_tool_error(self, tool_service, mock_tool, test_db):
616620
# Configure tool
617621
mock_tool.integration_type = "REST"
618622
mock_tool.request_type = "POST"
623+
mock_tool.auth_value = None # No auth
619624

620625
# Mock DB to return the tool
621626
mock_scalar = Mock()
622627
mock_scalar.scalar_one_or_none.return_value = mock_tool
623628
test_db.execute = Mock(return_value=mock_scalar)
624629

625-
# Mock HTTP client to raise an error
626-
tool_service._http_client.request.side_effect = Exception("HTTP error")
630+
# Mock decode_auth to return empty dict
631+
with patch('mcpgateway.services.tool_service.decode_auth', return_value={}):
632+
# Mock HTTP client to raise an error
633+
tool_service._http_client.request.side_effect = Exception("HTTP error")
627634

628-
# Mock metrics recording
629-
tool_service._record_tool_metric = AsyncMock()
635+
# Mock metrics recording
636+
tool_service._record_tool_metric = AsyncMock()
630637

631-
# Should raise ToolInvocationError
632-
with pytest.raises(ToolInvocationError) as exc_info:
633-
await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"})
638+
# Should raise ToolInvocationError
639+
with pytest.raises(ToolInvocationError) as exc_info:
640+
await tool_service.invoke_tool(test_db, "test_tool", {"param": "value"})
634641

635-
assert "Tool invocation failed: HTTP error" in str(exc_info.value)
642+
assert "Tool invocation failed: HTTP error" in str(exc_info.value)
636643

637-
# Verify metrics recorded with error
638-
tool_service._record_tool_metric.assert_called_once_with(
639-
test_db,
640-
mock_tool,
641-
ANY, # Start time
642-
False, # Failed
643-
"HTTP error", # Error message
644-
)
644+
# Verify metrics recorded with error
645+
tool_service._record_tool_metric.assert_called_once_with(
646+
test_db,
647+
mock_tool,
648+
ANY, # Start time
649+
False, # Failed
650+
"HTTP error", # Error message
651+
)
645652

646653
@pytest.mark.asyncio
647654
async def test_reset_metrics(self, tool_service, test_db):

0 commit comments

Comments
 (0)