|
1 | 1 | # -*- coding: utf-8 -*-
|
2 |
| -"""Tests for the stdio transport implementation. |
| 2 | +""" |
| 3 | +Unit-tests for `mcpgateway.transports.stdio_transport.StdioTransport`. |
| 4 | +
|
| 5 | +The real transport interacts with the running event-loop and the |
| 6 | +process's stdin / stdout file-descriptors. Those OS objects are tricky |
| 7 | +to mock portably, so these tests **inject in-memory fakes** in place of |
| 8 | +`StreamReader` and `StreamWriter`. That lets us assert the transport's |
| 9 | +logic (JSON encoding/decoding, connection state, error handling) without |
| 10 | +ever touching real pipes. |
3 | 11 |
|
4 | 12 | Copyright 2025
|
5 | 13 | SPDX-License-Identifier: Apache-2.0
|
6 | 14 | Authors: Mihai Criveti
|
7 |
| -
|
8 |
| -This module tests the stdio transport for MCP, ensuring it properly handles |
9 |
| -communication over standard input/output streams. |
10 | 15 | """
|
11 | 16 |
|
| 17 | +from __future__ import annotations |
| 18 | + |
12 | 19 | import asyncio
|
13 |
| -from unittest.mock import AsyncMock, MagicMock, patch |
| 20 | +import json |
| 21 | +from typing import List |
14 | 22 |
|
15 | 23 | import pytest
|
16 | 24 |
|
17 | 25 | from mcpgateway.transports.stdio_transport import StdioTransport
|
18 | 26 |
|
19 | 27 |
|
| 28 | +# --------------------------------------------------------------------------- |
| 29 | +# Simple in-memory stand-ins for asyncio streams |
| 30 | +# --------------------------------------------------------------------------- |
| 31 | + |
| 32 | + |
| 33 | +class _DummyWriter: |
| 34 | + """Enough of the `StreamWriter` interface for our tests.""" |
| 35 | + |
| 36 | + def __init__(self): |
| 37 | + self.buffer: List[bytes] = [] |
| 38 | + self.drain_called = False |
| 39 | + self.closed = False |
| 40 | + |
| 41 | + def write(self, data: bytes) -> None: # noqa: D401 |
| 42 | + self.buffer.append(data) |
| 43 | + |
| 44 | + async def drain(self) -> None: # noqa: D401 |
| 45 | + self.drain_called = True |
| 46 | + |
| 47 | + def close(self) -> None: # noqa: D401 |
| 48 | + self.closed = True |
| 49 | + |
| 50 | + async def wait_closed(self) -> None: # noqa: D401 |
| 51 | + return |
| 52 | + |
| 53 | + |
| 54 | +class _DummyReader: |
| 55 | + """Yield bytes from an internal list each time `.readline()` is awaited.""" |
| 56 | + |
| 57 | + def __init__(self, lines: list[str]): |
| 58 | + self._lines = [l.encode() + b"\n" for l in lines] |
| 59 | + |
| 60 | + async def readline(self) -> bytes: # noqa: D401 |
| 61 | + await asyncio.sleep(0) # let the event-loop breathe |
| 62 | + return self._lines.pop(0) if self._lines else b"" |
| 63 | + |
| 64 | + |
| 65 | +# --------------------------------------------------------------------------- |
| 66 | +# Fixtures |
| 67 | +# --------------------------------------------------------------------------- |
| 68 | + |
| 69 | + |
20 | 70 | @pytest.fixture
|
21 |
| -def stdio_transport(): |
22 |
| - """Create a StdioTransport instance for testing.""" |
| 71 | +def transport(): |
| 72 | + """Provide an *unconnected* StdioTransport for each test.""" |
23 | 73 | return StdioTransport()
|
24 | 74 |
|
25 | 75 |
|
26 |
| -class TestStdioTransport: |
27 |
| - """Test suite for the StdioTransport class.""" |
28 |
| - |
29 |
| - @patch("asyncio.get_running_loop") |
30 |
| - async def test_connect(self, mock_get_loop, stdio_transport): |
31 |
| - """Test establishing a connection.""" |
32 |
| - # Set up mocks |
33 |
| - mock_loop = MagicMock() |
34 |
| - mock_get_loop.return_value = mock_loop |
35 |
| - |
36 |
| - MagicMock() |
37 |
| - mock_reader_protocol = MagicMock() |
38 |
| - |
39 |
| - mock_transport = MagicMock() |
40 |
| - mock_protocol = MagicMock() |
41 |
| - |
42 |
| - # Mock the connect_read_pipe and connect_write_pipe methods |
43 |
| - mock_loop.connect_read_pipe = AsyncMock(return_value=(mock_transport, mock_reader_protocol)) |
44 |
| - mock_loop.connect_write_pipe = AsyncMock(return_value=(mock_transport, mock_protocol)) |
45 |
| - |
46 |
| - # Call the method under test |
47 |
| - await stdio_transport.connect() |
48 |
| - |
49 |
| - # Verify the expected calls |
50 |
| - mock_get_loop.assert_called_once() |
51 |
| - mock_loop.connect_read_pipe.assert_called_once() |
52 |
| - mock_loop.connect_write_pipe.assert_called_once() |
53 |
| - |
54 |
| - # Verify the connection state |
55 |
| - assert stdio_transport._connected is True |
56 |
| - |
57 |
| - @patch("asyncio.StreamWriter") |
58 |
| - async def test_disconnect(self, mock_writer, stdio_transport): |
59 |
| - """Test closing a connection.""" |
60 |
| - # Set up mock |
61 |
| - stdio_transport._stdout_writer = mock_writer |
62 |
| - stdio_transport._connected = True |
63 |
| - |
64 |
| - # Call the method under test |
65 |
| - await stdio_transport.disconnect() |
66 |
| - |
67 |
| - # Verify the expected calls |
68 |
| - mock_writer.close.assert_called_once() |
69 |
| - mock_writer.wait_closed.assert_called_once() |
70 |
| - |
71 |
| - # Verify the connection state |
72 |
| - assert stdio_transport._connected is False |
73 |
| - |
74 |
| - async def test_disconnect_not_connected(self, stdio_transport): |
75 |
| - """Test disconnecting when not connected.""" |
76 |
| - # Ensure not connected |
77 |
| - stdio_transport._stdout_writer = None |
78 |
| - stdio_transport._connected = False |
79 |
| - |
80 |
| - # Call the method under test |
81 |
| - await stdio_transport.disconnect() |
82 |
| - |
83 |
| - # Verify the connection state |
84 |
| - assert stdio_transport._connected is False |
85 |
| - |
86 |
| - async def test_send_message_not_connected(self, stdio_transport): |
87 |
| - """Test sending a message when not connected.""" |
88 |
| - # Ensure not connected |
89 |
| - stdio_transport._stdout_writer = None |
90 |
| - stdio_transport._connected = False |
91 |
| - |
92 |
| - # Verify the expected exception |
93 |
| - with pytest.raises(RuntimeError, match="Transport not connected"): |
94 |
| - await stdio_transport.send_message({"type": "test"}) |
95 |
| - |
96 |
| - @patch("asyncio.StreamWriter") |
97 |
| - async def test_send_message(self, mock_writer, stdio_transport): |
98 |
| - """Test sending a message.""" |
99 |
| - # Set up mock |
100 |
| - stdio_transport._stdout_writer = mock_writer |
101 |
| - stdio_transport._connected = True |
102 |
| - |
103 |
| - # Call the method under test |
104 |
| - await stdio_transport.send_message({"type": "test", "data": "message"}) |
105 |
| - |
106 |
| - # Verify the expected calls |
107 |
| - mock_writer.write.assert_called_once() |
108 |
| - mock_writer.drain.assert_called_once() |
109 |
| - |
110 |
| - # Verify the encoded message |
111 |
| - call_args = mock_writer.write.call_args[0][0] |
112 |
| - assert b'{"type": "test", "data": "message"}\n' == call_args |
113 |
| - |
114 |
| - @patch("asyncio.StreamWriter") |
115 |
| - async def test_send_message_exception(self, mock_writer, stdio_transport): |
116 |
| - """Test sending a message when an error occurs.""" |
117 |
| - # Set up mock with an exception |
118 |
| - stdio_transport._stdout_writer = mock_writer |
119 |
| - stdio_transport._connected = True |
120 |
| - mock_writer.write.side_effect = Exception("Write error") |
121 |
| - |
122 |
| - # Verify the expected exception |
123 |
| - with pytest.raises(Exception, match="Write error"): |
124 |
| - await stdio_transport.send_message({"type": "test"}) |
125 |
| - |
126 |
| - async def test_receive_message_not_connected(self, stdio_transport): |
127 |
| - """Test receiving messages when not connected.""" |
128 |
| - # Ensure not connected |
129 |
| - stdio_transport._stdin_reader = None |
130 |
| - stdio_transport._connected = False |
131 |
| - |
132 |
| - # Verify the expected exception |
133 |
| - with pytest.raises(RuntimeError, match="Transport not connected"): |
134 |
| - async for _ in stdio_transport.receive_message(): |
135 |
| - pass |
136 |
| - |
137 |
| - @patch("asyncio.StreamReader") |
138 |
| - async def test_receive_message(self, mock_reader, stdio_transport): |
139 |
| - """Test receiving messages.""" |
140 |
| - # Set up mock with two messages followed by EOF |
141 |
| - stdio_transport._stdin_reader = mock_reader |
142 |
| - stdio_transport._connected = True |
143 |
| - |
144 |
| - message1 = b'{"type": "message1", "data": "test1"}\n' |
145 |
| - message2 = b'{"type": "message2", "data": "test2"}\n' |
146 |
| - |
147 |
| - # Configure the mock to return messages and then an empty response (EOF) |
148 |
| - mock_reader.readline = AsyncMock(side_effect=[message1, message2, b""]) |
149 |
| - |
150 |
| - # Collect received messages |
151 |
| - received = [] |
152 |
| - async for message in stdio_transport.receive_message(): |
153 |
| - received.append(message) |
154 |
| - |
155 |
| - # Verify the expected calls and received messages |
156 |
| - assert mock_reader.readline.call_count == 3 |
157 |
| - assert len(received) == 2 |
158 |
| - assert received[0] == {"type": "message1", "data": "test1"} |
159 |
| - assert received[1] == {"type": "message2", "data": "test2"} |
160 |
| - |
161 |
| - @patch("asyncio.StreamReader") |
162 |
| - async def test_receive_message_exception(self, mock_reader, stdio_transport): |
163 |
| - """Test receiving messages when a non-fatal error occurs.""" |
164 |
| - # Set up mock with a valid message, then an invalid one, then EOF |
165 |
| - stdio_transport._stdin_reader = mock_reader |
166 |
| - stdio_transport._connected = True |
167 |
| - |
168 |
| - message1 = b'{"type": "message1", "data": "test1"}\n' |
169 |
| - invalid_message = b"not valid json\n" |
170 |
| - |
171 |
| - # Configure the mock to return a valid message, then an invalid one, then EOF |
172 |
| - mock_reader.readline = AsyncMock(side_effect=[message1, invalid_message, b""]) |
173 |
| - |
174 |
| - # Collect received messages (should only get the valid one) |
175 |
| - received = [] |
176 |
| - async for message in stdio_transport.receive_message(): |
177 |
| - received.append(message) |
178 |
| - |
179 |
| - # Verify the expected calls and received messages |
180 |
| - assert mock_reader.readline.call_count == 3 |
181 |
| - assert len(received) == 1 |
182 |
| - assert received[0] == {"type": "message1", "data": "test1"} |
183 |
| - |
184 |
| - @patch("asyncio.StreamReader") |
185 |
| - async def test_receive_message_cancellation(self, mock_reader, stdio_transport): |
186 |
| - """Test receiving messages with cancellation.""" |
187 |
| - # Set up mock with a message |
188 |
| - stdio_transport._stdin_reader = mock_reader |
189 |
| - stdio_transport._connected = True |
190 |
| - |
191 |
| - message = b'{"type": "message", "data": "test"}\n' |
192 |
| - |
193 |
| - # Configure the mock to return a message and then raise a cancellation |
194 |
| - mock_reader.readline = AsyncMock(side_effect=[message, asyncio.CancelledError()]) |
195 |
| - |
196 |
| - # Collect received messages until cancellation |
197 |
| - received = [] |
198 |
| - try: |
199 |
| - async for message in stdio_transport.receive_message(): |
200 |
| - received.append(message) |
201 |
| - except asyncio.CancelledError: |
202 |
| - pass # Expected |
203 |
| - |
204 |
| - # Verify the expected calls and received messages |
205 |
| - assert mock_reader.readline.call_count == 2 |
206 |
| - assert len(received) == 1 |
207 |
| - assert received[0] == {"type": "message", "data": "test"} |
208 |
| - |
209 |
| - async def test_is_connected(self, stdio_transport): |
210 |
| - """Test checking connection status.""" |
211 |
| - # Test when not connected |
212 |
| - stdio_transport._connected = False |
213 |
| - assert await stdio_transport.is_connected() is False |
214 |
| - |
215 |
| - # Test when connected |
216 |
| - stdio_transport._connected = True |
217 |
| - assert await stdio_transport.is_connected() is True |
218 |
| - |
219 |
| - @patch.object(asyncio, "StreamReader") |
220 |
| - @patch.object(asyncio, "get_running_loop") |
221 |
| - async def test_full_lifecycle(self, mock_get_loop, mock_stream_reader, stdio_transport): |
222 |
| - """Test a full lifecycle of connect, send/receive, and disconnect.""" |
223 |
| - # Set up mocks |
224 |
| - mock_loop = MagicMock() |
225 |
| - mock_get_loop.return_value = mock_loop |
226 |
| - |
227 |
| - mock_reader = MagicMock() |
228 |
| - mock_reader_protocol = MagicMock() |
229 |
| - |
230 |
| - mock_transport = MagicMock() |
231 |
| - mock_protocol = MagicMock() |
232 |
| - mock_writer = MagicMock() |
233 |
| - |
234 |
| - # Mock the connect_read_pipe and connect_write_pipe methods |
235 |
| - mock_loop.connect_read_pipe = AsyncMock(return_value=(mock_reader_protocol, mock_protocol)) |
236 |
| - mock_loop.connect_write_pipe = AsyncMock(return_value=(mock_transport, mock_protocol)) |
237 |
| - |
238 |
| - # Add stdout writer |
239 |
| - stdio_transport._stdout_writer = mock_writer |
240 |
| - |
241 |
| - # Configure reader with a test message |
242 |
| - mock_message = b'{"type": "test", "content": "hello"}\n' |
243 |
| - mock_reader.readline = AsyncMock(return_value=mock_message) |
244 |
| - stdio_transport._stdin_reader = mock_reader |
245 |
| - |
246 |
| - # Test connect |
247 |
| - await stdio_transport.connect() |
248 |
| - assert stdio_transport._connected is True |
249 |
| - |
250 |
| - # Test is_connected |
251 |
| - assert await stdio_transport.is_connected() is True |
252 |
| - |
253 |
| - # Test send_message |
254 |
| - await stdio_transport.send_message({"type": "response", "content": "world"}) |
255 |
| - mock_writer.write.assert_called_once() |
256 |
| - mock_writer.drain.assert_called_once() |
257 |
| - |
258 |
| - # Test receive_message |
259 |
| - async def get_first_message(): |
260 |
| - async for message in stdio_transport.receive_message(): |
261 |
| - return message |
262 |
| - |
263 |
| - message = await get_first_message() |
264 |
| - assert message == {"type": "test", "content": "hello"} |
265 |
| - |
266 |
| - # Test disconnect |
267 |
| - await stdio_transport.disconnect() |
268 |
| - assert stdio_transport._connected is False |
269 |
| - mock_writer.close.assert_called_once() |
270 |
| - mock_writer.wait_closed.assert_called_once() |
| 76 | +# --------------------------------------------------------------------------- |
| 77 | +# Tests |
| 78 | +# --------------------------------------------------------------------------- |
| 79 | + |
| 80 | + |
| 81 | +@pytest.mark.asyncio |
| 82 | +async def test_send_message_happy_path(transport): |
| 83 | + """ |
| 84 | + `send_message()` should JSON-encode + newline-terminate the dict, |
| 85 | + push it to the writer, and await `drain()`. |
| 86 | + """ |
| 87 | + writer = _DummyWriter() |
| 88 | + transport._stdout_writer = writer # type: ignore[attr-defined] |
| 89 | + transport._connected = True # type: ignore[attr-defined] |
| 90 | + |
| 91 | + payload = {"key": "value", "n": 123} |
| 92 | + await transport.send_message(payload) |
| 93 | + |
| 94 | + assert writer.drain_called is True |
| 95 | + assert len(writer.buffer) == 1 |
| 96 | + raw = writer.buffer[0] |
| 97 | + # newline-terminated and round-trips through json |
| 98 | + assert raw.endswith(b"\n") |
| 99 | + assert json.loads(raw.decode().rstrip("\n")) == payload |
| 100 | + |
| 101 | + |
| 102 | +@pytest.mark.asyncio |
| 103 | +async def test_send_message_not_connected_raises(transport): |
| 104 | + """Calling `send_message()` without a writer should raise RuntimeError.""" |
| 105 | + with pytest.raises(RuntimeError): |
| 106 | + await transport.send_message({"oops": 1}) |
| 107 | + |
| 108 | + |
| 109 | +@pytest.mark.asyncio |
| 110 | +async def test_receive_message_decodes_until_eof(transport): |
| 111 | + reader = _DummyReader(['{"a":1}', '{"b":2}']) # two JSON lines then EOF |
| 112 | + transport._stdin_reader = reader # type: ignore[attr-defined] |
| 113 | + transport._connected = True # type: ignore[attr-defined] |
| 114 | + |
| 115 | + messages = [m async for m in transport.receive_message()] |
| 116 | + |
| 117 | + assert messages == [{"a": 1}, {"b": 2}] |
| 118 | + |
| 119 | + |
| 120 | +@pytest.mark.asyncio |
| 121 | +async def test_receive_message_not_connected_raises(transport): |
| 122 | + with pytest.raises(RuntimeError): |
| 123 | + async for _ in transport.receive_message(): # pragma: no cover |
| 124 | + pass |
| 125 | + |
| 126 | + |
| 127 | +@pytest.mark.asyncio |
| 128 | +async def test_disconnect_closes_writer_and_flags_state(transport): |
| 129 | + writer = _DummyWriter() |
| 130 | + transport._stdout_writer = writer # type: ignore[attr-defined] |
| 131 | + transport._connected = True # type: ignore[attr-defined] |
| 132 | + |
| 133 | + await transport.disconnect() |
| 134 | + |
| 135 | + assert writer.closed is True |
| 136 | + assert await transport.is_connected() is False |
| 137 | + |
| 138 | + |
| 139 | +@pytest.mark.asyncio |
| 140 | +async def test_is_connected_reports_state(transport): |
| 141 | + transport._connected = False # type: ignore[attr-defined] |
| 142 | + assert await transport.is_connected() is False |
| 143 | + transport._connected = True # type: ignore[attr-defined] |
| 144 | + assert await transport.is_connected() is True |
0 commit comments