Skip to content

Commit d5a5019

Browse files
test case update (#224)
* test case update * Test case update Signed-off-by: Mohan Lakshmaiah <[email protected]> --------- Signed-off-by: Mohan Lakshmaiah <[email protected]> Co-authored-by: Mohan Lakshmaiah <[email protected]>
1 parent b4d8c54 commit d5a5019

File tree

4 files changed

+434
-8
lines changed

4 files changed

+434
-8
lines changed

tests/unit/mcpgateway/handlers/test_sampling.py

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
1515
No external MCP server is started; we test the isolated utility pieces that
1616
have no heavy dependencies.
17+
1718
"""
1819

1920
# Future
@@ -61,6 +62,70 @@ async def collector(msg):
6162
assert len(sent) == 1 and sent[0].message["id"] == 2
6263
assert sent[0].event_id == eid2
6364

65+
@pytest.mark.asyncio
66+
async def test_event_store_no_new_events():
67+
store = InMemoryEventStore(max_events_per_stream=10)
68+
stream_id = "stream1"
69+
eid = await store.store_event(stream_id, {"val": 42})
70+
sent = []
71+
async def collector(msg):
72+
sent.append(msg)
73+
74+
returned = await store.replay_events_after(eid, collector)
75+
assert returned == stream_id
76+
# No new events were stored, so nothing should be sent
77+
assert sent == []
78+
79+
@pytest.mark.asyncio
80+
async def test_event_store_multiple_replay():
81+
store = InMemoryEventStore(max_events_per_stream=10)
82+
stream_id = "stream1"
83+
# Store three events
84+
eids = []
85+
for i in range(3):
86+
eids.append(await store.store_event(stream_id, {"n": i}))
87+
sent = []
88+
async def collector(msg):
89+
sent.append(msg)
90+
91+
# Replay after the first event should get the 2nd and 3rd
92+
returned = await store.replay_events_after(eids[0], collector)
93+
assert returned == stream_id
94+
assert [msg.message["n"] for msg in sent] == [1, 2]
95+
96+
@pytest.mark.asyncio
97+
async def test_event_store_cross_streams():
98+
store = InMemoryEventStore(max_events_per_stream=10)
99+
s1, s2 = "s1", "s2"
100+
# Store events in two different streams
101+
eid1_s1 = await store.store_event(s1, {"val": 1})
102+
eid1_s2 = await store.store_event(s2, {"val": 2})
103+
eid2_s1 = await store.store_event(s1, {"val": 3})
104+
sent = []
105+
async def collector(msg):
106+
sent.append(msg)
107+
108+
# Replay on stream s1 after its first ID
109+
returned = await store.replay_events_after(eid1_s1, collector)
110+
assert returned == s1
111+
# Should only get the event from s1 (val=3), not the s2 event
112+
assert [msg.message["val"] for msg in sent] == [3]
113+
114+
@pytest.mark.asyncio
115+
async def test_event_store_eviction_of_oldest():
116+
store = InMemoryEventStore(max_events_per_stream=1)
117+
stream_id = "s"
118+
eid_old = await store.store_event(stream_id, {"x": "old"})
119+
# Storing a second event evicts the first (due to maxlen=1):contentReference[oaicite:8]{index=8}:contentReference[oaicite:9]{index=9}
120+
await store.store_event(stream_id, {"x": "new"})
121+
sent = []
122+
async def collector(msg):
123+
sent.append(msg)
124+
125+
result = await store.replay_events_after(eid_old, collector)
126+
# The first event ID has been evicted, so it should not be found
127+
assert result is None
128+
assert sent == []
64129

65130
@pytest.mark.asyncio
66131
async def test_event_store_eviction():
@@ -147,6 +212,42 @@ async def send(msg):
147212
assert sent and sent[0]["type"] == "http.response.start"
148213
assert sent[0]["status"] == tr.HTTP_401_UNAUTHORIZED
149214

215+
@pytest.mark.asyncio
216+
async def test_auth_valid_token(monkeypatch):
217+
# Simulate verify_credentials always succeeding
218+
async def fake_verify(token):
219+
assert token == "good-token"
220+
return {"ok": True}
221+
monkeypatch.setattr(tr, "verify_credentials", fake_verify)
222+
223+
messages = []
224+
async def send(msg):
225+
messages.append(msg)
226+
227+
scope = _make_scope("/servers/1/mcp",
228+
headers=[(b"authorization", b"Bearer good-token")])
229+
assert await streamable_http_auth(scope, None, send) is True
230+
assert messages == [] # No response sent on success
231+
232+
@pytest.mark.asyncio
233+
async def test_auth_invalid_token_raises(monkeypatch):
234+
# Simulate verify_credentials raising (invalid token scenario)
235+
async def fake_verify(token):
236+
raise ValueError("bad token")
237+
monkeypatch.setattr(tr, "verify_credentials", fake_verify)
238+
239+
sent = []
240+
async def send(msg):
241+
sent.append(msg)
242+
243+
scope = _make_scope("/servers/1/mcp",
244+
headers=[(b"authorization", b"Bearer bad-token")])
245+
result = await streamable_http_auth(scope, None, send)
246+
assert result is False
247+
# Expect an HTTP 401 response to be sent
248+
assert sent and sent[0]["type"] == "http.response.start"
249+
assert sent[0]["status"] == tr.HTTP_401_UNAUTHORIZED
250+
150251

151252
# ---------------------------------------------------------------------------
152253
# SamplingHandler tests
@@ -185,6 +286,13 @@ async def test_select_model_by_hint(handler):
185286

186287
assert handler._select_model(prefs) == "claude-3-sonnet" # pylint: disable=protected-access
187288

289+
@pytest.mark.asyncio
290+
async def test_select_model_no_suitable_model(handler):
291+
# Remove all supported models to force error
292+
handler._supported_models = {}
293+
prefs = _t.SimpleNamespace(hints=[], cost_priority=0, speed_priority=0, intelligence_priority=0)
294+
with pytest.raises(SamplingError):
295+
handler._select_model(prefs)
188296

189297
# ---------------------------------------------------------------------------
190298
# _validate_message
@@ -203,6 +311,54 @@ def test_validate_message(handler):
203311
assert handler._validate_message(valid_image) # pylint: disable=protected-access
204312
assert not handler._validate_message(invalid) # pylint: disable=protected-access
205313

314+
def test_validate_message_missing_image_fields(handler):
315+
# Missing 'data' field in image content
316+
invalid_img1 = {"role": "assistant", "content": {"type": "image", "mime_type": "image/png"}}
317+
# Missing 'mime_type' field
318+
invalid_img2 = {"role": "assistant", "content": {"type": "image", "data": "AAA"}}
319+
# Unknown content type
320+
invalid_img3 = {"role": "user", "content": {"type": "audio", "data": "xxx"}}
321+
322+
assert not handler._validate_message(invalid_img1)
323+
assert not handler._validate_message(invalid_img2)
324+
assert not handler._validate_message(invalid_img3)
325+
326+
@pytest.mark.asyncio
327+
async def test_add_context_returns_messages(handler):
328+
# Should just return the messages as-is (stub)
329+
msgs = [{"role": "user", "content": {"type": "text", "text": "hi"}}]
330+
result = await handler._add_context(None, msgs, "irrelevant")
331+
assert result == msgs
332+
333+
def test_mock_sample_no_user_message(handler):
334+
# No user message in the list
335+
msgs = [{"role": "assistant", "content": {"type": "text", "text": "hi"}}]
336+
result = handler._mock_sample(msgs)
337+
assert "I'm not sure" in result
338+
339+
def test_mock_sample_image_message(handler):
340+
# Last user message is image
341+
msgs = [
342+
{"role": "user", "content": {"type": "image", "data": "xxx", "mime_type": "image/png"}}
343+
]
344+
result = handler._mock_sample(msgs)
345+
assert "I see the image" in result
346+
347+
def test_validate_message_invalid_role(handler):
348+
msg = {"role": "system", "content": {"type": "text", "text": "hi"}}
349+
assert not handler._validate_message(msg)
350+
351+
def test_validate_message_missing_content(handler):
352+
msg = {"role": "user"}
353+
assert not handler._validate_message(msg)
354+
355+
def test_validate_message_exception_path(handler):
356+
# Simulate exception in validation
357+
class BadDict(dict):
358+
def get(self, k, d=None):
359+
raise Exception("fail")
360+
msg = {"role": "user", "content": BadDict()}
361+
assert not handler._validate_message(msg)
206362

207363
# ---------------------------------------------------------------------------
208364
# create_message success + error paths
@@ -228,6 +384,29 @@ async def test_create_message_success(monkeypatch, handler):
228384
assert result.role == sp.Role.ASSISTANT
229385
assert result.content.text.startswith("You said: Hello")
230386

387+
@pytest.mark.asyncio
388+
async def test_create_message_multiple_user_messages(monkeypatch, handler):
389+
# Return neutral preferences with no hints
390+
neutral_prefs = _t.SimpleNamespace(
391+
hints=[], cost_priority=0.5, speed_priority=0.3, intelligence_priority=0.2
392+
)
393+
monkeypatch.setattr(sp.ModelPreferences, "model_validate", lambda x: neutral_prefs)
394+
395+
# Conversation with an assistant message and then a user message
396+
request = {
397+
"messages": [
398+
{"role": "assistant", "content": {"type": "text", "text": "Hi"}},
399+
{"role": "user", "content": {"type": "text", "text": "Hello"}}
400+
],
401+
"maxTokens": 10,
402+
"modelPreferences": {}
403+
}
404+
405+
result = await handler.create_message(db=None, request=request)
406+
assert result.role == sp.Role.ASSISTANT
407+
# The response should reference the last user message "Hello"
408+
assert "You said: Hello" in result.content.text
409+
231410

232411
@pytest.mark.asyncio
233412
async def test_create_message_no_messages(monkeypatch, handler):
@@ -237,3 +416,45 @@ async def test_create_message_no_messages(monkeypatch, handler):
237416

238417
with pytest.raises(SamplingError):
239418
await handler.create_message(db=None, request=request)
419+
420+
@pytest.mark.asyncio
421+
async def test_create_message_raises_on_no_user_message(monkeypatch, handler):
422+
# Even if there are assistant messages, at least one user message is required
423+
monkeypatch.setattr(sp.ModelPreferences, "model_validate",
424+
lambda x: _t.SimpleNamespace(hints=[], cost_priority=0, speed_priority=0, intelligence_priority=0))
425+
request = {"messages": [], "maxTokens": 5, "modelPreferences": {}}
426+
with pytest.raises(SamplingError):
427+
await handler.create_message(db=None, request=request)
428+
429+
@pytest.mark.asyncio
430+
async def test_create_message_missing_max_tokens(monkeypatch, handler):
431+
monkeypatch.setattr(sp.ModelPreferences, "model_validate", lambda _x: _t.SimpleNamespace(hints=[], cost_priority=0, speed_priority=0, intelligence_priority=0))
432+
request = {"messages": [{"role": "user", "content": {"type": "text", "text": "hi"}}]}
433+
with pytest.raises(SamplingError):
434+
await handler.create_message(db=None, request=request)
435+
436+
@pytest.mark.asyncio
437+
async def test_create_message_invalid_message(monkeypatch, handler):
438+
monkeypatch.setattr(sp.ModelPreferences, "model_validate", lambda _x: _t.SimpleNamespace(hints=[], cost_priority=0, speed_priority=0, intelligence_priority=0))
439+
# Invalid message: missing text
440+
request = {
441+
"messages": [{"role": "user", "content": {"type": "text"}}],
442+
"maxTokens": 5,
443+
"modelPreferences": {},
444+
}
445+
with pytest.raises(SamplingError):
446+
await handler.create_message(db=None, request=request)
447+
448+
@pytest.mark.asyncio
449+
async def test_create_message_exception_propagation(monkeypatch, handler):
450+
# Patch _select_model to raise
451+
monkeypatch.setattr(handler, "_select_model", lambda prefs: (_ for _ in ()).throw(Exception("fail")))
452+
monkeypatch.setattr(sp.ModelPreferences, "model_validate", lambda _x: _t.SimpleNamespace(hints=[], cost_priority=0, speed_priority=0, intelligence_priority=0))
453+
request = {
454+
"messages": [{"role": "user", "content": {"type": "text", "text": "hi"}}],
455+
"maxTokens": 5,
456+
"modelPreferences": {},
457+
}
458+
with pytest.raises(SamplingError) as exc:
459+
await handler.create_message(db=None, request=request)
460+
assert "fail" in str(exc.value)

tests/unit/mcpgateway/transports/test_sse_transport.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
Authors: Mihai Criveti
77
88
Tests for the MCP Gateway SSE transport implementation.
9+
910
"""
1011

1112
# Standard
1213
import asyncio
1314
import json
14-
from unittest.mock import Mock
15+
import types
16+
from unittest.mock import Mock, AsyncMock, patch
1517

1618
# First-Party
1719
from mcpgateway.transports.sse_transport import SSETransport
@@ -81,6 +83,51 @@ async def test_send_message_not_connected(self, sse_transport):
8183
# Should raise error
8284
with pytest.raises(RuntimeError, match="Transport not connected"):
8385
await sse_transport.send_message(message)
86+
@pytest.mark.asyncio
87+
async def test_receive_message_not_connected(self, sse_transport):
88+
"""receive_message should raise RuntimeError if not connected."""
89+
with pytest.raises(RuntimeError):
90+
async for _ in sse_transport.receive_message():
91+
pass
92+
93+
@pytest.mark.asyncio
94+
async def test_send_message_queue_exception(self, sse_transport):
95+
"""send_message should log and raise if queue.put fails."""
96+
await sse_transport.connect()
97+
with patch.object(sse_transport._message_queue, "put", side_effect=Exception("fail")), \
98+
patch("mcpgateway.transports.sse_transport.logger") as mock_logger:
99+
with pytest.raises(Exception, match="fail"):
100+
await sse_transport.send_message({"foo": "bar"})
101+
assert mock_logger.error.called
102+
103+
@pytest.mark.asyncio
104+
async def test_receive_message_cancelled(self, sse_transport):
105+
"""Test receive_message handles CancelledError and logs."""
106+
await sse_transport.connect()
107+
with patch("asyncio.sleep", side_effect=asyncio.CancelledError), \
108+
patch("mcpgateway.transports.sse_transport.logger") as mock_logger:
109+
gen = sse_transport.receive_message()
110+
await gen.__anext__() # initialize message
111+
with pytest.raises(asyncio.CancelledError):
112+
await gen.__anext__()
113+
# Check that logger.info was called with the cancel message
114+
assert any(
115+
"SSE receive loop cancelled" in str(call)
116+
for call in [args[0] for args, _ in mock_logger.info.call_args_list]
117+
)
118+
@pytest.mark.asyncio
119+
async def test_receive_message_finally_logs(self, sse_transport):
120+
"""Test receive_message logs in finally block."""
121+
await sse_transport.connect()
122+
with patch("asyncio.sleep", side_effect=Exception("fail")), \
123+
patch("mcpgateway.transports.sse_transport.logger") as mock_logger:
124+
gen = sse_transport.receive_message()
125+
await gen.__anext__() # initialize message
126+
with pytest.raises(Exception):
127+
await gen.__anext__()
128+
assert any("SSE receive loop ended" in str(call) for call in mock_logger.info.call_args_list)
129+
130+
84131

85132
@pytest.mark.asyncio
86133
async def test_create_sse_response(self, sse_transport, mock_request):
@@ -99,6 +146,37 @@ async def test_create_sse_response(self, sse_transport, mock_request):
99146
assert response.headers["Cache-Control"] == "no-cache"
100147
assert response.headers["Content-Type"] == "text/event-stream"
101148
assert response.headers["X-MCP-SSE"] == "true"
149+
150+
@pytest.mark.asyncio
151+
async def test_create_sse_response_event_generator_error(self, sse_transport, mock_request):
152+
"""Test event_generator handles generic Exception and CancelledError."""
153+
await sse_transport.connect()
154+
# Patch _message_queue.get to raise Exception, then CancelledError
155+
with patch.object(sse_transport._message_queue, "get", side_effect=[Exception("fail"), asyncio.CancelledError()]), \
156+
patch("mcpgateway.transports.sse_transport.logger") as mock_logger:
157+
response = await sse_transport.create_sse_response(mock_request)
158+
gen = response.body_iterator
159+
await gen.__anext__() # endpoint
160+
await gen.__anext__() # keepalive
161+
# Should yield error event
162+
event = await gen.__anext__()
163+
assert event["event"] == "error"
164+
assert "fail" in event["data"]
165+
# Should handle CancelledError gracefully and stop
166+
with pytest.raises(StopAsyncIteration):
167+
await gen.__anext__()
168+
assert mock_logger.error.called or mock_logger.info.called
169+
170+
def test_session_id_property(self, sse_transport):
171+
"""Test session_id property returns the correct value."""
172+
assert sse_transport.session_id == sse_transport._session_id
173+
174+
@pytest.mark.asyncio
175+
async def test_client_disconnected(self, sse_transport, mock_request):
176+
"""Test _client_disconnected returns correct state."""
177+
assert await sse_transport._client_disconnected(mock_request) is False
178+
sse_transport._client_gone.set()
179+
assert await sse_transport._client_disconnected(mock_request) is True
102180

103181
@pytest.mark.asyncio
104182
async def test_receive_message(self, sse_transport):

0 commit comments

Comments
 (0)