Skip to content

Commit 412622e

Browse files
authored
unit test (#428)
Signed-off-by: Rakhi Dutta <[email protected]>
1 parent fdc80f2 commit 412622e

File tree

1 file changed

+360
-0
lines changed

1 file changed

+360
-0
lines changed

tests/unit/mcpgateway/test_db.py

Lines changed: 360 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,360 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
4+
Copyright 2025
5+
SPDX-License-Identifier: Apache-2.0
6+
Authors: Mihai Criveti
7+
8+
"""
9+
10+
# Standard
11+
from datetime import datetime, timezone, timedelta
12+
from unittest.mock import MagicMock
13+
14+
# Third-Party
15+
import pytest
16+
from sqlalchemy.exc import SQLAlchemyError
17+
18+
# First-Party
19+
import mcpgateway.db as db
20+
21+
# --- utc_now ---
22+
def test_utc_now_returns_utc_datetime():
23+
now = db.utc_now()
24+
assert isinstance(now, datetime)
25+
assert now.tzinfo == timezone.utc
26+
27+
# --- Tool metrics properties ---
28+
def make_tool_with_metrics(metrics):
29+
tool = db.Tool()
30+
tool.metrics = metrics
31+
return tool
32+
33+
def test_tool_metrics_properties():
34+
now = datetime.now(timezone.utc)
35+
metrics = [
36+
db.ToolMetric(response_time=1.0, is_success=True, timestamp=now),
37+
db.ToolMetric(response_time=2.0, is_success=False, timestamp=now + timedelta(seconds=1)),
38+
]
39+
tool = make_tool_with_metrics(metrics)
40+
assert tool.execution_count == 2
41+
assert tool.successful_executions == 1
42+
assert tool.failed_executions == 1
43+
assert tool.failure_rate == 0.5
44+
assert tool.min_response_time == 1.0
45+
assert tool.max_response_time == 2.0
46+
assert tool.avg_response_time == 1.5
47+
assert tool.last_execution_time == now + timedelta(seconds=1)
48+
summary = tool.metrics_summary
49+
assert summary["total_executions"] == 2
50+
assert summary["failure_rate"] == 0.5
51+
52+
def test_tool_metrics_properties_empty():
53+
tool = db.Tool()
54+
tool.metrics = []
55+
assert tool.execution_count == 0
56+
assert tool.successful_executions == 0
57+
assert tool.failed_executions == 0
58+
assert tool.failure_rate == 0.0
59+
assert tool.min_response_time is None
60+
assert tool.max_response_time is None
61+
assert tool.avg_response_time is None
62+
assert tool.last_execution_time is None
63+
64+
# --- Resource metrics properties ---
65+
def make_resource_with_metrics(metrics):
66+
resource = db.Resource()
67+
resource.metrics = metrics
68+
return resource
69+
70+
def test_resource_metrics_properties():
71+
now = datetime.now(timezone.utc)
72+
metrics = [
73+
db.ResourceMetric(response_time=1.0, is_success=True, timestamp=now),
74+
db.ResourceMetric(response_time=2.0, is_success=False, timestamp=now + timedelta(seconds=1)),
75+
]
76+
resource = make_resource_with_metrics(metrics)
77+
assert resource.execution_count == 2
78+
assert resource.successful_executions == 1
79+
assert resource.failed_executions == 1
80+
assert resource.failure_rate == 0.5
81+
assert resource.min_response_time == 1.0
82+
assert resource.max_response_time == 2.0
83+
assert resource.avg_response_time == 1.5
84+
assert resource.last_execution_time == now + timedelta(seconds=1)
85+
86+
87+
def test_resource_metrics_properties_empty():
88+
resource = db.Resource()
89+
resource.metrics = []
90+
assert resource.execution_count == 0
91+
assert resource.successful_executions == 0
92+
assert resource.failed_executions == 0
93+
assert resource.failure_rate == 0.0
94+
assert resource.min_response_time is None
95+
assert resource.max_response_time is None
96+
assert resource.avg_response_time is None
97+
assert resource.last_execution_time is None
98+
99+
# --- Prompt metrics properties ---
100+
def make_prompt_with_metrics(metrics):
101+
prompt = db.Prompt()
102+
prompt.metrics = metrics
103+
return prompt
104+
105+
def test_prompt_metrics_properties():
106+
now = datetime.now(timezone.utc)
107+
metrics = [
108+
db.PromptMetric(response_time=1.0, is_success=True, timestamp=now),
109+
db.PromptMetric(response_time=2.0, is_success=False, timestamp=now + timedelta(seconds=1)),
110+
]
111+
prompt = make_prompt_with_metrics(metrics)
112+
assert prompt.execution_count == 2
113+
assert prompt.successful_executions == 1
114+
assert prompt.failed_executions == 1
115+
assert prompt.failure_rate == 0.5
116+
assert prompt.min_response_time == 1.0
117+
assert prompt.max_response_time == 2.0
118+
assert prompt.avg_response_time == 1.5
119+
assert prompt.last_execution_time == now + timedelta(seconds=1)
120+
121+
def test_prompt_metrics_properties_empty():
122+
prompt = db.Prompt()
123+
prompt.metrics = []
124+
assert prompt.execution_count == 0
125+
assert prompt.successful_executions == 0
126+
assert prompt.failed_executions == 0
127+
assert prompt.failure_rate == 0.0
128+
assert prompt.min_response_time is None
129+
assert prompt.max_response_time is None
130+
assert prompt.avg_response_time is None
131+
assert prompt.last_execution_time is None
132+
133+
# --- Server metrics properties ---
134+
def make_server_with_metrics(metrics):
135+
server = db.Server()
136+
server.metrics = metrics
137+
return server
138+
139+
def test_server_metrics_properties():
140+
now = datetime.now(timezone.utc)
141+
metrics = [
142+
db.ServerMetric(response_time=1.0, is_success=True, timestamp=now),
143+
db.ServerMetric(response_time=2.0, is_success=False, timestamp=now + timedelta(seconds=1)),
144+
]
145+
server = make_server_with_metrics(metrics)
146+
assert server.execution_count == 2
147+
assert server.successful_executions == 1
148+
assert server.failed_executions == 1
149+
assert server.failure_rate == 0.5
150+
assert server.min_response_time == 1.0
151+
assert server.max_response_time == 2.0
152+
assert server.avg_response_time == 1.5
153+
assert server.last_execution_time == now + timedelta(seconds=1)
154+
155+
def test_server_metrics_properties_empty():
156+
server = db.Server()
157+
server.metrics = []
158+
assert server.execution_count == 0
159+
assert server.successful_executions == 0
160+
assert server.failed_executions == 0
161+
assert server.failure_rate == 0.0
162+
assert server.min_response_time is None
163+
assert server.max_response_time is None
164+
assert server.avg_response_time is None
165+
assert server.last_execution_time is None
166+
167+
# --- Resource content property ---
168+
def test_resource_content_text():
169+
resource = db.Resource()
170+
resource.text_content = "hello"
171+
resource.binary_content = None
172+
resource.uri = "uri"
173+
resource.mime_type = "text/plain"
174+
content = resource.content
175+
assert content.text == "hello"
176+
assert content.type == "resource"
177+
assert content.uri == "uri"
178+
assert content.mime_type == "text/plain"
179+
180+
def test_resource_content_binary():
181+
resource = db.Resource()
182+
resource.text_content = None
183+
resource.binary_content = b"data"
184+
resource.uri = "uri"
185+
resource.mime_type = None
186+
content = resource.content
187+
assert content.blob == b"data"
188+
assert content.mime_type == "application/octet-stream"
189+
190+
def test_resource_content_none():
191+
resource = db.Resource()
192+
resource.text_content = None
193+
resource.binary_content = None
194+
with pytest.raises(ValueError):
195+
_ = resource.content
196+
197+
def test_resource_content_text_and_binary():
198+
resource = db.Resource()
199+
resource.text_content = "text"
200+
resource.binary_content = b"binary"
201+
resource.uri = "uri"
202+
resource.mime_type = "text/plain"
203+
content = resource.content
204+
assert content.text == "text"
205+
assert not hasattr(content, "blob") or content.blob is None
206+
207+
# --- Prompt argument validation ---
208+
def test_prompt_validate_arguments_valid():
209+
prompt = db.Prompt()
210+
prompt.argument_schema = {"type": "object", "properties": {"a": {"type": "string"}}, "required": ["a"]}
211+
prompt.validate_arguments({"a": "x"})
212+
213+
def test_prompt_validate_arguments_invalid():
214+
prompt = db.Prompt()
215+
prompt.argument_schema = {"type": "object", "properties": {"a": {"type": "string"}}, "required": ["a"]}
216+
with pytest.raises(ValueError):
217+
prompt.validate_arguments({})
218+
219+
def test_prompt_validate_arguments_missing_schema():
220+
prompt = db.Prompt()
221+
prompt.argument_schema = None
222+
with pytest.raises(Exception):
223+
prompt.validate_arguments({"a": "x"})
224+
225+
# --- Validation listeners ---
226+
def test_validate_tool_schema_valid():
227+
class Target:
228+
input_schema = {"type": "object"}
229+
db.validate_tool_schema(None, None, Target())
230+
231+
def test_validate_tool_schema_invalid():
232+
class Target:
233+
input_schema = {"type": "invalid"}
234+
with pytest.raises(ValueError):
235+
db.validate_tool_schema(None, None, Target())
236+
237+
def test_validate_tool_name_valid():
238+
class Target:
239+
name = "valid_name-123"
240+
db.validate_tool_name(None, None, Target())
241+
242+
def test_validate_tool_name_invalid():
243+
class Target:
244+
name = "invalid name!"
245+
with pytest.raises(ValueError):
246+
db.validate_tool_name(None, None, Target())
247+
248+
def test_validate_prompt_schema_valid():
249+
class Target:
250+
argument_schema = {"type": "object"}
251+
db.validate_prompt_schema(None, None, Target())
252+
253+
def test_validate_prompt_schema_invalid():
254+
class Target:
255+
argument_schema = {"type": "invalid"}
256+
with pytest.raises(ValueError):
257+
db.validate_prompt_schema(None, None, Target())
258+
259+
def test_validate_tool_schema_missing():
260+
class Target:
261+
pass
262+
db.validate_tool_schema(None, None, Target()) # Should not raise
263+
264+
def test_validate_tool_name_missing():
265+
class Target:
266+
pass
267+
db.validate_tool_name(None, None, Target()) # Should not raise
268+
269+
def test_validate_prompt_schema_missing():
270+
class Target:
271+
pass
272+
db.validate_prompt_schema(None, None, Target()) # Should not raise
273+
274+
# --- get_db generator ---
275+
def test_get_db_yields_and_closes(monkeypatch):
276+
class DummySession:
277+
def close(self):
278+
self.closed = True
279+
dummy = DummySession()
280+
monkeypatch.setattr(db, "SessionLocal", lambda: dummy)
281+
gen = db.get_db()
282+
session = next(gen)
283+
assert session is dummy
284+
try:
285+
next(gen)
286+
except StopIteration:
287+
pass
288+
assert hasattr(dummy, "closed")
289+
290+
def test_get_db_closes_on_exception(monkeypatch):
291+
class DummySession:
292+
def close(self):
293+
self.closed = True
294+
295+
dummy = DummySession()
296+
monkeypatch.setattr(db, "SessionLocal", lambda: dummy)
297+
298+
gen = db.get_db()
299+
session = next(gen)
300+
assert session is dummy
301+
302+
try:
303+
gen.throw(Exception("fail"))
304+
except Exception:
305+
pass
306+
307+
assert hasattr(dummy, "closed")
308+
# --- init_db ---
309+
def test_init_db_success(monkeypatch):
310+
monkeypatch.setattr(db.Base.metadata, "create_all", lambda bind: True)
311+
db.init_db()
312+
313+
def test_init_db_failure(monkeypatch):
314+
def fail(*a, **k):
315+
raise SQLAlchemyError("fail")
316+
monkeypatch.setattr(db.Base.metadata, "create_all", fail)
317+
with pytest.raises(Exception):
318+
db.init_db()
319+
320+
# --- Gateway event listener ---
321+
def test_update_tool_names_on_gateway_update(monkeypatch):
322+
class DummyGateway:
323+
id = "gwid"
324+
name = "GatewayName"
325+
class DummyConnection:
326+
def execute(self, stmt):
327+
self.executed = True
328+
class DummyMapper:
329+
pass
330+
monkeypatch.setattr(db.Tool, "__table__", MagicMock())
331+
monkeypatch.setattr(db, "slugify", lambda name: "slug")
332+
monkeypatch.setattr(db.settings, "gateway_tool_name_separator", "-")
333+
dummy_gateway = DummyGateway()
334+
dummy_connection = DummyConnection()
335+
dummy_mapper = DummyMapper()
336+
# Simulate get_history returning an object with has_changes = True
337+
class DummyHistory:
338+
def has_changes(self):
339+
return True
340+
monkeypatch.setattr(db, "get_history", lambda target, name: DummyHistory())
341+
db.update_tool_names_on_gateway_update(dummy_mapper, dummy_connection, dummy_gateway)
342+
assert hasattr(dummy_connection, "executed")
343+
344+
# --- SessionRecord and SessionMessageRecord ---
345+
def test_session_record_and_message_record():
346+
session = db.SessionRecord()
347+
session.session_id = "sid"
348+
session.data = "data"
349+
session.created_at = datetime.now(timezone.utc)
350+
session.last_accessed = datetime.now(timezone.utc)
351+
msg = db.SessionMessageRecord()
352+
msg.session_id = "sid"
353+
msg.message = "msg"
354+
msg.created_at = datetime.now(timezone.utc)
355+
msg.last_accessed = datetime.now(timezone.utc)
356+
session.messages = [msg]
357+
msg.session = session
358+
assert session.session_id == msg.session_id
359+
assert session.messages[0].message == "msg"
360+
assert msg.session.data == "data"

0 commit comments

Comments
 (0)