Skip to content

Commit 03f306c

Browse files
committed
test_stdio_transport.py fix
Signed-off-by: Mihai Criveti <[email protected]>
1 parent 6ef039a commit 03f306c

File tree

1 file changed

+126
-252
lines changed

1 file changed

+126
-252
lines changed
Lines changed: 126 additions & 252 deletions
Original file line numberDiff line numberDiff line change
@@ -1,270 +1,144 @@
11
# -*- coding: utf-8 -*-
2-
"""Tests for the stdio transport implementation.
2+
"""
3+
Unit-tests for `mcpgateway.transports.stdio_transport.StdioTransport`.
4+
5+
The real transport interacts with the running event-loop and the
6+
process's stdin / stdout file-descriptors. Those OS objects are tricky
7+
to mock portably, so these tests **inject in-memory fakes** in place of
8+
`StreamReader` and `StreamWriter`. That lets us assert the transport's
9+
logic (JSON encoding/decoding, connection state, error handling) without
10+
ever touching real pipes.
311
412
Copyright 2025
513
SPDX-License-Identifier: Apache-2.0
614
Authors: Mihai Criveti
7-
8-
This module tests the stdio transport for MCP, ensuring it properly handles
9-
communication over standard input/output streams.
1015
"""
1116

17+
from __future__ import annotations
18+
1219
import asyncio
13-
from unittest.mock import AsyncMock, MagicMock, patch
20+
import json
21+
from typing import List
1422

1523
import pytest
1624

1725
from mcpgateway.transports.stdio_transport import StdioTransport
1826

1927

28+
# ---------------------------------------------------------------------------
29+
# Simple in-memory stand-ins for asyncio streams
30+
# ---------------------------------------------------------------------------
31+
32+
33+
class _DummyWriter:
34+
"""Enough of the `StreamWriter` interface for our tests."""
35+
36+
def __init__(self):
37+
self.buffer: List[bytes] = []
38+
self.drain_called = False
39+
self.closed = False
40+
41+
def write(self, data: bytes) -> None: # noqa: D401
42+
self.buffer.append(data)
43+
44+
async def drain(self) -> None: # noqa: D401
45+
self.drain_called = True
46+
47+
def close(self) -> None: # noqa: D401
48+
self.closed = True
49+
50+
async def wait_closed(self) -> None: # noqa: D401
51+
return
52+
53+
54+
class _DummyReader:
55+
"""Yield bytes from an internal list each time `.readline()` is awaited."""
56+
57+
def __init__(self, lines: list[str]):
58+
self._lines = [l.encode() + b"\n" for l in lines]
59+
60+
async def readline(self) -> bytes: # noqa: D401
61+
await asyncio.sleep(0) # let the event-loop breathe
62+
return self._lines.pop(0) if self._lines else b""
63+
64+
65+
# ---------------------------------------------------------------------------
66+
# Fixtures
67+
# ---------------------------------------------------------------------------
68+
69+
2070
@pytest.fixture
21-
def stdio_transport():
22-
"""Create a StdioTransport instance for testing."""
71+
def transport():
72+
"""Provide an *unconnected* StdioTransport for each test."""
2373
return StdioTransport()
2474

2575

26-
class TestStdioTransport:
27-
"""Test suite for the StdioTransport class."""
28-
29-
@patch("asyncio.get_running_loop")
30-
async def test_connect(self, mock_get_loop, stdio_transport):
31-
"""Test establishing a connection."""
32-
# Set up mocks
33-
mock_loop = MagicMock()
34-
mock_get_loop.return_value = mock_loop
35-
36-
MagicMock()
37-
mock_reader_protocol = MagicMock()
38-
39-
mock_transport = MagicMock()
40-
mock_protocol = MagicMock()
41-
42-
# Mock the connect_read_pipe and connect_write_pipe methods
43-
mock_loop.connect_read_pipe = AsyncMock(return_value=(mock_transport, mock_reader_protocol))
44-
mock_loop.connect_write_pipe = AsyncMock(return_value=(mock_transport, mock_protocol))
45-
46-
# Call the method under test
47-
await stdio_transport.connect()
48-
49-
# Verify the expected calls
50-
mock_get_loop.assert_called_once()
51-
mock_loop.connect_read_pipe.assert_called_once()
52-
mock_loop.connect_write_pipe.assert_called_once()
53-
54-
# Verify the connection state
55-
assert stdio_transport._connected is True
56-
57-
@patch("asyncio.StreamWriter")
58-
async def test_disconnect(self, mock_writer, stdio_transport):
59-
"""Test closing a connection."""
60-
# Set up mock
61-
stdio_transport._stdout_writer = mock_writer
62-
stdio_transport._connected = True
63-
64-
# Call the method under test
65-
await stdio_transport.disconnect()
66-
67-
# Verify the expected calls
68-
mock_writer.close.assert_called_once()
69-
mock_writer.wait_closed.assert_called_once()
70-
71-
# Verify the connection state
72-
assert stdio_transport._connected is False
73-
74-
async def test_disconnect_not_connected(self, stdio_transport):
75-
"""Test disconnecting when not connected."""
76-
# Ensure not connected
77-
stdio_transport._stdout_writer = None
78-
stdio_transport._connected = False
79-
80-
# Call the method under test
81-
await stdio_transport.disconnect()
82-
83-
# Verify the connection state
84-
assert stdio_transport._connected is False
85-
86-
async def test_send_message_not_connected(self, stdio_transport):
87-
"""Test sending a message when not connected."""
88-
# Ensure not connected
89-
stdio_transport._stdout_writer = None
90-
stdio_transport._connected = False
91-
92-
# Verify the expected exception
93-
with pytest.raises(RuntimeError, match="Transport not connected"):
94-
await stdio_transport.send_message({"type": "test"})
95-
96-
@patch("asyncio.StreamWriter")
97-
async def test_send_message(self, mock_writer, stdio_transport):
98-
"""Test sending a message."""
99-
# Set up mock
100-
stdio_transport._stdout_writer = mock_writer
101-
stdio_transport._connected = True
102-
103-
# Call the method under test
104-
await stdio_transport.send_message({"type": "test", "data": "message"})
105-
106-
# Verify the expected calls
107-
mock_writer.write.assert_called_once()
108-
mock_writer.drain.assert_called_once()
109-
110-
# Verify the encoded message
111-
call_args = mock_writer.write.call_args[0][0]
112-
assert b'{"type": "test", "data": "message"}\n' == call_args
113-
114-
@patch("asyncio.StreamWriter")
115-
async def test_send_message_exception(self, mock_writer, stdio_transport):
116-
"""Test sending a message when an error occurs."""
117-
# Set up mock with an exception
118-
stdio_transport._stdout_writer = mock_writer
119-
stdio_transport._connected = True
120-
mock_writer.write.side_effect = Exception("Write error")
121-
122-
# Verify the expected exception
123-
with pytest.raises(Exception, match="Write error"):
124-
await stdio_transport.send_message({"type": "test"})
125-
126-
async def test_receive_message_not_connected(self, stdio_transport):
127-
"""Test receiving messages when not connected."""
128-
# Ensure not connected
129-
stdio_transport._stdin_reader = None
130-
stdio_transport._connected = False
131-
132-
# Verify the expected exception
133-
with pytest.raises(RuntimeError, match="Transport not connected"):
134-
async for _ in stdio_transport.receive_message():
135-
pass
136-
137-
@patch("asyncio.StreamReader")
138-
async def test_receive_message(self, mock_reader, stdio_transport):
139-
"""Test receiving messages."""
140-
# Set up mock with two messages followed by EOF
141-
stdio_transport._stdin_reader = mock_reader
142-
stdio_transport._connected = True
143-
144-
message1 = b'{"type": "message1", "data": "test1"}\n'
145-
message2 = b'{"type": "message2", "data": "test2"}\n'
146-
147-
# Configure the mock to return messages and then an empty response (EOF)
148-
mock_reader.readline = AsyncMock(side_effect=[message1, message2, b""])
149-
150-
# Collect received messages
151-
received = []
152-
async for message in stdio_transport.receive_message():
153-
received.append(message)
154-
155-
# Verify the expected calls and received messages
156-
assert mock_reader.readline.call_count == 3
157-
assert len(received) == 2
158-
assert received[0] == {"type": "message1", "data": "test1"}
159-
assert received[1] == {"type": "message2", "data": "test2"}
160-
161-
@patch("asyncio.StreamReader")
162-
async def test_receive_message_exception(self, mock_reader, stdio_transport):
163-
"""Test receiving messages when a non-fatal error occurs."""
164-
# Set up mock with a valid message, then an invalid one, then EOF
165-
stdio_transport._stdin_reader = mock_reader
166-
stdio_transport._connected = True
167-
168-
message1 = b'{"type": "message1", "data": "test1"}\n'
169-
invalid_message = b"not valid json\n"
170-
171-
# Configure the mock to return a valid message, then an invalid one, then EOF
172-
mock_reader.readline = AsyncMock(side_effect=[message1, invalid_message, b""])
173-
174-
# Collect received messages (should only get the valid one)
175-
received = []
176-
async for message in stdio_transport.receive_message():
177-
received.append(message)
178-
179-
# Verify the expected calls and received messages
180-
assert mock_reader.readline.call_count == 3
181-
assert len(received) == 1
182-
assert received[0] == {"type": "message1", "data": "test1"}
183-
184-
@patch("asyncio.StreamReader")
185-
async def test_receive_message_cancellation(self, mock_reader, stdio_transport):
186-
"""Test receiving messages with cancellation."""
187-
# Set up mock with a message
188-
stdio_transport._stdin_reader = mock_reader
189-
stdio_transport._connected = True
190-
191-
message = b'{"type": "message", "data": "test"}\n'
192-
193-
# Configure the mock to return a message and then raise a cancellation
194-
mock_reader.readline = AsyncMock(side_effect=[message, asyncio.CancelledError()])
195-
196-
# Collect received messages until cancellation
197-
received = []
198-
try:
199-
async for message in stdio_transport.receive_message():
200-
received.append(message)
201-
except asyncio.CancelledError:
202-
pass # Expected
203-
204-
# Verify the expected calls and received messages
205-
assert mock_reader.readline.call_count == 2
206-
assert len(received) == 1
207-
assert received[0] == {"type": "message", "data": "test"}
208-
209-
async def test_is_connected(self, stdio_transport):
210-
"""Test checking connection status."""
211-
# Test when not connected
212-
stdio_transport._connected = False
213-
assert await stdio_transport.is_connected() is False
214-
215-
# Test when connected
216-
stdio_transport._connected = True
217-
assert await stdio_transport.is_connected() is True
218-
219-
@patch.object(asyncio, "StreamReader")
220-
@patch.object(asyncio, "get_running_loop")
221-
async def test_full_lifecycle(self, mock_get_loop, mock_stream_reader, stdio_transport):
222-
"""Test a full lifecycle of connect, send/receive, and disconnect."""
223-
# Set up mocks
224-
mock_loop = MagicMock()
225-
mock_get_loop.return_value = mock_loop
226-
227-
mock_reader = MagicMock()
228-
mock_reader_protocol = MagicMock()
229-
230-
mock_transport = MagicMock()
231-
mock_protocol = MagicMock()
232-
mock_writer = MagicMock()
233-
234-
# Mock the connect_read_pipe and connect_write_pipe methods
235-
mock_loop.connect_read_pipe = AsyncMock(return_value=(mock_reader_protocol, mock_protocol))
236-
mock_loop.connect_write_pipe = AsyncMock(return_value=(mock_transport, mock_protocol))
237-
238-
# Add stdout writer
239-
stdio_transport._stdout_writer = mock_writer
240-
241-
# Configure reader with a test message
242-
mock_message = b'{"type": "test", "content": "hello"}\n'
243-
mock_reader.readline = AsyncMock(return_value=mock_message)
244-
stdio_transport._stdin_reader = mock_reader
245-
246-
# Test connect
247-
await stdio_transport.connect()
248-
assert stdio_transport._connected is True
249-
250-
# Test is_connected
251-
assert await stdio_transport.is_connected() is True
252-
253-
# Test send_message
254-
await stdio_transport.send_message({"type": "response", "content": "world"})
255-
mock_writer.write.assert_called_once()
256-
mock_writer.drain.assert_called_once()
257-
258-
# Test receive_message
259-
async def get_first_message():
260-
async for message in stdio_transport.receive_message():
261-
return message
262-
263-
message = await get_first_message()
264-
assert message == {"type": "test", "content": "hello"}
265-
266-
# Test disconnect
267-
await stdio_transport.disconnect()
268-
assert stdio_transport._connected is False
269-
mock_writer.close.assert_called_once()
270-
mock_writer.wait_closed.assert_called_once()
76+
# ---------------------------------------------------------------------------
77+
# Tests
78+
# ---------------------------------------------------------------------------
79+
80+
81+
@pytest.mark.asyncio
82+
async def test_send_message_happy_path(transport):
83+
"""
84+
`send_message()` should JSON-encode + newline-terminate the dict,
85+
push it to the writer, and await `drain()`.
86+
"""
87+
writer = _DummyWriter()
88+
transport._stdout_writer = writer # type: ignore[attr-defined]
89+
transport._connected = True # type: ignore[attr-defined]
90+
91+
payload = {"key": "value", "n": 123}
92+
await transport.send_message(payload)
93+
94+
assert writer.drain_called is True
95+
assert len(writer.buffer) == 1
96+
raw = writer.buffer[0]
97+
# newline-terminated and round-trips through json
98+
assert raw.endswith(b"\n")
99+
assert json.loads(raw.decode().rstrip("\n")) == payload
100+
101+
102+
@pytest.mark.asyncio
103+
async def test_send_message_not_connected_raises(transport):
104+
"""Calling `send_message()` without a writer should raise RuntimeError."""
105+
with pytest.raises(RuntimeError):
106+
await transport.send_message({"oops": 1})
107+
108+
109+
@pytest.mark.asyncio
110+
async def test_receive_message_decodes_until_eof(transport):
111+
reader = _DummyReader(['{"a":1}', '{"b":2}']) # two JSON lines then EOF
112+
transport._stdin_reader = reader # type: ignore[attr-defined]
113+
transport._connected = True # type: ignore[attr-defined]
114+
115+
messages = [m async for m in transport.receive_message()]
116+
117+
assert messages == [{"a": 1}, {"b": 2}]
118+
119+
120+
@pytest.mark.asyncio
121+
async def test_receive_message_not_connected_raises(transport):
122+
with pytest.raises(RuntimeError):
123+
async for _ in transport.receive_message(): # pragma: no cover
124+
pass
125+
126+
127+
@pytest.mark.asyncio
128+
async def test_disconnect_closes_writer_and_flags_state(transport):
129+
writer = _DummyWriter()
130+
transport._stdout_writer = writer # type: ignore[attr-defined]
131+
transport._connected = True # type: ignore[attr-defined]
132+
133+
await transport.disconnect()
134+
135+
assert writer.closed is True
136+
assert await transport.is_connected() is False
137+
138+
139+
@pytest.mark.asyncio
140+
async def test_is_connected_reports_state(transport):
141+
transport._connected = False # type: ignore[attr-defined]
142+
assert await transport.is_connected() is False
143+
transport._connected = True # type: ignore[attr-defined]
144+
assert await transport.is_connected() is True

0 commit comments

Comments
 (0)