Skip to content

Commit c8f9a0c

Browse files
authored
Test case update (#346)
Signed-off-by: Mohan Lakshmaiah <[email protected]>
1 parent 1ee0ca4 commit c8f9a0c

File tree

1 file changed

+178
-2
lines changed

1 file changed

+178
-2
lines changed

tests/unit/mcpgateway/test_translate.py

Lines changed: 178 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
import sys
4242
import types
4343
from typing import Sequence
44-
from unittest.mock import AsyncMock, Mock
44+
from unittest.mock import AsyncMock, Mock, MagicMock
4545

4646
# Third-Party
4747
from fastapi.testclient import TestClient
@@ -69,6 +69,23 @@ def translate():
6969
sys.modules.pop("mcpgateway.translate", None)
7070
return importlib.import_module("mcpgateway.translate")
7171

72+
def test_translate_importerror(monkeypatch):
73+
# Remove httpx from sys.modules if present
74+
sys.modules.pop("httpx", None)
75+
# Simulate ImportError when importing httpx
76+
import builtins
77+
real_import = builtins.__import__
78+
79+
def fake_import(name, *args, **kwargs):
80+
if name == "httpx":
81+
raise ImportError("No module named 'httpx'")
82+
return real_import(name, *args, **kwargs)
83+
84+
monkeypatch.setattr(builtins, "__import__", fake_import)
85+
# Reload the module to trigger the import block
86+
import mcpgateway.translate as translate
87+
importlib.reload(translate)
88+
assert translate.httpx is None
7289

7390
# ---------------------------------------------------------------------------#
7491
# Dummy subprocess plumbing #
@@ -137,12 +154,31 @@ def put_nowait(self, *_): # type: ignore[override]
137154
await ps.publish("x")
138155
assert bad not in ps._subscribers
139156

157+
@pytest.mark.asyncio
158+
async def test_pubsub_double_unsubscribe_and_publish_no_subs(translate):
159+
ps = translate._PubSub()
160+
q = ps.subscribe()
161+
ps.unsubscribe(q)
162+
# Unsubscribing again should not raise
163+
ps.unsubscribe(q)
164+
# Publishing with no subscribers should not raise
165+
await ps.publish("no one listens")
140166

141167
# ---------------------------------------------------------------------------#
142168
# Tests: StdIOEndpoint #
143169
# ---------------------------------------------------------------------------#
144170

145171

172+
@pytest.mark.asyncio
173+
async def test_stdio_endpoint_stop_when_proc_none(translate):
174+
"""Test StdIOEndpoint.stop() returns immediately if _proc is None."""
175+
ps = translate._PubSub()
176+
ep = translate.StdIOEndpoint("echo test", ps)
177+
# Ensure _proc is None (should be by default)
178+
assert ep._proc is None
179+
# Should not raise or do anything
180+
await ep.stop()
181+
146182
@pytest.mark.asyncio
147183
async def test_stdio_endpoint_flow(monkeypatch, translate):
148184
ps = translate._PubSub()
@@ -209,6 +245,23 @@ async def _fake_exec(*_a, **_kw):
209245
await ep.stop() # Should handle timeout gracefully
210246
assert fake.terminated
211247

248+
@pytest.mark.asyncio
249+
async def test_stdio_endpoint_stop_cancels_pump(monkeypatch, translate):
250+
ps = translate._PubSub()
251+
fake = _FakeProc(['{"jsonrpc":"2.0"}\n'])
252+
253+
async def _fake_exec(*_a, **_kw):
254+
return fake
255+
256+
monkeypatch.setattr(translate.asyncio, "create_subprocess_exec", _fake_exec)
257+
258+
ep = translate.StdIOEndpoint("echo hi", ps)
259+
await ep.start()
260+
# Simulate pump task still running
261+
assert ep._pump_task is not None
262+
# Stop should cancel the pump task
263+
await ep.stop()
264+
assert fake.terminated
212265

213266
# ---------------------------------------------------------------------------#
214267
# Tests: FastAPI facade (/sse /message /healthz) #
@@ -349,6 +402,46 @@ def test_fastapi_custom_paths(translate):
349402
assert "/healthz" in route_paths # Default health endpoint should still exist
350403

351404

405+
def test_build_fastapi_with_cors_and_keepalive(translate):
406+
ps = translate._PubSub()
407+
stdio = Mock()
408+
app = translate._build_fastapi(ps, stdio, keep_alive=5, cors_origins=["*"])
409+
assert app is not None
410+
# Check CORS middleware is present
411+
assert any("CORSMiddleware" in str(m) for m in app.user_middleware)
412+
413+
414+
@pytest.mark.asyncio
415+
async def test_sse_event_gen_unsubscribes_on_disconnect(monkeypatch, translate):
416+
ps = translate._PubSub()
417+
stdio = Mock()
418+
app = translate._build_fastapi(ps, stdio)
419+
420+
# Patch request to simulate disconnect after first yield
421+
class DummyRequest:
422+
def __init__(self):
423+
self.base_url = "http://test/"
424+
self._disconnected = False
425+
async def is_disconnected(self):
426+
if not self._disconnected:
427+
self._disconnected = True
428+
return False
429+
return True
430+
431+
# Get the /sse route handler
432+
for route in app.routes:
433+
if getattr(route, "path", None) == "/sse":
434+
handler = route.endpoint
435+
break
436+
437+
# Call the handler and exhaust the generator
438+
resp = await handler(DummyRequest())
439+
# The generator should unsubscribe after disconnect (no error)
440+
assert resp is not None
441+
442+
443+
444+
352445
# ---------------------------------------------------------------------------#
353446
# Tests: _parse_args #
354447
# ---------------------------------------------------------------------------#
@@ -386,6 +479,13 @@ def test_parse_args_log_level(translate):
386479
ns = translate._parse_args(["--stdio", "echo hi", "--logLevel", "debug"])
387480
assert ns.logLevel == "debug"
388481

482+
def test_parse_args_missing_required(translate):
483+
import sys
484+
argv = []
485+
# Should exit with SystemExit due to missing required argument
486+
with pytest.raises(SystemExit):
487+
translate._parse_args(argv)
488+
389489

390490
# ---------------------------------------------------------------------------#
391491
# Tests: _run_stdio_to_sse orchestration #
@@ -441,6 +541,7 @@ async def shutdown(self):
441541
await asyncio.wait_for(_test_logic(), timeout=3.0)
442542

443543

544+
444545
@pytest.mark.asyncio
445546
async def test_run_stdio_to_sse_with_cors(monkeypatch, translate):
446547
"""Test _run_stdio_to_sse with CORS configuration."""
@@ -737,12 +838,70 @@ def stream(self, *_a, **_kw):
737838
# Add timeout to prevent hanging
738839
await asyncio.wait_for(_test_logic(), timeout=5.0)
739840

841+
@pytest.mark.asyncio
842+
async def test_run_sse_to_stdio_importerror(monkeypatch, translate):
843+
monkeypatch.setattr(translate, "httpx", None)
844+
with pytest.raises(ImportError):
845+
await translate._run_sse_to_stdio("http://dummy/sse", None)
740846

847+
@pytest.mark.asyncio
848+
async def test_pump_sse_to_stdio_full(monkeypatch, translate):
849+
# Prepare fake process with mock stdin
850+
written = []
851+
class DummyStdin:
852+
def write(self, data):
853+
written.append(data)
854+
async def drain(self):
855+
written.append("drained")
856+
857+
class DummyProcess:
858+
stdin = DummyStdin()
859+
860+
# Prepare fake response with aiter_lines
861+
lines = [
862+
"event: message",
863+
"data: ", # Should be skipped
864+
"data: {}", # Should be skipped
865+
"data: {\"jsonrpc\":\"2.0\",\"result\":\"ok\"}", # Should be written
866+
"data: another", # Should be written
867+
"notdata: ignored", # Should be ignored
868+
]
869+
class DummyResponse:
870+
async def __aenter__(self): return self
871+
async def __aexit__(self, *a): pass
872+
async def aiter_lines(self):
873+
for line in lines:
874+
yield line
875+
876+
class DummyClient:
877+
async def __aenter__(self): return self
878+
async def __aexit__(self, *a): pass
879+
def stream(self, *a, **k): return DummyResponse()
880+
881+
# Patch httpx.AsyncClient to return DummyClient
882+
monkeypatch.setattr(translate, "httpx", MagicMock())
883+
translate.httpx.AsyncClient = MagicMock(return_value=DummyClient())
884+
885+
# Patch asyncio.create_subprocess_shell to return DummyProcess
886+
monkeypatch.setattr(translate.asyncio, "create_subprocess_shell", AsyncMock(return_value=DummyProcess()))
887+
888+
# Patch process.stdout so read_stdout() exits immediately
889+
class DummyStdout:
890+
async def readline(self): return b""
891+
DummyProcess.stdout = DummyStdout()
892+
893+
# Actually call _run_sse_to_stdio, which will define and call pump_sse_to_stdio
894+
await translate._run_sse_to_stdio("http://dummy/sse", None)
895+
896+
# Check that only the correct data was written and drained
897+
# Should skip empty and {} data, write the others
898+
assert b'{"jsonrpc":"2.0","result":"ok"}\n' in written
899+
assert b'another\n' in written
900+
assert "drained" in written
741901
# ---------------------------------------------------------------------------#
742902
# Tests: CLI entry-point (`python -m mcpgateway.translate`) #
743903
# ---------------------------------------------------------------------------#
744904

745-
746905
def test_module_entrypoint(monkeypatch, translate):
747906
"""Test that the module can be executed as __main__."""
748907
executed: list[str] = []
@@ -842,6 +1001,17 @@ def _raise_not_implemented(*args):
8421001
captured = capsys.readouterr()
8431002
assert "Test error message" in captured.err
8441003

1004+
def test_main_unknown_args(monkeypatch, translate):
1005+
monkeypatch.setattr(
1006+
translate,
1007+
"_parse_args",
1008+
lambda argv: type("Args", (), {
1009+
"stdio": None, "sse": None, "streamableHttp": None,
1010+
"logLevel": "info", "cors": None, "oauth2Bearer": None, "port": 8000
1011+
})()
1012+
)
1013+
# Just call main and assert it returns None (does not raise)
1014+
assert translate.main(["--unknown"]) is None
8451015

8461016
# ---------------------------------------------------------------------------#
8471017
# Tests: Edge cases and error paths #
@@ -937,3 +1107,9 @@ async def _fake_exec(*_a, **_kw):
9371107

9381108
# Add timeout to prevent hanging
9391109
await asyncio.wait_for(_test_logic(), timeout=3.0)
1110+
1111+
@pytest.mark.asyncio
1112+
async def test_stdio_endpoint_send_not_started(translate):
1113+
ep = translate.StdIOEndpoint("cmd", translate._PubSub())
1114+
with pytest.raises(RuntimeError):
1115+
await ep.send("test")

0 commit comments

Comments
 (0)