Skip to content

Commit fbb3445

Browse files
committed
Improve test coverage
Signed-off-by: Mihai Criveti <[email protected]>
1 parent cb467ff commit fbb3445

File tree

8 files changed

+771
-6
lines changed

8 files changed

+771
-6
lines changed
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# -*- coding: utf-8 -*-
2+
"""Memory-backend unit tests for `session_registry.py`.
3+
4+
Copyright 2025
5+
SPDX-License-Identifier: Apache-2.0
6+
Author: Mihai Criveti
7+
8+
These tests cover the essential public behaviours of SessionRegistry when
9+
configured with backend="memory":
10+
11+
* add_session / get_session / get_session_sync
12+
* remove_session (disconnects transport & clears cache)
13+
* broadcast + respond (with generate_response monkey-patched)
14+
15+
No Redis or SQLAlchemy fixtures are required, making the suite fast and
16+
portable.
17+
"""
18+
19+
import asyncio
20+
21+
import pytest
22+
23+
# Import SessionRegistry – works whether the file lives inside the package or beside it
24+
try:
25+
from mcpgateway.cache.session_registry import SessionRegistry
26+
except (ModuleNotFoundError, ImportError): # pragma: no cover
27+
from session_registry import SessionRegistry # type: ignore
28+
29+
30+
class FakeSSETransport:
31+
"""Minimal stub implementing only the methods SessionRegistry uses."""
32+
33+
def __init__(self, session_id: str):
34+
self.session_id = session_id
35+
self._connected = True
36+
self.sent_messages = []
37+
38+
async def disconnect(self):
39+
self._connected = False
40+
41+
async def is_connected(self):
42+
return self._connected
43+
44+
async def send_message(self, msg):
45+
self.sent_messages.append(msg)
46+
47+
48+
# ---------------------------------------------------------------------------
49+
# Pytest fixtures
50+
# ---------------------------------------------------------------------------
51+
52+
53+
@pytest.fixture(name="event_loop")
54+
def _event_loop_fixture():
55+
loop = asyncio.new_event_loop()
56+
yield loop
57+
loop.close()
58+
59+
60+
@pytest.fixture()
61+
async def registry():
62+
reg = SessionRegistry(backend="memory")
63+
await reg.initialize()
64+
yield reg
65+
await reg.shutdown()
66+
67+
68+
# ---------------------------------------------------------------------------
69+
# Tests
70+
# ---------------------------------------------------------------------------
71+
72+
73+
@pytest.mark.asyncio
74+
async def test_add_and_get_session(registry):
75+
tr = FakeSSETransport("abc")
76+
await registry.add_session("abc", tr)
77+
78+
assert await registry.get_session("abc") is tr
79+
assert registry.get_session_sync("abc") is tr
80+
81+
82+
@pytest.mark.asyncio
83+
async def test_remove_session(registry):
84+
tr = FakeSSETransport("dead")
85+
await registry.add_session("dead", tr)
86+
87+
await registry.remove_session("dead")
88+
89+
assert not await tr.is_connected()
90+
assert registry.get_session_sync("dead") is None
91+
92+
93+
@pytest.mark.asyncio
94+
async def test_broadcast_and_respond(monkeypatch, registry):
95+
"""Ensure broadcast stores the message and respond delivers it via generate_response."""
96+
tr = FakeSSETransport("xyz")
97+
await registry.add_session("xyz", tr)
98+
99+
captured = {}
100+
101+
async def fake_generate_response(*, message, transport, **_):
102+
captured["transport"] = transport
103+
captured["message"] = message
104+
105+
monkeypatch.setattr(registry, "generate_response", fake_generate_response)
106+
107+
ping_msg = {"method": "ping", "id": 1, "params": {}}
108+
await registry.broadcast("xyz", ping_msg)
109+
110+
# respond should call our fake_generate_response exactly once
111+
await registry.respond(
112+
server_id=None,
113+
user={},
114+
session_id="xyz",
115+
base_url="http://localhost",
116+
)
117+
118+
assert captured["transport"] is tr
119+
assert captured["message"] == ping_msg
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# -*- coding: utf-8 -*-
2+
"""Unit tests for Federation Discovery Service.
3+
4+
Copyright 2025
5+
SPDX-License-Identifier: Apache-2.0
6+
Authors: Mihai Criveti
7+
8+
Comprehensive unit tests for the discovery service module.
9+
"""
10+
11+
# tests/test_discovery.py
12+
import asyncio
13+
from datetime import datetime
14+
15+
import pytest
16+
17+
from mcpgateway.federation.discovery import (
18+
PROTOCOL_VERSION,
19+
DiscoveryService,
20+
)
21+
22+
# ---------------------------------------------------------------------------
23+
# Helpers / fixtures
24+
# ---------------------------------------------------------------------------
25+
26+
27+
@pytest.fixture
28+
async def discovery():
29+
"""
30+
Provide a DiscoveryService instance whose network-touching method
31+
`_get_gateway_info` is monkey-patched so no real HTTP requests are made.
32+
"""
33+
ds = DiscoveryService()
34+
35+
async def _fake_gateway_info(url: str): # noqa: D401, ANN001
36+
# Return an *empty* capabilities object – structure is unimportant here.
37+
from mcpgateway.types import ServerCapabilities
38+
39+
return ServerCapabilities()
40+
41+
# Patch the network call
42+
ds._get_gateway_info = _fake_gateway_info # type: ignore[attr-defined]
43+
44+
yield ds
45+
await ds.stop() # ensure graceful cleanup
46+
47+
48+
# ---------------------------------------------------------------------------
49+
# Tests
50+
# ---------------------------------------------------------------------------
51+
52+
53+
@pytest.mark.anyio
54+
async def test_add_peer_success(discovery):
55+
added = await discovery.add_peer("http://example.com", source="test", name="Example")
56+
assert added is True, "first call should add the peer"
57+
58+
peers = discovery.get_discovered_peers()
59+
assert len(peers) == 1
60+
peer = peers[0]
61+
62+
assert peer.url == "http://example.com"
63+
assert peer.name == "Example"
64+
assert peer.protocol_version == PROTOCOL_VERSION
65+
assert peer.source == "test"
66+
67+
68+
@pytest.mark.anyio
69+
async def test_add_duplicate_peer_is_ignored(discovery):
70+
await discovery.add_peer("http://dup.com", source="test")
71+
second_add = await discovery.add_peer("http://dup.com", source="test-again")
72+
assert second_add is False, "duplicate add should be a no-op"
73+
74+
peers = discovery.get_discovered_peers()
75+
assert len(peers) == 1, "still only one peer stored"
76+
77+
78+
@pytest.mark.anyio
79+
async def test_add_peer_invalid_url_returns_false(discovery):
80+
added = await discovery.add_peer("not-a-valid-url", source="test")
81+
assert added is False
82+
assert discovery.get_discovered_peers() == []
83+
84+
85+
@pytest.mark.anyio
86+
async def test_refresh_peer_updates_last_seen(discovery):
87+
await discovery.add_peer("http://refresh.me", source="test")
88+
peer = discovery.get_discovered_peers()[0]
89+
first_seen = peer.last_seen
90+
91+
# Wait a moment to ensure a measurable delta
92+
await asyncio.sleep(0.01)
93+
94+
refreshed = await discovery.refresh_peer("http://refresh.me")
95+
assert refreshed is True
96+
assert peer.last_seen > first_seen

tests/unit/mcpgateway/test_forward.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# -*- coding: utf-8 -*-
2+
"""Unit tests for Federation Forwarding Service.
3+
4+
Copyright 2025
5+
SPDX-License-Identifier: Apache-2.0
6+
Authors: Mihai Criveti
7+
8+
Comprehensive unit tests for the forwarding service module.
9+
"""
10+
11+
import asyncio
12+
from datetime import datetime
13+
14+
import pytest
15+
16+
from mcpgateway.federation.forward import ForwardingError, ForwardingService, ToolResult
17+
from mcpgateway.types import TextContent
18+
19+
# ---------------------------------------------------------------------------
20+
# Tiny dummy ORM objects + fake Session
21+
# ---------------------------------------------------------------------------
22+
23+
24+
class DummyGateway:
25+
def __init__(self, id_, name, url, is_active=True):
26+
self.id = id_
27+
self.name = name
28+
self.url = url
29+
self.is_active = is_active
30+
self.last_seen: datetime | None = None
31+
32+
33+
class DummyTool:
34+
def __init__(self, id_, name, gateway_id=None, is_active=True):
35+
self.id = id_
36+
self.name = name
37+
self.gateway_id = gateway_id
38+
self.is_active = is_active
39+
40+
41+
class _FakeResult:
42+
def __init__(self, scalar=None, scalar_list=None):
43+
self._scalar = scalar
44+
self._scalar_list = scalar_list or []
45+
46+
def scalar_one_or_none(self):
47+
return self._scalar
48+
49+
class _Proxy:
50+
def __init__(self, items):
51+
self._items = items
52+
53+
def all(self):
54+
return self._items
55+
56+
def scalars(self):
57+
return self._Proxy(self._scalar_list)
58+
59+
60+
class FakeSession:
61+
def __init__(self, gateways=None, tools=None):
62+
self._gateways = gateways or []
63+
self._tools = tools or []
64+
65+
def get(self, _model, pk):
66+
for gw in self._gateways:
67+
if gw.id == pk:
68+
return gw
69+
return None
70+
71+
def execute(self, query): # pragma: no cover
72+
txt = str(query).lower()
73+
if "tool" in txt:
74+
return _FakeResult(scalar=self._tools[0] if self._tools else None)
75+
return _FakeResult(scalar_list=self._gateways)
76+
77+
78+
# ---------------------------------------------------------------------------
79+
# ForwardingService fixture with network stubbed out
80+
# ---------------------------------------------------------------------------
81+
82+
83+
@pytest.fixture
84+
async def fwd_service(monkeypatch):
85+
svc = ForwardingService()
86+
87+
class _FakeResp:
88+
def __init__(self, payload):
89+
self._payload = payload
90+
91+
def raise_for_status(self):
92+
pass
93+
94+
def json(self):
95+
return self._payload
96+
97+
async def fake_post(url, json=None, headers=None): # noqa: D401
98+
return _FakeResp({"result": {"method": json["method"]}})
99+
100+
monkeypatch.setattr(svc._http_client, "post", fake_post)
101+
102+
from mcpgateway.config import settings
103+
104+
monkeypatch.setattr(settings, "max_tool_retries", 1, raising=False)
105+
yield svc
106+
await svc.stop()
107+
108+
109+
# ---------------------------------------------------------------------------
110+
# Tests
111+
# ---------------------------------------------------------------------------
112+
113+
114+
@pytest.mark.anyio
115+
async def test_forward_to_gateway_success(fwd_service):
116+
gw = DummyGateway(1, "Alpha", "http://alpha")
117+
db = FakeSession(gateways=[gw])
118+
119+
result = await fwd_service._forward_to_gateway(db, 1, "ping", {"x": 1})
120+
assert result == {"method": "ping"}
121+
assert isinstance(gw.last_seen, datetime)
122+
123+
124+
@pytest.mark.anyio
125+
async def test_forward_tool_request_parses_result(monkeypatch, fwd_service):
126+
# Fake gateway call to produce valid ContentType payload
127+
async def fake_forward(db, gid, method, params): # noqa: D401
128+
assert method == "tools/invoke"
129+
payload = {
130+
"content": [TextContent(type="text", text="OK").model_dump()],
131+
"is_error": False,
132+
}
133+
return payload
134+
135+
monkeypatch.setattr(fwd_service, "_forward_to_gateway", fake_forward) # type: ignore[arg-type]
136+
137+
tool = DummyTool(1, "echo", gateway_id=42)
138+
db = FakeSession(gateways=[DummyGateway(42, "EchoGW", "http://echo")], tools=[tool])
139+
140+
result: ToolResult = await fwd_service.forward_tool_request(db, "echo", {"msg": "hi"})
141+
assert result.is_error is False
142+
assert isinstance(result.content[0], TextContent) and result.content[0].text == "OK"
143+
144+
145+
@pytest.mark.anyio
146+
async def test_rate_limit(monkeypatch):
147+
from mcpgateway.config import settings
148+
149+
monkeypatch.setattr(settings, "tool_rate_limit", 2, raising=False)
150+
151+
svc = ForwardingService()
152+
url = "http://beta"
153+
assert svc._check_rate_limit(url)
154+
assert svc._check_rate_limit(url)
155+
assert svc._check_rate_limit(url) is False # third call exceeds limit
156+
157+
158+
@pytest.mark.anyio
159+
async def test_forward_to_all_partial_success(monkeypatch, fwd_service):
160+
gw_ok = DummyGateway(1, "GoodGW", "http://good")
161+
gw_bad = DummyGateway(2, "BadGW", "http://bad")
162+
163+
async def fake_forward(db, gid, method, params): # noqa: D401
164+
if gid == 1:
165+
return "ok!"
166+
raise ForwardingError("boom")
167+
168+
monkeypatch.setattr(fwd_service, "_forward_to_gateway", fake_forward) # type: ignore[arg-type]
169+
170+
db = FakeSession(gateways=[gw_ok, gw_bad])
171+
results = await fwd_service._forward_to_all(db, "stats/get")
172+
assert results == ["ok!"]

0 commit comments

Comments
 (0)